MEGABYTE - Pytorch
在Pytorch中实现MEGABYTE,使用多尺度Transformer预测百万字节序列。我们进一步将其泛化,使其可以拥有多个本地模型。
致谢
- 感谢Stability和🤗 Huggingface慷慨赞助,使我们能够开展并开源前沿人工智能研究
安装
$ pip install MEGABYTE-pytorch
使用方法
import torch
from MEGABYTE_pytorch import MEGABYTE
model = MEGABYTE(
num_tokens = 16000, # 词元数量
dim = (512, 256), # transformer模型维度(最粗粒度为512,细粒度为256,此为示例)
max_seq_len = (1024, 4), # 全局和局部的序列长度。可以超过2个
depth = (6, 4), # 全局和局部的层数。可以超过2个,但长度必须与max_seq_len匹配
dim_head = 64, # 每个注意力头的维度
heads = 8, # 注意力头的数量
flash_attn = True # 使用快速注意力机制
)
x = torch.randint(0, 16000, (1, 1024, 4))
loss = model(x, return_loss = True)
loss.backward()
# 经过大量训练后
logits = model(x)
# 然后根据logits进行采样
# 或者你可以使用generate函数
sampled = model.generate(temperature = 0.9, filter_thres = 0.9) # (1, 1024, 4)
测试
在字符级enwik8上训练,patch大小为4,长度为8192
$ python train.py
引用
@misc{yu2023megabyte,
title = {MEGABYTE: Predicting Million-byte Sequences with Multiscale Transformers},
author = {Lili Yu and Dániel Simig and Colin Flaherty and Armen Aghajanyan and Luke Zettlemoyer and Mike Lewis},
year = {2023},
eprint = {2305.07185},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@misc{https://doi.org/10.48550/arxiv.2302.01327,
doi = {10.48550/ARXIV.2302.01327},
url = {https://arxiv.org/abs/2302.01327},
author = {Kumar, Manoj and Dehghani, Mostafa and Houlsby, Neil},
title = {Dual PatchNorm},
publisher = {arXiv},
year = {2023},
copyright = {Creative Commons Attribution 4.0 International}
}
@inproceedings{dao2022flashattention,
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
year = {2022}
}
@software{peng_bo_2021_5196578,
author = {PENG Bo},
title = {BlinkDL/RWKV-LM: 0.01},
month = {aug},
year = {2021},
publisher = {Zenodo},
version = {0.01},
doi = {10.5281/zenodo.5196578},
url = {https://doi.org/10.5281/zenodo.5196578}
}
@article{Kazemnejad2023TheIO,
title = {The Impact of Positional Encoding on Length Generalization in Transformers},
author = {Amirhossein Kazemnejad and Inkit Padhi and Karthikeyan Natesan Ramamurthy and Payel Das and Siva Reddy},
journal = {ArXiv},
year = {2023},
volume = {abs/2305.19466}
}
@misc{su2021roformer,
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
year = {2021},
eprint = {2104.09864},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}