Pytorch的Linformer实现
这是Linformer在Pytorch中的一个实现。Linformer有两个缺陷:(1) 它不适用于自回归的情况。(2) 假设序列长度是固定的。然而,如果基准测试显示其性能足够好,它将被添加到这个仓库中作为一个可用于编码器的自注意力层。
Linformer已经被Facebook投入生产使用!
安装
$ pip install linformer
使用方法
Linformer语言模型
import torch
from linformer import LinformerLM
model = LinformerLM(
num_tokens = 20000,
dim = 512,
seq_len = 4096,
depth = 12,
heads = 8,
dim_head = 128, # 可以设置多头注意力中每个头的维度
k = 256, # 这是key/value在序列维度上投影的k值
one_kv_head = True, # 在所有头之间共享一个key/value头
share_kv = False, # 共享key和value的相同投影
reversible = True # 使网络可逆,类似于Reformer
)
x = torch.randint(0, 20000, (1, 4096))
model(x) # (1, 4096, 20000)
Linformer
import torch
from linformer import Linformer
model = Linformer(
dim = 512,
seq_len = 4096,
depth = 12,
heads = 8,
k = 256,
one_kv_head = True,
share_kv = True
)
x = torch.randn(1, 4096, 512)
model(x) # (1, 4096, 512)
单个自注意力层
import torch
from linformer import LinformerSelfAttention
attn = LinformerSelfAttention(
dim = 512,
seq_len = 4096,
heads = 8,
k = 256,
one_kv_head = True,
share_kv = True
)
x = torch.randn(1, 4096, 512)
attn(x) # (1, 4096, 512)
上述自注意力层接收上下文键。序列长度基于上下文键的长度而不是源序列进行验证。
import torch
from linformer import LinformerSelfAttention
attn = LinformerSelfAttention(
dim = 512,
seq_len = 8192,
heads = 8,
k = 256,
one_kv_head = True,
share_kv = True
)
x = torch.randn(1, 2048, 512)
context = torch.randn(1, 8192, 512)
attn(x, context) # (1, 2048, 512)
引用
@misc{wang2020linformer,
title={Linformer: Self-Attention with Linear Complexity},
author={Sinong Wang and Belinda Z. Li and Madian Khabsa and Han Fang and Hao Ma},
year={2020},
eprint={2006.04768},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
@inproceedings{kitaev2020reformer,
title = {Reformer: The Efficient Transformer},
author = {Nikita Kitaev and Lukasz Kaiser and Anselm Levskaya},
booktitle = {International Conference on Learning Representations},
year = {2020},
url = {https://openreview.net/forum?id=rkgNKkHtvB}
}