Project Icon

recurrent-memory-transformer-pytorch

Recurrent Memory Transformer的PyTorch实现助力超长序列处理

Recurrent Memory Transformer的PyTorch实现项目致力于解决超长序列处理问题。该模型通过创新的记忆机制和高效注意力机制,可处理长达百万token的序列。项目提供简便的安装使用方法,支持XL记忆和记忆回放反向传播等先进功能。这一实现在长序列处理、因果推理和强化学习等领域展现出优异性能,为AI研究和应用开发提供了实用工具。

循环记忆 Transformer - Pytorch 实现

在 Pytorch 中实现循环记忆 Transformer (openreview)。他们最近发表了一篇简短的后续论文,证明它至少能够复制 100 万个 token 的信息。

我坚信 RMT 会比 AdA(仅是一个 Transformer-XL)成为更强大的强化学习代理 - 更新:循环记忆决策 Transformer

Yannic Kilcher 论文评论

致谢

安装

$ pip install recurrent-memory-transformer-pytorch

使用方法

import torch
from recurrent_memory_transformer_pytorch import RecurrentMemoryTransformer

model = RecurrentMemoryTransformer(
    num_tokens = 20000,               # token 数量
    num_memory_tokens = 128,          # 记忆 token 数量,这将决定传递给未来的信息瓶颈
    dim = 512,                        # 模型维度
    depth = 6,                        # transformer 深度
    causal = True,                    # 是否自回归
    dim_head = 64,                    # 每个头的维度
    heads = 8,                        # 头数
    seq_len = 1024,                   # 段的序列长度
    use_flash_attn = True             # 是否使用快速注意力
)

x = torch.randint(0, 256, (1, 1024))

logits1, mem1, _ = model(x)        # (1, 1024, 20000), (1, 128, 512), None
logits2, mem2, _ = model(x, mem1)  # (1, 1024, 20000), (1, 128, 512), None
logits3, mem3, _ = model(x, mem2)  # (1, 1024, 20000), (1, 128, 512), None

# 依此类推 ...

使用 XL 记忆

import torch
from recurrent_memory_transformer_pytorch import RecurrentMemoryTransformer

model = RecurrentMemoryTransformer(
    num_tokens = 20000,
    num_memory_tokens = 128,
    dim = 512,
    depth = 6,
    causal = True,
    dim_head = 64,
    heads = 8,
    seq_len = 1024,
    use_flash_attn = True,
    use_xl_memories = True,    # 将此设置为 True
    xl_mem_len = 512           # 可以比 seq_len 短 - 我认为只需要一点过去的信息就能防止大部分 RMT 记忆记住紧邻的前文
)

x = torch.randint(0, 256, (1, 1024))

logits1, mem1, xl_mem1 = model(x)                               # (1, 1024, 20000), (1, 128, 512), [(2, 1, 512, 512)]
logits2, mem2, xl_mem2 = model(x, mem1, xl_memories = xl_mem1)  # (1, 1024, 20000), (1, 128, 512), [(2, 1, 512, 512)]
logits3, mem3, xl_mem3 = model(x, mem2, xl_memories = xl_mem2)  # (1, 1024, 20000), (1, 128, 512), [(2, 1, 512, 512)]

# 依此类推 ...

在极长序列上训练

import torch
from recurrent_memory_transformer_pytorch import (
    RecurrentMemoryTransformer,
    RecurrentMemoryTransformerWrapper
)

model = RecurrentMemoryTransformer(
    num_tokens = 256,
    num_memory_tokens = 128,
    dim = 512,
    depth = 6,
    seq_len = 1024,
    use_flash_attn = True,
    causal = True
)

model = RecurrentMemoryTransformerWrapper(model).cuda()

seq = torch.randint(0, 256, (4, 65536)).cuda()   # 极长序列,实际上,他们从 1 个段开始到大约 7-8 个段进行课程学习

loss = model(seq, memory_replay_backprop = True) # 来自 memformer 论文的内存高效训练

待办事项

  • 将记忆回放反向传播移至 torch.function,测试双向,然后在实际问题上测试

  • 使旋转嵌入与 xl 记忆正常工作

  • 添加 xl 记忆,分离

  • 提供关闭旋转嵌入、绝对位置嵌入的方法,并添加 token 移位

  • 将因果掩蔽记忆设为可选

  • 添加来自 memformer 论文的记忆回放反向传播技术

  • 相对位置编码

替代方案

引用

@inproceedings{bulatov2022recurrent,
  title     = {Recurrent Memory Transformer},
  author    = {Aydar Bulatov and Yuri Kuratov and Mikhail Burtsev},
  booktitle = {Advances in Neural Information Processing Systems},
  editor    = {Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho},
  year      = {2022},
  url       = {https://openreview.net/forum?id=Uynr3iPhksa}
}
@misc{bulatov2023scaling,
  title     = {Scaling Transformer to 1M tokens and beyond with RMT},
  author    = {Aydar Bulatov and Yuri Kuratov and Mikhail S. Burtsev},
  year      = {2023},
  eprint    = {2304.11062},
  archivePrefix = {arXiv},
  primaryClass = {cs.CL}
}
@inproceedings{dao2022flashattention,
  title     = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
  author    = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
  booktitle = {Advances in Neural Information Processing Systems},
  year      = {2022}
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@inproceedings{Wu2020MemformerAM,
    title   = {Memformer: 一种用于序列建模的记忆增强型Transformer},
    author  = {吴清阳 and 兰振中 and 钱堃 and 顾静 and Alborz Geramifard and 俞舟},
    booktitle = {AACL/IJCNLP},
    year    = {2020}
}
@software{peng_bo_2021_5196578,
    author    = {彭博},
    title     = {BlinkDL/RWKV-LM: 0.01},
    month     = {8月},
    year      = {2021},
    publisher = {Zenodo},
    version   = {0.01},
    doi       = {10.5281/zenodo.5196578},
    url       = {https://doi.org/10.5281/zenodo.5196578}
}
@misc{ding2021cogview,
    title   = {CogView: 通过Transformer掌握文本到图像的生成},
    author  = {丁明 and 杨卓艺 and 洪文毅 and 郑文迪 and 周畅 and 尹达 and 林俊阳 and 邹旭 and 邵周 and 杨红霞 and 唐杰},
    year    = {2021},
    eprint  = {2105.13290},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@software{Dayma_DALLE_Mini_2021,
    author  = {Boris Dayma and Suraj Patil and Pedro Cuenca and Khalid Saifullah and Tanishq Abraham and Phúc Lê Khắc and Luke Melas and Ritobrata Ghosh},
    doi     = {10.5281/zenodo.5146400},
    license = {Apache-2.0},
    month   = {7月},
    title   = {{DALL·E Mini}},
    url     = {https://github.com/borisdayma/dalle-mini},
    version = {v0.1-alpha},
    year    = {2021}}
@inproceedings{anonymous2022normformer,
    title   = {NormFormer: 通过额外归一化改进的Transformer预训练},
    author  = {匿名},
    booktitle = {提交至第十届国际学习表示会议},
    year    = {2022},
    url     = {https://openreview.net/forum?id=GMYWzWztDx5},
    note    = {审核中}
}
@misc{ding2021erniedoc,
    title   = {ERNIE-Doc: 一种回顾性长文档建模Transformer},
    author  = {丁思宇 and 商骏远 and 王硕欢 and 孙宇 and 田昊 and 吴华 and 王海峰},
    year    = {2021},
    eprint  = {2012.15688},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@article{Xie2023ResiDualTW,
  title     = {ResiDual: 具有双重残差连接的Transformer},
  author    = {谢书芳 and 张会帅 and 郭俊良 and 谭旭 and 边江 and Hany Hassan Awadalla and Arul Menezes and 秦涛 and 严睿},
  journal   = {ArXiv},
  year      = {2023},
  volume    = {abs/2304.14802}
}
项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

吐司

探索Tensor.Art平台的独特AI模型,免费访问各种图像生成与AI训练工具,从Stable Diffusion等基础模型开始,轻松实现创新图像生成。体验前沿的AI技术,推动个人和企业的创新发展。

Project Cover

SubCat字幕猫

SubCat字幕猫APP是一款创新的视频播放器,它将改变您观看视频的方式!SubCat结合了先进的人工智能技术,为您提供即时视频字幕翻译,无论是本地视频还是网络流媒体,让您轻松享受各种语言的内容。

Project Cover

美间AI

美间AI创意设计平台,利用前沿AI技术,为设计师和营销人员提供一站式设计解决方案。从智能海报到3D效果图,再到文案生成,美间让创意设计更简单、更高效。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号