DARTS简介
DARTS(Differentiable Architecture Search)是一种高效的神经网络架构搜索算法,由Hanxiao Liu等人在2018年提出。该算法基于架构空间的连续松弛和梯度下降,能够高效地为图像分类和语言建模等任务设计高性能的卷积和循环神经网络架构。DARTS只需要单个GPU就可以完成搜索过程,大大降低了硬件要求。
项目资源
DARTS的官方代码实现开源在GitHub上:
- GitHub仓库: quark0/darts
- Star数: 3.9k+
- 主要编程语言: Python
该仓库提供了DARTS算法的完整实现,包括架构搜索和评估的代码。
环境配置
DARTS的运行环境要求如下:
Python >= 3.5.5
PyTorch == 0.3.1
torchvision == 0.2.0
注意PyTorch 0.4及以上版本目前不支持,可能会导致内存溢出。
预训练模型
DARTS项目提供了在CIFAR-10、PTB和ImageNet数据集上的预训练模型:
这些预训练模型可以作为基线或者迁移学习的起点。
架构搜索
DARTS使用小型代理模型进行架构搜索,可以通过以下命令运行:
cd cnn && python train_search.py --unrolled # 用于CIFAR-10的卷积单元搜索
cd rnn && python train_search.py --unrolled # 用于PTB的循环单元搜索
需要注意的是,这一步骤中的验证性能并不代表最终架构的性能。搜索完成后,还需要使用完整大小的模型从头训练获得的基因型/架构。
架构评估
使用以下命令可以从头训练搜索得到的最佳单元:
cd cnn && python train.py --auxiliary --cutout # CIFAR-10
cd rnn && python train.py # PTB
cd cnn && python train_imagenet.py --auxiliary # ImageNet
由于cuDNN反向传播内核的非确定性,CIFAR-10的最终结果可能会有所波动。建议进行多次独立运行以获得更可靠的结果。
可视化
DARTS项目使用graphviz包来可视化学习到的单元结构:
python visualize.py DARTS
总结
DARTS提供了一种高效的神经网络架构搜索方法,大大降低了计算资源需求。该项目的开源实现为研究人员和工程师提供了方便的工具,有助于进一步探索和应用自动机器学习技术。欢迎对神经架构搜索感兴趣的读者深入学习DARTS项目的代码和相关论文。