evosax
:基于JAX的进化策略 🦎
厌倦了为神经进化处理异步进程?想要利用大规模向量化和高吞吐量加速器来实现进化策略(ES)吗?evosax
允许您利用JAX、XLA编译和自动向量化/并行化来将ES扩展到您喜欢的加速器上。其API基于ES的经典"询问"、"评估"、"告知"循环。"询问"和"告知"调用都与jit
、vmap
/pmap
和lax.scan
兼容。它包含了大量经典(如CMA-ES、差分进化等)和现代神经进化(如OpenAI-ES、增强RS等)策略。您可以在这里开始 👉
evosax
API基本用法 🍲
import jax
from evosax import CMA_ES
# 实例化搜索策略
rng = jax.random.PRNGKey(0)
strategy = CMA_ES(popsize=20, num_dims=2, elite_ratio=0.5)
es_params = strategy.default_params
state = strategy.initialize(rng, es_params)
# 运行询问-评估-告知循环 - 注意:默认为最小化!
for t in range(num_generations):
rng, rng_gen, rng_eval = jax.random.split(rng, 3)
x, state = strategy.ask(rng_gen, state, es_params)
fitness = ... # 您的种群评估函数
state = strategy.tell(x, fitness, state, es_params)
# 获取整体最佳种群成员及其适应度
state.best_member, state.best_fitness
已实现的进化策略 🦎
安装 ⏳
最新版本的 evosax
可以直接从 PyPI 安装:
pip install evosax
如果你想获取最新的提交,请直接从仓库安装:
pip install git+https://github.com/RobertTLange/evosax.git@main
要在你的加速器上使用 JAX,可以在 JAX 文档中找到更多详细信息。
示例 📖
- 📓 经典 ES 任务: 以 Rosenbrock 函数为例介绍 API (CMA-ES, Simple GA 等)。
- 📓 CartPole 控制: 在
CartPole-v1
gym 任务上使用 OpenES 和 PEPG (MLP/LSTM 控制器)。 - 📓 MNIST 分类器: 在 MNIST 上使用带 CNN 网络的 OpenES。
- 📓 学习率调优-PES: 在元学习问题上使用持久/噪声重用 ES,如 Vicol et al. (2021) 所述。
- 📓 二次函数-PBT: 在玩具二次问题上使用 PBT,如 Jaderberg et al. (2017) 所述。
- 📓 重启包装器: 自定义重启包装器,例如在 (B)IPOP-CMA-ES 中使用。
- 📓 Brax 控制: 使用
EvoJAX
包装器在 Brax 任务上进化 Tanh MLPs。 - 📓 BBOB 可视化器: 在 2D 适应度景观上可视化进化过程。
主要特性 💵
-
策略多样性:
evosax
实现了超过 30 种经典和现代神经进化策略。它们都遵循相同的简单ask
/eval
API,并配备了定制工具,如 ClipUp 优化器、参数重塑为 PyTrees 和适应度整形(见下文)。 -
ask
/tell
调用的向量化/并行化:ask
和tell
调用都可以利用jit
、vmap
/pmap
。这使得不同进化策略的向量化/并行化推演成为可能。
from evosax.strategies.ars import ARS, EvoParams
# 例如,对不同的初始扰动标准差进行向量化
strategy = ARS(popsize=100, num_dims=20)
es_params = EvoParams(sigma_init=jnp.array([0.1, 0.01, 0.001]), sigma_decay=0.999, ...)
# 指定如何映射 ES 超参数
map_dict = EvoParams(sigma_init=0, sigma_decay=None, ...)
# Vmap 组合的批量初始化、ask 和 tell 函数
batch_init = jax.vmap(strategy.init, in_axes=(None, map_dict))
batch_ask = jax.vmap(strategy.ask, in_axes=(None, 0, map_dict))
batch_tell = jax.vmap(strategy.tell, in_axes=(0, 0, 0, map_dict))
- 扫描进化推演: 你还可以使用
lax.scan
扫描整个init
、ask
、eval
、tell
循环,以快速编译 ES 循环:
@partial(jax.jit, static_argnums=(1,))
def run_es_loop(rng, num_steps):
"""运行进化 ask-eval-tell 循环。"""
es_params = strategy.default_params
state = strategy.initialize(rng, es_params)
def es_step(state_input, tmp):
"""用于 lax.scan 的辅助 es 步骤。"""
rng, state = state_input
rng, rng_iter = jax.random.split(rng)
x, state = strategy.ask(rng_iter, state, es_params)
fitness = ...
state = strategy.tell(y, fitness, state, es_params)
return [rng, state], fitness[jnp.argmin(fitness)]
_, scan_out = jax.lax.scan(es_step,
[rng, state],
[jnp.zeros(num_steps)])
return jnp.min(scan_out)
- 群体参数重塑: 我们提供了一个
ParameterReshaper
包装器,用于将平坦的参数向量重塑为 PyTrees。该包装器与 JAX 神经网络库(如 Flax/Haiku)兼容,使后续评估网络群体变得更加容易。
from flax import linen as nn
from evosax import ParameterReshaper
class MLP(nn.Module):
num_hidden_units: int
...
@nn.compact
def __call__(self, obs):
...
return ...
network = MLP(64)
net_params = network.init(rng, jnp.zeros(4,), rng)
# 根据占位网络形状初始化重塑器
param_reshaper = ParameterReshaper(net_params)
# 获取群体候选项并重塑为堆叠的 pytrees
x = strategy.ask(...)
x_shaped = param_reshaper.reshape(x)
- 灵活的适应度整形: 默认情况下,
evosax
假设适应度目标是要最小化的。如果你想最大化、执行排名居中、z 分数标准化或添加权重正则化,你可以使用FitnessShaper
:
from evosax import FitnessShaper
# 实例化可即时编译的适应度整形器(例如用于 Open ES)
fit_shaper = FitnessShaper(centered_rank=True,
z_score=False,
weight_decay=0.01,
maximize=True)
# 整形评估得到的适应度分数
fit_shaped = fit_shaper.apply(x, fitness)
资源和其他优秀的JAX-ES工具 📝
- 📺 Rob在MLC研究会议上的演讲:在ML Collective研究会议上的简短动机演讲。
- 📝 Rob的2021年2月博客:关于CMA-ES和利用JAX原语的教程。
- 💻 Evojax:Google Brain开发的JAX-ES库,具有出色的rollout包装器。
- 💻 QDax:JAX中的质量多样性算法。
致谢和引用 evosax
✏️
如果您在研究中使用了 evosax
,请引用以下论文:
@article{evosax2022github,
author = {Robert Tjarko Lange},
title = {evosax: JAX-based Evolution Strategies},
journal={arXiv preprint arXiv:2212.04180},
year = {2022},
}
我们感谢Google TRC和德国研究基金会(DFG,Deutsche Forschungsgemeinschaft)在德国卓越战略框架下对"智能科学"项目(项目编号390523135)的资金支持。
开发 👷
您可以通过运行 python -m pytest -vv --all
来执行测试套件。如果您发现了bug或缺少您喜欢的功能,欢迎创建issue和/或开始贡献 🤗。
免责声明 ⚠️
本仓库包含基于ICLR 2023发表的论文(Lange et al., 2023)的LES和DES的独立重新实现。它与Google或DeepMind无关。该实现已经过测试,在一系列任务中大致重现了官方结果。