H-Transformer-1D
H-Transformer-1D的实现,使用分层注意力进行序列学习的Transformer,具有亚二次方的计算成本。该架构的编码器(非自回归)版本目前在Long Range Arena(一个高效Transformer的基准测试)中保持领先地位。
安装
$ pip install h-transformer-1d
使用
import torch
from h_transformer_1d import HTransformer1D
model = HTransformer1D(
num_tokens = 256, # 标记数量
dim = 512, # 维度
depth = 12, # 深度
causal = False, # 是否自回归
max_seq_len = 8192, # 最大序列长度
heads = 8, # 注意力头数
dim_head = 64, # 每个注意力头的维度
block_size = 128, # 块大小
reversible = True, # 使用可逆性,以增加深度的同时节省内存
shift_tokens = True # 是否在序列维度上将一半特征空间移动一位,以加快收敛(实验性功能)
)
x = torch.randint(0, 256, (1, 8000)) # 可变序列长度
mask = torch.ones((1, 8000)).bool() # 可变掩码长度
# 网络将自动填充到2的幂次,进行分层注意力等操作
logits = model(x, mask = mask) # (1, 8000, 256)
引用
@misc{zhu2021htransformer1d,
title = {H-Transformer-1D: Fast One-Dimensional Hierarchical Attention for Sequences},
author = {Zhenhai Zhu and Radu Soricut},
year = {2021},
eprint = {2107.11906},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@software{peng_bo_2021_5196578,
author = {PENG Bo},
title = {BlinkDL/RWKV-LM: 0.01},
month = {aug},
year = {2021},
publisher = {Zenodo},
version = {0.01},
doi = {10.5281/zenodo.5196578},
url = {https://doi.org/10.5281/zenodo.5196578}
}