Project Icon

aw_nas

模块化设计实现多种NAS算法

aw_nas是一个模块化的神经架构搜索框架,实现了ENAS、DARTS等多种主流NAS算法。框架将NAS系统分解为搜索空间、控制器等组件,通过接口实现灵活组合。支持分类、检测等多种应用场景,并提供硬件分析接口。aw_nas采用插件机制便于扩展,已应用于容错性、对抗鲁棒性等研究方向。

aw_nas:模块化可扩展的神经架构搜索框架

清华大学NICS-EFC实验室北京诺瓦奥图科技有限公司维护。

简介

神经架构搜索(NAS)因其能够以自动化方式发现神经网络架构而受到广泛关注。aw_nas是一个以模块化方式实现各种NAS算法的框架。目前,aw_nas可用于重现许多主流NAS算法的结果,如ENAS、DARTS、SNAS、FBNet、OFA、基于预测器的NAS等。我们已将aw_nas应用于各种应用场景,包括用于分类、检测、文本建模、硬件容错、对抗鲁棒性、硬件推理效率等的NAS。

此外,硬件相关的性能分析和解析接口设计得通用且易用。aw_nas还提供了多种硬件的延迟表和一些校正模型。详情请参见硬件相关

欢迎各种贡献,包括新的NAS组件实现、新的NAS应用、错误修复、文档等。

NAS系统的组成部分

NAS系统中有多个相互协作的参与者,可以分为以下几个组成部分:

  • 搜索空间
  • 控制器
  • 权重管理器
  • 评估器
  • 目标函数

这些组件之间的接口是明确定义的。我们使用awnas.rollout.base.BaseRollout类来表示所有这些组件之间的接口对象。通常,一个搜索空间定义一个或多个rollout类型(BaseRollout的子类)。例如,基本的基于单元的搜索空间cnnawnas.common.CNNSearchSpace类)对应两种rollout类型:discrete离散rollout,用于基于强化学习、进化算法的控制器等(awnas.rollout.base.Rollout类);differentiable可微rollout,用于基于梯度的NAS(awnas.rollout.base.DifferentiableRollout类)。

NAS框架

这是NAS流程和相应方法调用的图示。这里是aw_nas的简要技术概述,包括一些复现结果和硬件成本预测模型的描述。该技术概述也可在arXiv上获取(GitHub/ArXiv版本可能略有不同)。

安装

建议使用虚拟Python环境。例如,使用Anaconda,你可以先运行conda create -n awnas python==3.7.3 pip

  • 支持的Python版本:2.7、3.6、3.7
  • 支持的PyTorch版本:>=1.0.0,<1.5.0(目前,DataParallel复制中的一些补丁在1.5.0之后不兼容)

要安装awnas,运行pip install -r requirements.txt。如果你不想安装检测相关的额外内容(运行在VOC/COCO检测数据集上搜索时需要),在安装时省略",det"额外内容(参见requirements文件的最后一行)。注意,对于RTX 3090,requirements.txt中的torch==1.2.0不再适用:使用torch会导致永久卡住。请查看requirements.cu110.txt中的注释。

架构绘图依赖于graphviz包,确保安装了graphviz,例如在Ubuntu上,你可以运行sudo apt-get install graphviz

使用

安装后,你可以运行awnas --help查看可用的子命令。

示例运行输出(版本0.3.dev3):

07/04 11:41:44 PM plugin              INFO: Check plugins under /home/foxfi/awnas/plugins
07/04 11:41:44 PM plugin              INFO: Loaded plugins:
Usage: awnas [OPTIONS] COMMAND [ARGS]...

  awnas NAS框架命令行接口。使用`AWNAS_LOG_LEVEL`环境变量修改日志级别。

Options:
  --version             显示版本并退出
  --local_rank INTEGER  此进程的等级  [默认:-1]
  --help                显示此消息并退出

Commands:
  search                   搜索架构
  mpsearch                 多进程搜索架构
  random-sample            随机采样架构
  sample                   采样架构,加载pickle控制器
  eval-arch                从文件评估架构
  derive                   派生架构
  mptrain                  多进程最终训练架构
  train                    训练一个架构
  test                     测试最终训练的模型
  gen-sample-config        导出采样配置
  gen-final-sample-config  导出最终训练的采样配置
  registry                 打印注册信息

准备数据

运行awnas程序时,它会假设名为<NAME>的数据集位于AWNAS_DATA/<NAME>下,其中AWNAS_DATA基础目录从环境变量AWNAS_DATA中读取。如果未指定环境变量,默认为AWNAS_HOME/data,其中AWNAS_HOME是默认为~/awnas的环境变量。

  • Cifar-10/Cifar-100:无需特殊准备。
  • PTB:执行bash scripts/get_data.sh ptb,PTB数据将下载到${DATA_BASE}/ptb目录下。默认情况下${DATA_BASE}~/awnas/data
  • Tiny-ImageNet:执行bash scripts/get_data.sh tiny-imagenet,Tiny-ImageNet数据将下载到${DATA_BASE}/tiny-imagenet目录下。
  • 目标检测数据集VOC/COCO:执行bash scripts/get_data.sh vocbash scripts/get_data.sh coco

运行NAS搜索

ENAS 尝试运行ENAS [Pham et. al., ICML 2018]搜索(结果包括配置备份、搜索日志,保存在<TRAIN_DIR>中):

awnas search examples/basic/enas.yaml --gpu 0 --save-every <SAVE_EVERY> --train-dir <TRAIN_DIR>

配置文件中包含几个部分,描述了NAS框架中不同组件的配置。例如,在example/basic/enas.yaml中,不同的配置部分组织如下:

  1. 基于单元的CNN搜索空间:这是原始ENAS论文中5个原语微搜索空间的扩展版本。
  2. cifar-10数据集
  3. 使用embed_lstm RNN网络的RL学习控制器
  4. 基于共享权重的评估器
  5. 基于共享权重的权重管理器:超网络
  6. 分类目标
  7. 训练器:整体NAS搜索流程的编排

有关ENAS搜索配置的详细分解,请参阅配置说明

DARTS 此外,你可以通过运行以下命令来执行DARTS [Liu et. al., ICLR 2018]搜索的改进版本:

awnas search examples/basic/darts.yaml --gpu 0 --save-every <SAVE_EVERY> --train-dir <TRAIN_DIR>

我们在这里提供了组件和流程的详细说明。请注意,该配置与原始DARTS略有不同:1) entropy_coeff: 0.01:使用0.01的熵正则化系数,鼓励操作分布更接近于one-hot;2) use_prob: false:使用Gumbel-softmax采样,而不是直接使用概率。

结果复现 关于各种流行方法的精确结果复现,请参阅examples/mloss/下的文档、配置和结果。

生成样例搜索配置

要生成用于搜索的样例配置文件,可以尝试使用awnas gen-sample-config工具。例如,如果你想要一个用于在NAS-Bench-101上搜索的样例配置,运行:

awnas gen-sample-config -r nasbench-101 -d image ./sample_nb101.yaml

然后,检查sample_nb101.yaml文件,对于每种组件类型,所有声明支持nasbench-101展开类型的类都会列在文件中。删除不需要的,取消注释需要的,更改默认设置,然后该配置就可以用于在NAS-Bench-101上运行NAS。

导出与评估架构

awnas derive工具使用训练好的NAS组件采样架构。如果--test标志关闭(默认),只加载控制器来采样展开;否则,还会加载权重管理器和训练器来测试这些展开,并根据性能对采样的基因型进行排序,保存在输出文件中。

示例运行是采样10个基因型,并将它们保存到sampled_genotypes.yaml中。

awnas derive search_cfg.yaml --load <awnas搜索期间保存的检查点目录> -o sampled_genotypes.yaml -n 10 --test --gpu 0 --seed 123

注意,<TRAIN_DIR>/<EPOCH>/文件夹中的"controller/evaluator/trainer"文件包含组件的状态字典,可以加载(每<SAVE_EVERY>个周期保存一次),而"<TRAIN_DIR>/final/"文件夹中的最终检查点"controller.pt/evaluator.pt"包含整个组件对象的pickle,不能直接加载。如果你忘记指定--save-every命令行参数而没有获得状态字典检查点,你可以加载最终检查点,然后通过cd <TRAIN_DIR>/final/; python -c "controller = torch.load('./controller.pt'); controller.save('controller')"导出所需的状态字典检查点。

awnas eval-arch工具使用训练好的NAS组件评估基因型。给定一个包含基因型列表的yaml文件,可以使用保存的NAS检查点评估这些基因型:

awnas eval-arch search_cfg.yaml sampled_genotypes.yaml --load <awnas搜索期间保存的检查点目录> --gpu 0 --seed 123

基于单元架构的最终训练

awnas.final 子包提供了基于单元的架构的最终训练功能。examples/basic/final_templates/final_template.yaml 是一个常用的配置模板,用于在类 ENAS 搜索空间中进行架构的最终训练。要使用该模板,请在 final_model_cfg.genotypes 字段中填入从搜索过程中得到的基因型字符串。基因型字符串示例如下:

CNNGenotype(normal_0=[('dil_conv_3x3', 1, 2), ('skip_connect', 1, 2), ('sep_conv_3x3', 0, 3), ('sep_conv_3x3', 2, 3), ('skip_connect', 3, 4), ('sep_conv_3x3', 0, 4), ('sep_conv_5x5', 1, 5), ('sep_conv_5x5', 0, 5)], reduce_1=[('max_pool_3x3', 0, 2), ('dil_conv_5x5', 0, 2), ('avg_pool_3x3', 1, 3), ('avg_pool_3x3', 2, 3), ('sep_conv_5x5', 1, 4), ('avg_pool_3x3', 1, 4), ('sep_conv_3x3', 1, 5), ('dil_conv_5x5', 3, 5)], normal_0_concat=[2, 3, 4, 5], reduce_1_concat=[2, 3, 4, 5])

插件机制

aw_nas 提供了一个简单的插件机制,支持在包外添加额外组件或扩展现有组件。在初始化过程中,~/awnas/plugins/ 目录下的所有 Python 脚本(文件名以 .py 结尾,不包括以 test_ 开头的文件)都会被导入。因此,这些文件中定义的组件将自动注册。

例如,为了复现 FBNet [Wu et. al., CVPR 2019],我们在 examples/plugins/fbnet/fbnet_plugin.py 中添加了 FBNet 原始块的实现,并使用 aw_nas.ops.register_primitive 注册这些原始操作。为了重用 DiffSuperNet 实现的大部分代码(用于 DARTS [Liu et. al., ICLR 2018]、SNAS [Xie et. al., ICLR 2018] 等),我们创建了一个继承自 DiffSuperNetWeightInitDiffSuperNet 类,唯一的区别是添加了一个为 FBNet 量身定制的权重初始化。此外,还实现了一个 LatencyObjective 目标函数,它将延迟损失和交叉熵损失的加权和作为损失计算。

examples/plugins/robustness 目录下是用于实现对抗鲁棒性神经架构搜索的插件模块。例如,定义了各种用于评估对抗鲁棒性的目标函数。由于密集连接是对抗鲁棒性的一个重要特性,而 ENAS/DARTS 搜索空间将节点输入度限制为小于或等于 2,因此定义了一个具有可变节点输入度的新搜索空间。实现了几个具有对抗样本缓存的超网络(weights_manager),以避免多次为同一子网络重新生成对抗样本。

除了定义新组件外,你还可以使用这种机制来进行猴子补丁技巧。例如,在 examples/research/ftt-nas/fixed_point_plugins/ 下有各种定点插件。在这些插件中,诸如 nn.Conv2dnn.Linear 等原始操作被修补为具有量化和故障注入功能的模块。

硬件相关:硬件分析和解析

有关硬件分析和解析的流程和示例,请参阅 Hardware related

开发新组件

有关开发新组件的指南,请参阅 Develop New Components

研究

本代码库与以下研究相关(*: 贡献相同; ^: 共同通讯作者)

更多详情请参见examples/research/下的子目录。

如果您发现本代码库有帮助,可以引用以下研究:

@misc{ning2020awnas,
      title={aw_nas: A Modularized and Extensible NAS framework},
      author={Xuefei Ning and Changcheng Tang and Wenshuo Li and Songyi Yang and Tianchen Zhao and Niansong Zhang and Tianyi Lu and Shuang Liang and Huazhong Yang and Yu Wang},
      year={2020},
      eprint={2012.10388},
      archivePrefix={arXiv},
      primaryClass={cs.NE}
}

参考文献

  • FBNet Wu, Bichen等人。"FBNet:通过可微分神经架构搜索进行硬件感知的高效卷积网络设计"。发表于IEEE计算机视觉与模式识别会议论文集,第10734-10742页。2019年。
  • ENAS Pham, Hieu等人。"通过参数共享实现高效神经架构搜索"。发表于国际机器学习会议,第4095-4104页。2018年。
  • DARTS Liu, Hanxiao等人。"DARTS:可微分架构搜索"。发表于国际学习表示会议。2018年。
  • SNAS Xie, Sirui等人。"SNAS:随机神经架构搜索"。发表于国际学习表示会议。2018年。
  • OFA Cai, Han等人。"一劳永逸:训练一个网络并针对高效部署进行专门化"。发表于国际学习表示会议。2019年。

单元测试

覆盖率百分比(版本0.4.0-dev1)

运行pytest -x ./tests来执行单元测试。

NAS-Bench-101和NAS-Bench-201的测试默认被跳过,设置AWNAS_TEST_NASBENCH环境变量并运行pytest来执行这些测试:AWNAS_TEST_NASBENCH=1 pytest -x ./tests/test_nasbench*。还有一些其他测试由于可能非常耗时而被跳过(参见测试输出(标记为"s")和tests/下的测试用例)。

联系我们

  • 如有技术问题或改进建议,请在Github上提交问题,我们是一个小团队,但会尽最大努力及时回复。
  • 如果想讨论NAS或高效深度学习,请通过foxdoraame@gmail.com(宁学飞)和yu-wang@tsinghua.edu.cn(王玉)联系我们。
  • 我们的团队正在招募访问学生和工程师,如果您感兴趣,请查看我们网站上的信息。
项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

吐司

探索Tensor.Art平台的独特AI模型,免费访问各种图像生成与AI训练工具,从Stable Diffusion等基础模型开始,轻松实现创新图像生成。体验前沿的AI技术,推动个人和企业的创新发展。

Project Cover

SubCat字幕猫

SubCat字幕猫APP是一款创新的视频播放器,它将改变您观看视频的方式!SubCat结合了先进的人工智能技术,为您提供即时视频字幕翻译,无论是本地视频还是网络流媒体,让您轻松享受各种语言的内容。

Project Cover

美间AI

美间AI创意设计平台,利用前沿AI技术,为设计师和营销人员提供一站式设计解决方案。从智能海报到3D效果图,再到文案生成,美间让创意设计更简单、更高效。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号