LongMem
我们论文"使用长期记忆增强语言模型"的官方实现。
如果您觉得这个仓库有趣或有帮助,请引用我们的论文:
@article{LongMem,
title={使用长期记忆增强语言模型},
author={Wang, Weizhi and Dong, Li and Cheng, Hao and Liu, Xiaodong and Yan, Xifeng and Gao, Jianfeng and Wei, Furu},
journal={arXiv预印本 arXiv:2306.07174},
year={2023}
}
环境设置
-
torch: 请遵循torch官方安装指南。我们推荐torch>=1.8.0。请选择与您的cuda驱动版本一致的torch-gpu版本。
-
Faiss-GPU: 对于Nvidia V100 GPU,只需通过
pip install faiss-gpu
安装。对于Nvidia A100、A6000 GPU,请运行conda install faiss-gpu cudatoolkit=11.0 -c pytorch
。A100 GPU不受faiss-gpu官方支持,有时会导致错误,您可以参考faiss的这个git 问题寻求帮助。 -
fairseq:
pip install --editable ./fairseq
然后修订版的fairseq
和依赖包将被安装。为了稳定性,我们强烈建议您使用python 3.8。 -
其他包:
pip install -r requirements.txt
项目结构
-
预训练LLM类(L24, E1024, Alibi位置嵌入):
fairseq/fairseq/models/newgpt.py
-
带SideNetwork的Transformer解码器(L12, E1024):
fairseq/fairseq/models/sidenet/transformer_decoder_sidenet.py
-
带SideNetwork的Transformer语言模型类:
fairseq/fairseq/models/transformer_lm_sidenet.py
-
记忆库和检索:
fairseq/fairseq/modules/dynamic_memory_with_chunk.py
-
用于记忆融合的联合注意力:
fairseq/fairseq/modules/joint_multihead_attention_sum.py
记忆增强适应训练
数据收集和预处理
请从官方发布下载Pile。Pile中的每个子数据集都被组织为各种jsonline分片。您可以参考preprocess/filter_shard_tnlg.py
了解我们如何抽样训练集并按照标准fairseq预处理过程进行二值化。
记忆增强适应训练:
bash train_scripts/train_longmem.sh
评估
请首先下载GPT2-medium模型和LongMem模型的预训练检查点到checkpoints/
。
记忆增强上下文学习
# 评估gpt2基线
python eval_scripts/eval_longmem_icl.py --path /path/to/gpt2_pretrained_model
# 评估LongMem模型
python eval_scripts/eval_longmem_icl.py --path /path/to/longmem_model --pretrained-model-path /path/to/gpt2_pretrained_model
鸣谢
LongMem基于fairseq开发。感谢eleuther.ai团队构建了最大的高质量语料库Pile。