EvoJAX:硬件加速神经进化
EvoJAX是一个可扩展、通用、硬件加速的神经进化工具包。它基于JAX库构建,使神经进化算法能够在多个TPU/GPU上并行运行神经网络。EvoJAX通过将进化算法、神经网络和任务全部用NumPy实现,并即时编译以在加速器上运行,从而实现极高的性能。
本仓库还包含了EvoJAX在广泛任务中的几个可扩展示例,包括监督学习、强化学习和生成艺术,展示了EvoJAX如何在单个加速器上几分钟内运行您的进化实验,相比之下使用CPU可能需要数小时或数天。
EvoJAX论文:https://arxiv.org/abs/2202.05008(介绍[视频](https://youtu.be/TMkft3wWpb8))
如果您希望在您的出版物中引用本项目,请使用以下BibTeX:
@article{evojax2022,
title={EvoJAX: Hardware-Accelerated Neuroevolution},
author={Tang, Yujin and Tian, Yingtao and Ha, David},
journal={arXiv preprint arXiv:2202.05008},
year={2022}
}
使用EvoJAX的出版物列表(请通过PR添加缺失的条目):
- 现代创造力进化策略:拟合具体图像和抽象概念(NeurIPS创造力研讨会2021,EvoMUSART 2022)
安装
EvoJAX是在JAX上实现的,需要先安装JAX。
安装JAX: 请首先按照JAX的安装说明进行安装,可选择GPU/TPU后端支持。 如果未设置JAX,EvoJAX安装仍会尝试拉取仅CPU版本的JAX。 请注意,Colab运行时预装了JAX。
安装EvoJAX:
# 从PyPI安装
pip install evojax
# 或者,从我们的GitHub仓库安装
pip install git+https://github.com/google/evojax.git@main
如果您还想安装某些可选功能所需的额外依赖项,请使用
pip install evojax[extra]
# 或
pip install git+https://github.com/google/evojax.git@main#egg=evojax[extra]
代码概览
EvoJAX是一个由三个主要组件组成的框架,我们期望用户对其进行扩展。
- 神经进化算法 所有神经进化算法都应实现
evojax.algo.base.NEAlgorithm
接口,并位于evojax/algo/
中。 有关EvoJAX中可用算法的信息,请参见此处。 - 策略网络 所有神经网络都应实现
evojax.policy.base.PolicyNetwork
接口,并保存在evojax/policy/
中。 在本仓库中,我们给出了MLP、ConvNet、Seq2Seq和PermutationInvariant模型的示例实现。 - 任务 所有任务都应实现
evojax.task.base.VectorizedTask
,并位于evojax/task/
中。
这些组件可以独立使用,也可以由管理训练流程的evojax.trainer
和evojax.sim_mgr
协调使用。
虽然它们应该足以满足当前提供的策略和任务的需求,但我们计划在未来根据需要扩展其功能。
示例
作为快速入门,我们提供了一些非平凡的示例(examples/
中的脚本和examples/notebooks
中的笔记本)来说明EvoJAX的用法。
我们在每个脚本的顶部提供了启动训练过程的示例命令。
这些脚本和笔记本在TPU和/或NVIDIA V100 GPU上运行:
监督学习任务
虽然在实践中显然会使用梯度下降来解决此类任务,但重点是要表明神经进化也能在短时间内以一定程度的准确性解决这些问题,这在将这些模型适应到更复杂的任务中时会很有用,因为在那些情况下基于梯度的方法可能不起作用。
- MNIST分类 - 我们展示了EvoJAX在单个GPU上5分钟内训练ConvNet策略,达到>98%的测试准确率。
- Seq2Seq学习 - 我们演示了EvoJAX能够学习一个具有数十万参数的大型网络来完成seq2seq任务。
经典控制任务
包含控制任务有两个目的:1) 与监督学习任务不同,EvoJAX中的控制任务步数不确定,因此我们使用这些示例来展示我们任务展开循环的效率。2) 我们希望展示在JAX中实现任务的加速优势,并说明如何从头开始实现一个任务。
- 运动 - Brax是一个用JAX实现的可微分物理引擎。 我们将其包装为一个任务,并在GPU/TPU上使用EvoJAX进行训练。EvoJAX只需几十分钟就能解决Brax中的运动任务。
- 倒立摆摆动 - 我们说明了如何在JAX中实现经典控制任务,并将其集成到EvoJAX的流程中,以显著加快训练速度。
新颖任务
在最后这一类中,我们超越简单的说明,展示了更实用和吸引遗传和进化计算领域研究人员的新颖任务示例,目的是帮助他们在EvoJAX中尝试想法。
多智能体水世界 | ES-CLIP:"一幅猫的画" | 史莱姆排球 |
-
抽象画(笔记本1和笔记本2) - 我们复现了这项计算创造力工作的结果,并展示了如何使用EvoJAX在单个GPU上高效地加速原本需要多个CPU和GPU的实现,这在之前是不可能的。 此外,使用多个GPU/TPU,EvoJAX可以进一步将上述工作的速度提升近乎线性。 我们还展示了EvoJAX的模块化设计允许其组件被独立使用 - 在这种情况下,可以只使用EvoJAX的ES算法,同时利用自己的训练循环和环境实现。
-
神经网络史莱姆排球 - 在这个任务中,智能体的目标是让球落在对手一侧的地面上,使对手失去一条生命。当任一智能体失去全部五条生命或达到时间限制时,回合结束。当对手失去生命时,智能体获得+1的奖励,当自己失去生命时获得-1的奖励。 EvoJAX能够在单个GPU上用不到5分钟训练出智能体,相比之下在多个CPU上需要数小时。 这个实现基于Slime Volleyball Gym环境,它是原始JavaScript版游戏的Python移植版。在所有这些版本中,内置的AI对手和不太理想的物理引擎都是相同的。
征集贡献
EvoJAX的目标是让进化计算能够在加速器上处理大量任务。
之前的一个问题是,许多进化算法只针对某篇论文中的特定任务进行了优化。这就是为什么在EvoJAX的第一个版本中,我们只关注一种算法(PGPE),同时创建了6个以上不同领域的多样化任务,确保单一算法可以毫无问题地适用于所有任务。查看已贡献算法的表格。
进化算法
我们欢迎新的进化算法被添加到这个工具包中。如果你能在提交拉取请求之前,展示你的实现可以在倒立摆(硬模式)、BRAX、水世界和MNIST上表现良好,那将会很棒。
进化算法候选ideas:
- 你最喜欢的遗传算法。
- CMA-ES(基础版本,以及改进版本如BIPOP-CMA-ES)
- 增强随机搜索(论文)
- AMaLGaM-IDEA(论文)
我们建议新算法遵循以下性能指南:
- MNIST: 90%以上
- 倒立摆: 900分以上(简单),600分以上(困难)
- 水世界: 6分以上(单智能体),2分以上(多智能体)
- Brax蚂蚁: 3000分以上
请注意,这些不是硬性要求,只是粗略的指南。
在向我们发送PR之前,请使用基准测试脚本评估你的算法,如果由于硬件限制无法测试某些任务,请告诉我们。 查看这个示例拉取请求线程,了解如何将遗传算法合并到EvoJAX中。
如果你想进一步讨论,欢迎联系evojax-dev@google.com或evojax-dev@googlegroups.com。
新任务
我们也欢迎新的任务和示例(查看这里了解EvoJAX中的所有任务)。一些建议:
- 使用进化训练神经图灵机来创建一个排序算法。
- 通过自我对弈的足球(示例)
- 进化具有Hebbian学习能力的可塑网络,可以从智能体最近的经历中记住迷宫地图。
- 执行需要未知步骤数的任务的RNN自适应计算时间。
- 使用硬注意力的任务。
姊妹项目
越来越多的研究人员在使用JAX进行进化计算。以下是相关工作列表:
-
QDax: 加速的质量多样性。一个使用JAX帮助通过硬件加速器和大规模并行加速质量多样性(QD)算法的工具。(GitHub | 论文)
-
evosax: 一个基于JAX的进化策略库,专注于JAX可组合的ask-tell功能和策略多样性。实现了10多种ES算法。(GitHub)
免责声明
这不是Google的官方产品。