【Tranformer-GPT】使用注意力机制进行类GPT模型训练
By
e2hang
at 18 小时前 • 0人收藏 • 2人看过
一、相关原理
类GPT,把输入的token作为文章的开头,进行自回归输入,最终输出接下来的文本。相比Transformer(翻译),GPT只需要用到Decoder,相比之下比较好写。
二、具体效果
三、实现代码(仅展示部分重要模块)
class MultiHeadAttention(nn.Module): def __init__(self, d_model, n_heads, dropout=0.1): super().__init__() assert d_model % n_heads == 0 self.d_model = d_model self.n_heads = n_heads self.d_k = d_model // n_heads self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) self.dropout = nn.Dropout(dropout) def forward(self, x, mask=None): B, T, C = x.shape Q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2) K = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2) V = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2) scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k) if mask is not None: scores = scores.masked_fill(mask == 0, -1e4) attn = F.softmax(scores.float(), dim=-1).type_as(scores) attn = self.dropout(attn) output = (attn @ V).transpose(1, 2).contiguous().view(B, T, C) return self.W_o(output)
class Model(nn.Module): def __init__(self, vocab_size, d_model=512, n_heads=8, num_layers=12, max_seq_len=128, d_ff=2048, dropout=0.1): super().__init__() self.d_model = d_model self.max_seq_len = max_seq_len self.embedding = nn.Embedding(vocab_size, d_model) self.pos_embedding = nn.Embedding(max_seq_len, d_model) self.dropout = nn.Dropout(dropout) self.blocks = nn.ModuleList([ Block(d_model, n_heads, d_ff, dropout) for _ in range(num_layers) ]) self.ln_f = nn.LayerNorm(d_model) self.head = nn.Linear(d_model, vocab_size, bias=False) self.embedding.weight = self.head.weight self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): torch.nn.init.zeros_(module.bias) torch.nn.init.ones_(module.weight) def forward(self, idx, targets=None): B, T = idx.shape device = idx.device tok_emb = self.embedding(idx) pos = torch.arange(0, T, dtype=torch.long, device=device).unsqueeze(0) pos_emb = self.pos_embedding(pos) x = self.dropout(tok_emb + pos_emb) mask = torch.tril(torch.ones(T, T, device=device)).view(1, 1, T, T) for block in self.blocks: x = block(x, mask) x = self.ln_f(x) logits = self.head(x) loss = None if targets is not None: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) loss = torch.clamp(loss, 0, 15) return logits, loss
登录后方可回帖