Transformer 是样本高效的世界模型 (IRIS)
Transformer 是样本高效的世界模型
Vincent Micheli*, Eloi Alonso*, François Fleuret
* 表示贡献相同
简要总结
- IRIS 是一个数据高效的智能体,在世界模型中通过数百万次模拟轨迹进行训练。
- 世界模型由一个离散自编码器和一个自回归 Transformer 组成。
- 我们的方法将动态学习视为一个序列建模问题,其中自编码器构建图像标记语言,Transformer 在时间维度上组合这种语言。
BibTeX
如果您发现这份代码或论文有用,请使用以下引用:
@inproceedings{
iris2023,
title={Transformers are Sample-Efficient World Models},
author={Vincent Micheli and Eloi Alonso and Fran{\c{c}}ois Fleuret},
booktitle={The Eleventh International Conference on Learning Representations },
year={2023},
url={https://openreview.net/forum?id=vhFu1Acb0xb}
}
环境设置
- 安装 PyTorch(torch 和 torchvision)。代码开发使用 torch==1.11.0 和 torchvision==0.12.0。
- 安装其他依赖:
pip install -r requirements.txt
- 注意:Atari ROM 将随依赖项一起下载,这意味着您确认您有使用它们的许可。
启动训练
python src/main.py env.train.id=BreakoutNoFrameskip-v4 common.device=cuda:0 wandb.mode=online
默认情况下,日志会同步到 weights & biases,设置 wandb.mode=disabled
可以关闭同步。
配置
- 所有配置文件位于
config/
目录,主配置文件是config/trainer.yaml
。 - 自定义配置最简单的方法是直接编辑这些文件。
- 有关配置管理的更多详细信息,请参阅 Hydra。
运行文件夹
每次新运行都位于 outputs/YYYY-MM-DD/hh-mm-ss/
目录。该文件夹的结构如下:
outputs/YYYY-MM-DD/hh-mm-ss/
│
└─── checkpoints
│ │ last.pt
| | optimizer.pt
| | ...
│ │
│ └─── dataset
│ │ 0.pt
│ │ 1.pt
│ │ ...
│
└─── config
│ | trainer.yaml
|
└─── media
│ │
│ └─── episodes
│ | │ ...
│ │
│ └─── reconstructions
│ | │ ...
│
└─── scripts
| | eval.py
│ │ play.sh
│ │ resume.sh
| | ...
|
└─── src
| | ...
|
└─── wandb
| ...
checkpoints
:包含模型的最新检查点、优化器和数据集。media
:episodes
:包含用于可视化目的的训练/测试/想象情节。reconstructions
:包含原始帧及其通过自动编码器重建的结果。
scripts
:从运行文件夹中,你可以使用以下三个脚本。eval.py
:运行python ./scripts/eval.py
来评估运行结果。resume.sh
:运行./scripts/resume.sh
来恢复崩溃的训练。play.sh
:用于可视化运行的一些有趣方面的工具。- 运行
./scripts/play.sh
观看代理在环境中实时游戏。如果添加-r
标志,左侧面板显示原始帧,中间面板显示缩小到离散自动编码器输入分辨率的相同帧,右侧面板显示自动编码器的输出(代理实际看到的内容)。 - 运行
./scripts/play.sh -w
使用键盘输入实时展开轨迹(即在世界模型中游戏)。注意,为了更快的交互,Transformer的内存每20帧清空一次。 - 运行
./scripts/play.sh -a
观看代理在世界模型中实时游戏。注意,为了更快的交互,Transformer的内存每20帧清空一次。 - 运行
./scripts/play.sh -e
可视化media/episodes
中包含的情节。 - 添加
-h
标志显示带有附加信息的标题。 - 按 "," 开始和停止录制。相应的片段以mp4和numpy格式保存在
media/recordings
中。 - 添加
-s
标志进入"保存模式",完成后会提示用户保存轨迹。
- 运行
结果笔记本
results/data/
文件夹包含IRIS和基准模型的原始分数(每个游戏和每次训练运行的分数)。
使用 results/results_iris.ipynb
笔记本重现论文中的图表。
预训练模型
预训练模型可在此处获取。
- 要从这些检查点之一开始训练运行,在
config/trainer.yaml
的initialization
部分,将path_to_checkpoint
设置为相应的路径,并将load_tokenizer
、load_world_model
和load_actor_critic
设置为True
。 - 要可视化其中一个检查点,请在
config/env/default.yaml
中将train.id
设置为相应的游戏,创建一个checkpoints
目录,并将检查点复制到checkpoints/last.pt
。然后,您可以按照上述描述使用./scripts/play.sh
来可视化代理。