爬行动物
OpenAI 的 Reptile 算法在 PyTorch 上的有监督学习实现。
目前,它可以在 Omniglot 上运行,但尚未在 MiniImagenet 上运行。
该代码尚未经过广泛测试。欢迎提供贡献和反馈!
Omniglot 元学习数据集
torchvision 中已经有一个 Omniglot 数据集类,但它似乎更适合于有监督学习而非少样本学习。
omniglot.py
提供了一种从 Omniglot 中采样 K-shot N-way 基本任务的方法,
并提供了各种将元训练集以及基本任务进行划分的实用工具。
功能
- 使用 TensorboardX 监控训练。
- 中断并恢复训练。
- 在 Omniglot 上进行训练和评估。
- 元批量大小 > 1。
- 在 Mini-Imagenet 上进行训练和评估。
- 澄清推断性 vs. 非推断性设置。
- 在README中添加训练曲线。
- 重现 OpenAI 代码中的所有设置。
- Shell 脚本下载数据集
如何在 Omniglot 上训练
下载 Omniglot 数据集的两个部分:
- https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip
- https://github.com/brendenlake/omniglot/blob/master/python/images_evaluation.zip
在仓库中创建一个 omniglot/
文件夹,解压并合并这两个文件以拥有如下文件结构:
./train_omniglot.py
...
./omniglot/Alphabet_of_the_Magi/
./omniglot/Angelic/
./omniglot/Anglo-Saxon_Futhorc/
...
./omniglot/ULOG/
现在开始训练
python train_omniglot.py log --cuda 0 $HYPERPARAMETERS # 使用CPU
python train_omniglot.py log $HYPERPARAMETERS # 使用CUDA
其中 $HYPERPARAMETERS 取决于你的任务和超参数。
行为:
- 如果在
log/
中未找到检查点,这将创建一个log/
文件夹以存储 tensorboard 信息和检查点。 - 如果在
log/
中找到检查点,这将从最后一个检查点恢复。
训练可以随时通过 ^C
终止,并通过重新运行相同的命令从最后一个检查点恢复。
Omniglot 超参数
以下一组超参数效果不错。
它们摘自 OpenAI 实现,但稍微进行了调整,
适用于 meta-batch=1
。
对于 5-way 5-shot(红色曲线):
python train_omniglot.py log/o55 --classes 5 --shots 5 --train-shots 10 --meta-iterations 100000 --iterations 5 --test-iterations 50 --batch 10 --meta-lr 0.2 --lr 0.001
对于 5-way 1-shot(蓝色曲线):
python train_omniglot.py log/o51 --classes 5 --shots 1 --train-shots 12 --meta-iterations 200000 --iterations 12 --test-iterations 86 --batch 10 --meta-lr 0.33 --lr 0.00044
参考文献
- 原始论文: Alex Nichol, Joshua Achiam, John Schulman. "关于一阶元学习算法".
- OpenAI 博文。 去查看一下,他们有一个完全用 Javascript 运行的在线演示!
- Tensorflow 原始代码: https://github.com/openai/supervised-reptile