可微架构搜索
论文的相关代码
DARTS: Differentiable Architecture Search
刘汉霄,Karen Simonyan,杨一鸣。
arXiv:1806.09055。
该算法基于架构空间中的连续松弛和梯度下降,能够高效设计出用于图像分类(在 CIFAR-10 和 ImageNet 上)和语言建模(在 Penn Treebank 和 WikiText-2 上)的高性能卷积架构和递归架构。仅需单个 GPU 即可运行。
环境需求
Python >= 3.5.5, PyTorch == 0.3.1, torchvision == 0.2.0
注意:目前不支持 PyTorch 0.4,会导致 OOM(内存不足)。
数据集
关于获取 PTB 和 WT2 的说明可以在这里找到。而 CIFAR-10 可以通过 torchvision 自动下载,ImageNet 需要手动下载(最好下载到 SSD)并按照这里的说明进行操作。
预训练模型
最简单的入门方法是评估我们预训练的 DARTS 模型。
CIFAR-10 (cifar10_model.pt)
cd cnn && python test.py --auxiliary --model_path cifar10_model.pt
- 预期结果:测试错误率 2.63%,模型参数 3.3M。
PTB (ptb_model.pt)
cd rnn && python test.py --model_path ptb_model.pt
- 预期结果:测试困惑度 55.68,模型参数 23M。
ImageNet (imagenet_model.pt)
cd cnn && python test_imagenet.py --auxiliary --model_path imagenet_model.pt
- 预期结果:Top-1 错误率 26.7%,Top-5 错误率 8.7%,模型参数 4.7M。
架构搜索(使用小型代理模型)
要使用二阶近似进行架构搜索,运行以下命令:
cd cnn && python train_search.py --unrolled # 在 CIFAR-10 上搜索卷积单元
cd rnn && python train_search.py --unrolled # 在 PTB 上搜索递归单元
请注意,此步骤中的验证性能不代表最终架构的性能。必须使用完整尺寸的模型从头开始训练获得的基因型/架构,正如下一节所述。
还要注意,不同的运行可能会导致不同的局部最小值。要获得最佳结果,必须使用不同的随机种子重复搜索过程,并根据验证性能(通过从头开始训练派生单元进行少量轮次的训练获得)选择最佳单元。请参阅我们的 arXiv 论文中的图 3 和章节 3.2。
图:随着时间推移,最有可能的普通卷积单元、降维卷积单元和递归单元的快照。
架构评估(使用完整尺寸的模型)
要从头开始训练以评估我们最好的单元,运行以下命令:
cd cnn && python train.py --auxiliary --cutout # CIFAR-10
cd rnn && python train.py # PTB
cd rnn && python train.py --data ../data/wikitext-2 \ # WT2
--dropouth 0.15 --emsize 700 --nhidlast 700 --nhid 700 --wdecay 5e-7
cd cnn && python train_imagenet.py --auxiliary # ImageNet
一旦在 genotypes.py
中指定,定制架构也可通过 --arch
标志来支持。
由于 cuDNN 反向传播内核的不确定性,CIFAR-10 最终的训练结果会存在一定的方差。仅报告单次运行的结果可能会产生误导。通过从头开始训练我们最好的单元,10 次独立运行的测试错误率平均值很可能会落在 2.76 +/- 0.09% 的范围内。
图:CIFAR-10(4 次运行)、ImageNet 和 PTB 的预期学习曲线。
可视化
需要使用 graphviz 软件包来可视化学习到的单元:
python visualize.py DARTS
其中 DARTS
可以被替换为 genotypes.py
中的任何自定义架构。
引用
如果您在研究中使用了此代码的任何部分,请引用我们的论文:
@article{liu2018darts,
title={DARTS: Differentiable Architecture Search},
author={Liu, Hanxiao and Simonyan, Karen and Yang, Yiming},
journal={arXiv preprint arXiv:1806.09055},
year={2018}
}