x-transformers 项目介绍
x-transformers 是一个简洁但功能完备的 Transformer 库,集成了多种来自不同论文的创新性实验特性。该库旨在提供全面的 Transformer 模型解决方案,支持编码器-解码器结构、仅解码器结构(如 GPT)和仅编码器结构(如 BERT),同时还包括视觉转换器模型(如 SimpleViT)。
安装
要使用 x-transformers,可以通过 pip 命令安装:
$ pip install x-transformers
使用方法
编码器-解码器模型
通过 XTransformer 类,可以轻松搭建一个完整的编码器-解码器模型:
import torch
from x_transformers import XTransformer
model = XTransformer(
dim=512,
enc_num_tokens=256,
enc_depth=6,
enc_heads=8,
enc_max_seq_len=1024,
dec_num_tokens=256,
dec_depth=6,
dec_heads=8,
dec_max_seq_len=1024,
tie_token_emb=True # 绑定编码器和解码器的词嵌入
)
src = torch.randint(0, 256, (1, 1024))
src_mask = torch.ones_like(src).bool()
tgt = torch.randint(0, 256, (1, 1024))
loss = model(src, tgt, mask=src_mask) # (1, 1024, 512)
loss.backward()
仅解码器(类似 GPT)
构建仅解码器模型非常简单,如 GPT 模型的架构:
import torch
from x_transformers import TransformerWrapper, Decoder
model = TransformerWrapper(
num_tokens=20000,
max_seq_len=1024,
attn_layers=Decoder(
dim=512,
depth=12,
heads=8
)
).cuda()
x = torch.randint(0, 256, (1, 1024)).cuda()
output = model(x) # (1, 1024, 20000)
仅编码器(类似 BERT)
使用 TransformerWrapper 类创建一个仅编码器的模型,例如 BERT 模型:
import torch
from x_transformers import TransformerWrapper, Encoder
model = TransformerWrapper(
num_tokens=20000,
max_seq_len=1024,
attn_layers=Encoder(
dim=512,
depth=12,
heads=8
)
).cuda()
x = torch.randint(0, 256, (1, 1024)).cuda()
mask = torch.ones_like(x).bool()
output = model(x, mask=mask) # (1, 1024, 20000)
特色功能
闪电注意力(Flash Attention)
这种技术通过以块状处理注意力矩阵来优化内存使用,从而支持更长的上下文处理而不受内存瓶颈的限制。使用该功能只需将 attn_flash
参数设置为 True
:
import torch
from x_transformers import TransformerWrapper, Decoder
model = TransformerWrapper(
num_tokens=20000,
max_seq_len=1024,
attn_layers=Decoder(
dim=512,
depth=6,
heads=8,
attn_flash=True # 开启闪电注意力
)
)
增强的自注意力(带持久内存的增益注意力)
这项技术通过在注意力层之前增加学习到的内存键/值来提升性能,并允许删除前馈网络,或者同时使用两者以获得最佳效果:
from x_transformers import Encoder
enc = Encoder(
dim=512,
depth=6,
heads=8,
attn_num_mem_kv=16 # 添加16个内存键/值
)
GLU 变体提升 Transformer
通过简单的门控机制(例如,使用 GELU)来显著提升 Transformer 的性能:
import torch
from x_transformers import TransformerWrapper, Decoder, Encoder
model = TransformerWrapper(
num_tokens=20000,
max_seq_len=1024,
attn_layers=Decoder(
dim=512,
depth=6,
heads=8,
ff_glu=True # 启用 GLU 变体
)
)
以上仅为 x-transformers 提供的众多特性中的一部分,通过这些创新特性,开发者可以构建更高效、性能更优的 Transformer 模型。同时,该库提供的灵活接口使得研究人员与工程师可以结合最新的学术研究,快速进行实验与迭代。