Logo

PyTorch Transformer教程:从零开始实现注意力机制

PyTorch Transformer教程:从零开始实现注意力机制

在深度学习领域,Transformer模型已经成为处理序列数据的首选架构,特别是在自然语言处理任务中。本教程将详细介绍如何使用PyTorch从零开始实现一个Transformer模型,并应用于机器翻译任务。

Transformer的优势

与传统的循环神经网络(RNN)相比,Transformer具有以下显著优势:

  1. 并行计算能力强: Transformer可以并行处理序列中的所有元素,而RNN必须按顺序处理。

  2. 长距离依赖建模能力强: Transformer可以直接访问序列中的任意位置,而RNN需要通过中间状态传递信息。

  3. 表征能力强: 多头注意力机制使Transformer能够从多个角度理解输入序列。

Transformer vs RNN

Transformer的核心组件

  1. 多头注意力机制(Multi-Head Attention)

多头注意力是Transformer的核心,它允许模型同时关注序列的不同位置,从多个表示子空间学习信息。

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        
        # Linear projections
        Q = self.W_q(query)
        K = self.W_k(key)  
        V = self.W_v(value)
        
        # Split into multiple heads
        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attn = F.softmax(scores, dim=-1)
        
        # Apply attention to values
        context = torch.matmul(attn, V)
        
        # Concatenate heads
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
        
        # Final linear projection
        output = self.W_o(context)
        
        return output
  1. 位置编码(Positional Encoding)

由于Transformer没有固有的序列处理能力,我们需要通过位置编码为模型提供位置信息。

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        return x + self.pe[:x.size(0), :]
  1. 前馈神经网络(Feed Forward Network)

每个Transformer层都包含一个前馈神经网络,用于进一步处理注意力机制的输出。

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        
    def forward(self, x):
        return self.linear2(F.relu(self.linear1(x)))

构建Transformer模型

将上述组件组合,我们可以构建完整的Transformer模型:

class Transformer(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super().__init__()
        self.encoder_embedding = nn.Embedding(src_vocab, d_model)
        self.decoder_embedding = nn.Embedding(tgt_vocab, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)
        
        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        
        self.fc = nn.Linear(d_model, tgt_vocab)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src, tgt, src_mask, tgt_mask):
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))
        
        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)
        
        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)
        
        output = self.fc(dec_output)
        return output

应用于机器翻译任务

我们以英语到德语的翻译为例,展示如何使用Transformer进行机器翻译。

  1. 数据准备

首先,我们需要准备平行语料库,并使用BPE(Byte Pair Encoding)等技术处理词表。

from torchtext.data import Field, BucketIterator
from torchtext.datasets import Multi30k

SRC = Field(tokenize = "spacy", tokenizer_language="en_core_web_sm", init_token = '<sos>', eos_token = '<eos>', lower = True)
TGT = Field(tokenize = "spacy", tokenizer_language="de_core_news_sm", init_token = '<sos>', eos_token = '<eos>', lower = True)

train_data, valid_data, test_data = Multi30k.splits(exts = ('.en', '.de'), fields = (SRC, TGT))

SRC.build_vocab(train_data, min_freq = 2)
TGT.build_vocab(train_data, min_freq = 2)
  1. 模型训练

接下来,我们定义损失函数和优化器,并开始训练过程。

model = Transformer(len(SRC.vocab), len(TGT.vocab), d_model=512, num_heads=8, num_layers=6, d_ff=2048, max_seq_length=100, dropout=0.1)
criterion = nn.CrossEntropyLoss(ignore_index=TGT.vocab.stoi['<pad>'])
optimizer = optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

def train(model, iterator, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0
    for i, batch in enumerate(iterator):
        src = batch.src
        tgt = batch.trg
        
        optimizer.zero_grad()
        output = model(src, tgt[:,:-1])
        output = output.contiguous().view(-1, output.shape[-1])
        tgt = tgt[:,1:].contiguous().view(-1)
        
        loss = criterion(output, tgt)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

N_EPOCHS = 10
CLIP = 1

for epoch in range(N_EPOCHS):
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
    print(f'Epoch: {epoch+1:02} | Train Loss: {train_loss:.3f}')
  1. 翻译推理

最后,我们可以使用训练好的模型进行翻译:

def translate_sentence(sentence, src_field, trg_field, model, device, max_len = 50):
    model.eval()
    
    if isinstance(sentence, str):
        nlp = spacy.load('en_core_web_sm')
        tokens = [token.text.lower() for token in nlp(sentence)]
    else:
        tokens = [token.lower() for token in sentence]
    
    tokens = ['<sos>'] + tokens + ['<eos>']
    src_indexes = [src_field.vocab.stoi[token] for token in tokens]
    src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device)
    
    src_mask = model.make_src_mask(src_tensor)
    
    with torch.no_grad():
        enc_src = model.encoder(src_tensor, src_mask)

    trg_indexes = [trg_field.vocab.stoi['<sos>']]

    for i in range(max_len):
        trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)
        trg_mask = model.make_trg_mask(trg_tensor)
        
        with torch.no_grad():
            output = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)
        
        pred_token = output.argmax(2)[:,-1].item()
        trg_indexes.append(pred_token)

        if pred_token == trg_field.vocab.stoi['<eos>']:
            break
    
    trg_tokens = [trg_field.vocab.itos[i] for i in trg_indexes]
    
    return trg_tokens[1:]

example_sentence = "The sun is shining brightly today."
translation = translate_sentence(example_sentence, SRC, TGT, model, device)
print(f'Source: {example_sentence}')
print(f'Translation: {" ".join(translation)}')

结论

通过本教程,我们详细介绍了如何使用PyTorch从零开始实现Transformer模型,并将其应用于机器翻译任务。Transformer的强大表现力和灵活性使其成为处理序列数据的首选模型之一。

随着技术的不断发展,Transformer及其变体(如BERT、GPT等)在自然语言处理、计算机视觉等领域都取得了巨大成功。深入理解Transformer的工作原理,将有助于我们更好地应用和改进这一强大的模型架构。

参考资源

  1. Attention Is All You Need - Transformer原始论文
  2. PyTorch官方文档
  3. The Illustrated Transformer - 图解Transformer工作原理

希望这篇教程能够帮助你深入理解Transformer模型,并在实际项目中灵活运用。如果你有任何问题或建议,欢迎在评论区留言讨论。让我们一起探索人工智能的无限可能! 🚀🤖

相关项目

Project Cover
fastbook
本项目提供涵盖fastai和PyTorch的深度学习教程,适合初学者与进阶用户。可通过Google Colab在线运行,无需本地配置Python环境。项目还包括MOOC课程及相关书籍,系统化帮助用户学习深度学习技术。
Project Cover
pytorch-handbook
本开源书籍为使用PyTorch进行深度学习开发的用户提供系统化的入门指南。教程内容覆盖了从环境搭建到高级应用的各个方面,包括PyTorch基础、深度学习数学原理、神经网络、卷积神经网络、循环神经网络等,还包含实践案例与多GPU并行训练技巧。书籍持续更新,与PyTorch版本同步,适合所有深度学习研究者。
Project Cover
fastai
fastai是一个深度学习库,提供高层组件以快速实现高性能结果,同时为研究人员提供可组合的低层组件。通过分层架构和Python、PyTorch的灵活性,fastai在不牺牲易用性、灵活性和性能的情况下,实现了高效的深度学习。支持多种安装方式,包括Google Colab和conda,适用于Windows和Linux。学习资源丰富,包括书籍、免费课程和详细文档。
Project Cover
annotated_deep_learning_paper_implementations
该项目提供详细文档和解释的简明PyTorch神经网络及算法实现,涵盖Transformer、GPT-NeoX、GAN、扩散模型等前沿领域,并每周更新新实现,帮助研究者和开发者高效理解深度学习算法。
Project Cover
keras
Keras 3 提供高效的模型开发,支持计算机视觉、自然语言处理等任务。选择最快的后端(如JAX),性能提升高达350%。无缝扩展,从本地到大规模集群,适合企业和初创团队。安装简单,支持GPU,兼容tf.keras代码,避免框架锁定。
Project Cover
CLIP
CLIP通过对比学习训练神经网络,结合图像和文本,实现自然语言指令预测。其在ImageNet零样本测试中的表现与ResNet50相当,无需使用原始标注数据。安装便捷,支持多种API,适用于零样本预测和线性探针评估,推动计算机视觉领域发展。
Project Cover
allennlp
AllenNLP是一个基于PyTorch的Apache 2.0自然语言处理研究库,专注于开发先进的深度学习模型。该项目已进入维护模式,并将在2022年12月16日前继续修复问题和响应用户提问。推荐的替代项目包括AI2 Tango、allennlp-light、flair和torchmetrics,以帮助用户更好地管理实验和使用预训练模型。
Project Cover
pix2pix
使用条件对抗网络实现图像到图像翻译,支持从建筑立面生成到日夜转换等多种任务。该项目能在小数据集上快速产生良好结果,并提供改进版的PyTorch实现。支持多种数据集和模型,并附有详细的安装、训练和测试指南。
Project Cover
pytorch-CycleGAN-and-pix2pix
该项目提供了PyTorch框架下的CycleGAN和pix2pix图像翻译实现,支持配对和无配对的图像翻译。最新版本引入img2img-turbo和StableDiffusion-Turbo模型,提高了训练和推理效率。项目页面包含详细的安装指南、训练和测试步骤,以及常见问题解答。适用于Linux和macOS系统,兼容最新的PyTorch版本,并提供Docker和Colab支持,便于快速上手。

最新项目

Project Cover
豆包MarsCode
豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。
Project Cover
AI写歌
Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。
Project Cover
美间AI
美间AI创意设计平台,利用前沿AI技术,为设计师和营销人员提供一站式设计解决方案。从智能海报到3D效果图,再到文案生成,美间让创意设计更简单、更高效。
Project Cover
商汤小浣熊
小浣熊家族Raccoon,您的AI智能助手,致力于通过先进的人工智能技术,为用户提供高效、便捷的智能服务。无论是日常咨询还是专业问题解答,小浣熊都能以快速、准确的响应满足您的需求,让您的生活更加智能便捷。
Project Cover
有言AI
有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。
Project Cover
Kimi
Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。
Project Cover
吐司
探索Tensor.Art平台的独特AI模型,免费访问各种图像生成与AI训练工具,从Stable Diffusion等基础模型开始,轻松实现创新图像生成。体验前沿的AI技术,推动个人和企业的创新发展。
Project Cover
SubCat字幕猫
SubCat字幕猫APP是一款创新的视频播放器,它将改变您观看视频的方式!SubCat结合了先进的人工智能技术,为您提供即时视频字幕翻译,无论是本地视频还是网络流媒体,让您轻松享受各种语言的内容。
Project Cover
AIWritePaper论文写作
AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。
投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号