Project Icon

evosax

基于JAX的高性能进化策略框架

evosax是基于JAX的进化策略框架,通过XLA编译和自动向量化/并行化技术实现大规模进化策略的高效计算。它支持CMA-ES、OpenAI-ES等多种经典和现代神经进化算法,采用ask-evaluate-tell API设计。evosax兼容JAX的jit、vmap和lax.scan,可扩展至不同硬件加速器。该框架为进化计算研究和应用提供了高性能、灵活的工具。

evosax:基于JAX的进化策略 🦎

Python版本 PyPI版本 代码风格:black codecov 论文

厌倦了为神经进化处理异步进程?想要利用大规模向量化和高吞吐量加速器来实现进化策略(ES)吗?evosax允许您利用JAX、XLA编译和自动向量化/并行化来将ES扩展到您喜欢的加速器上。其API基于ES的经典"询问"、"评估"、"告知"循环。"询问"和"告知"调用都与jitvmap/pmaplax.scan兼容。它包含了大量经典(如CMA-ES、差分进化等)和现代神经进化(如OpenAI-ES、增强RS等)策略。您可以在这里开始 👉 Colab

evosax API基本用法 🍲

import jax
from evosax import CMA_ES

# 实例化搜索策略
rng = jax.random.PRNGKey(0)
strategy = CMA_ES(popsize=20, num_dims=2, elite_ratio=0.5)
es_params = strategy.default_params
state = strategy.initialize(rng, es_params)

# 运行询问-评估-告知循环 - 注意:默认为最小化!
for t in range(num_generations):
    rng, rng_gen, rng_eval = jax.random.split(rng, 3)
    x, state = strategy.ask(rng_gen, state, es_params)
    fitness = ...  # 您的种群评估函数 
    state = strategy.tell(x, fitness, state, es_params)

# 获取整体最佳种群成员及其适应度
state.best_member, state.best_fitness

已实现的进化策略 🦎

安装 ⏳

最新版本的 evosax 可以直接从 PyPI 安装:

pip install evosax

如果你想获取最新的提交,请直接从仓库安装:

pip install git+https://github.com/RobertTLange/evosax.git@main

要在你的加速器上使用 JAX,可以在 JAX 文档中找到更多详细信息。

示例 📖

主要特性 💵

  • 策略多样性: evosax 实现了超过 30 种经典和现代神经进化策略。它们都遵循相同的简单 ask/eval API,并配备了定制工具,如 ClipUp 优化器、参数重塑为 PyTrees 和适应度整形(见下文)。

  • ask/tell 调用的向量化/并行化: asktell 调用都可以利用 jitvmap/pmap。这使得不同进化策略的向量化/并行化推演成为可能。

from evosax.strategies.ars import ARS, EvoParams
# 例如,对不同的初始扰动标准差进行向量化
strategy = ARS(popsize=100, num_dims=20)
es_params = EvoParams(sigma_init=jnp.array([0.1, 0.01, 0.001]), sigma_decay=0.999, ...)

# 指定如何映射 ES 超参数
map_dict = EvoParams(sigma_init=0, sigma_decay=None, ...)

# Vmap 组合的批量初始化、ask 和 tell 函数
batch_init = jax.vmap(strategy.init, in_axes=(None, map_dict))
batch_ask = jax.vmap(strategy.ask, in_axes=(None, 0, map_dict))
batch_tell = jax.vmap(strategy.tell, in_axes=(0, 0, 0, map_dict))
  • 扫描进化推演: 你还可以使用 lax.scan 扫描整个 initaskevaltell 循环,以快速编译 ES 循环:
@partial(jax.jit, static_argnums=(1,))
def run_es_loop(rng, num_steps):
    """运行进化 ask-eval-tell 循环。"""
    es_params = strategy.default_params
    state = strategy.initialize(rng, es_params)

    def es_step(state_input, tmp):
        """用于 lax.scan 的辅助 es 步骤。"""
        rng, state = state_input
        rng, rng_iter = jax.random.split(rng)
        x, state = strategy.ask(rng_iter, state, es_params)
        fitness = ...
        state = strategy.tell(y, fitness, state, es_params)
        return [rng, state], fitness[jnp.argmin(fitness)]

    _, scan_out = jax.lax.scan(es_step,
                               [rng, state],
                               [jnp.zeros(num_steps)])
    return jnp.min(scan_out)
  • 群体参数重塑: 我们提供了一个 ParameterReshaper 包装器,用于将平坦的参数向量重塑为 PyTrees。该包装器与 JAX 神经网络库(如 Flax/Haiku)兼容,使后续评估网络群体变得更加容易。
from flax import linen as nn
from evosax import ParameterReshaper

class MLP(nn.Module):
    num_hidden_units: int
    ...

    @nn.compact
    def __call__(self, obs):
        ...
        return ...

network = MLP(64)
net_params = network.init(rng, jnp.zeros(4,), rng)

# 根据占位网络形状初始化重塑器
param_reshaper = ParameterReshaper(net_params)

# 获取群体候选项并重塑为堆叠的 pytrees
x = strategy.ask(...)
x_shaped = param_reshaper.reshape(x)
  • 灵活的适应度整形: 默认情况下,evosax 假设适应度目标是要最小化的。如果你想最大化、执行排名居中、z 分数标准化或添加权重正则化,你可以使用 FitnessShaper:
from evosax import FitnessShaper

# 实例化可即时编译的适应度整形器(例如用于 Open ES)
fit_shaper = FitnessShaper(centered_rank=True,
                           z_score=False,
                           weight_decay=0.01,
                           maximize=True)

# 整形评估得到的适应度分数
fit_shaped = fit_shaper.apply(x, fitness) 

资源和其他优秀的JAX-ES工具 📝

致谢和引用 evosax ✏️

如果您在研究中使用了 evosax,请引用以下论文

@article{evosax2022github,
  author = {Robert Tjarko Lange},
  title = {evosax: JAX-based Evolution Strategies},
  journal={arXiv preprint arXiv:2212.04180},
  year = {2022},
}

我们感谢Google TRC和德国研究基金会(DFG,Deutsche Forschungsgemeinschaft)在德国卓越战略框架下对"智能科学"项目(项目编号390523135)的资金支持。

开发 👷

您可以通过运行 python -m pytest -vv --all 来执行测试套件。如果您发现了bug或缺少您喜欢的功能,欢迎创建issue和/或开始贡献 🤗。

免责声明 ⚠️

本仓库包含基于ICLR 2023发表的论文(Lange et al., 2023)的LES和DES的独立重新实现。它与Google或DeepMind无关。该实现已经过测试,在一系列任务中大致重现了官方结果。

项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

白日梦AI

白日梦AI提供专注于AI视频生成的多样化功能,包括文生视频、动态画面和形象生成等,帮助用户快速上手,创造专业级内容。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

讯飞绘镜

讯飞绘镜是一个支持从创意到完整视频创作的智能平台,用户可以快速生成视频素材并创作独特的音乐视频和故事。平台提供多样化的主题和精选作品,帮助用户探索创意灵感。

Project Cover

讯飞文书

讯飞文书依托讯飞星火大模型,为文书写作者提供从素材筹备到稿件撰写及审稿的全程支持。通过录音智记和以稿写稿等功能,满足事务性工作的高频需求,帮助撰稿人节省精力,提高效率,优化工作与生活。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号