【Tranformer-GPT】使用注意力机制进行类GPT模型训练
一、相关原理
类GPT,把输入的token作为文章的开头,进行自回归输入,最终输出接下来的文本。相比Transformer(翻译),GPT只需要用到Decoder,相比之下比较好写。
二、具体效果
现在一共训练了55万个Batch,取其中一些输出作为训练效果的体现,具体如下:
·batch-14000; loss-5.2

词汇表还没有多少,正在学习基本语句
·batch-25000; loss-3.5
句子结构基本正确,但是逻辑欠缺,出现一些不明所以的句子;对部分词语理解有误
·batch-48000; loss-3.0

对部分词的理解欠缺,句子上下文衔接不连贯
·batch-69000; loss-2.5

逻辑转换莫名其妙,出现不明所以的人物,词性未完全理解
·batch-97000; loss-2.2

基本语句流畅,句间逻辑有很大问题,并且重复的词比较多
·batch-117000; loss-2.0

句子语法正确,词义理解基本正确,重复的词汇减少,但是句间逻辑依然有比较大的问题
·batch-197000; loss-1.75

还在学习句子之间的逻辑,明显比上面好,但是出现莫名其妙的转折点
·batch-384000; loss-1.55

句子之间的逻辑明显好很多,形容词增加,句子成分更加复杂
·batch-425000; loss-1.5

生成的故事已经比较有逻辑了,中间可能有些断断续续,突然出现了一些莫名其妙的内容,但是至少已经很连贯了
·batch-514000; loss-1.45

对于长故事的创造力很强,但是句子之间还是缺乏逻辑,以及连接的时候会有很多问题。GPT还在天马行空的想象吧!
对上面的内容进行一个总结:
1、loss的折线图如下所示,逐渐趋缓并且达到一个瓶颈期

2、学习内容、产出文本解读
在多次实验中,生成参数保持一致,唯一的变量是所使用的模型。根据 Transformer 的结构特点,每个注意力头(Multi-Head Attention)都会捕捉不同层面的上下文关系。从生成结果可以看出,GPT 已经能够较好地理解词语的基本含义以及上下文之间的常见组合。然而,它在深入理解整体语义、以及在自回归生成过程中有效记忆较远的上下文内容方面,仍存在一定的不足。
三、实现代码(仅展示部分重要模块)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | class MultiHeadAttention(nn.Module): def __init__(self, d_model, n_heads, dropout=0.1): super().__init__() assert d_model % n_heads == 0 self.d_model = d_model self.n_heads = n_heads self.d_k = d_model // n_heads self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) self.dropout = nn.Dropout(dropout) def forward(self, x, mask=None): B, T, C = x.shape Q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2) K = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2) V = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2) scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k) if mask is not None: scores = scores.masked_fill(mask == 0, -1e4) attn = F.softmax(scores.float(), dim=-1).type_as(scores) attn = self.dropout(attn) output = (attn @ V).transpose(1, 2).contiguous().view(B, T, C) return self.W_o(output) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | class Model(nn.Module): def __init__(self, vocab_size, d_model=512, n_heads=8, num_layers=12, max_seq_len=128, d_ff=2048, dropout=0.1): super().__init__() self.d_model = d_model self.max_seq_len = max_seq_len self.embedding = nn.Embedding(vocab_size, d_model) self.pos_embedding = nn.Embedding(max_seq_len, d_model) self.dropout = nn.Dropout(dropout) self.blocks = nn.ModuleList([ Block(d_model, n_heads, d_ff, dropout) for _ in range(num_layers) ]) self.ln_f = nn.LayerNorm(d_model) self.head = nn.Linear(d_model, vocab_size, bias=False) self.embedding.weight = self.head.weight self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): torch.nn.init.zeros_(module.bias) torch.nn.init.ones_(module.weight) def forward(self, idx, targets=None): B, T = idx.shape device = idx.device tok_emb = self.embedding(idx) pos = torch.arange(0, T, dtype=torch.long, device=device).unsqueeze(0) pos_emb = self.pos_embedding(pos) x = self.dropout(tok_emb + pos_emb) mask = torch.tril(torch.ones(T, T, device=device)).view(1, 1, T, T) for block in self.blocks: x = block(x, mask) x = self.ln_f(x) logits = self.head(x) loss = None if targets is not None: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) loss = torch.clamp(loss, 0, 15) return logits, loss |
登录后方可回帖