QDax: 加速的质量多样性算法
QDax是一个通过硬件加速器和大规模并行化来加速质量多样性(QD)和神经进化算法的工具。QD算法通常需要在大型CPU集群上运行数天/数周。使用QDax,QD算法现在可以在几分钟内完成!⏩ ⏩ 🕛
QDax被开发为一个研究框架:它灵活易扩展,可用于任何问题设置。从这里开始简单示例,几分钟内运行QD算法!
安装
QDax可在PyPI上获取,使用以下命令安装:
pip install qdax
或者,可以直接从源代码安装QDax的最新提交:
pip install git+https://github.com/adaptive-intelligent-robotics/QDax.git@main
通过pip
安装QDax默认安装仅CPU版本的JAX。要在NVidia GPU上使用QDax,您必须先安装CUDA、CuDNN和支持GPU的JAX。
然而,我们还提供并推荐使用Docker或conda环境来使用该仓库,默认提供GPU支持。详细步骤可在文档中找到。
基本API使用
要全面了解QDax的工作原理,我们建议从教程风格的Colab笔记本开始。这是一个使用MAP-Elites算法在选定的Brax环境(默认为Walker)中进化控制器群体的示例。
以下是主要API使用的摘要:
import jax
import functools
from qdax.core.map_elites import MAPElites
from qdax.core.containers.mapelites_repertoire import compute_euclidean_centroids
from qdax.tasks.arm import arm_scoring_function
from qdax.core.emitters.mutation_operators import isoline_variation
from qdax.core.emitters.standard_emitters import MixingEmitter
from qdax.utils.metrics import default_qd_metrics
seed = 42
num_param_dimensions = 100 # 机械臂自由度数量
init_batch_size = 100
batch_size = 1024
num_iterations = 50
grid_shape = (100, 100)
min_param = 0.0
max_param = 1.0
min_bd = 0.0
max_bd = 1.0
# 初始化随机密钥
random_key = jax.random.PRNGKey(seed)
# 初始化控制器群体
random_key, subkey = jax.random.split(random_key)
init_variables = jax.random.uniform(
subkey,
shape=(init_batch_size, num_param_dimensions),
minval=min_param,
maxval=max_param,
)
# 定义发射器
variation_fn = functools.partial(
isoline_variation,
iso_sigma=0.05,
line_sigma=0.1,
minval=min_param,
maxval=max_param,
)
mixing_emitter = MixingEmitter(
mutation_fn=lambda x, y: (x, y),
variation_fn=variation_fn,
variation_percentage=1.0,
batch_size=batch_size,
)
# 定义度量函数
metrics_fn = functools.partial(
default_qd_metrics,
qd_offset=0.0,
)
# 实例化MAP-Elites
map_elites = MAPElites(
scoring_function=arm_scoring_function,
emitter=mixing_emitter,
metrics_function=metrics_fn,
)
# 计算质心
centroids = compute_euclidean_centroids(
grid_shape=grid_shape,
minval=min_bd,
maxval=max_bd,
)
# 初始化库和发射器状态
repertoire, emitter_state, random_key = map_elites.init(init_variables, centroids, random_key)
# 运行MAP-Elites循环
for i in range(num_iterations):
(repertoire, emitter_state, metrics, random_key,) = map_elites.update(
repertoire,
emitter_state,
random_key,
)
# 获取库内容
repertoire.genotypes, repertoire.fitnesses, repertoire.descriptors
QDax核心算法
QDax目前支持以下算法:
算法 | 示例 |
---|---|
MAP-Elites | |
CVT MAP-Elites | |
策略梯度辅助MAP-Elites (PGA-ME) | |
QDPG | |
CMA-ME | |
OMG-MEGA | |
CMA-MEGA | |
多目标MAP-Elites (MOME) | |
MAP-Elites进化策略 (MEES) | |
MAP-Elites PBT (ME-PBT) | |
MAP-Elites低扩散 (ME-LS) |
QDax基准算法
QDax库还提供了一些有用的基准算法实现:
QDax任务
QDax库还为多个标准质量多样性任务提供了众多实现。
所有这些实现及其描述都在任务目录中提供。
贡献
欢迎提出问题和贡献。更多详细信息请参阅文档中的贡献指南。
相关项目
- EvoJAX:硬件加速的神经进化。EvoJAX是一个可扩展的、通用的、硬件加速的神经进化工具包。论文
- evosax:基于JAX的进化策略
引用QDax
如果您在研究中使用了QDax并想在您的工作中引用它,请使用:
@misc{chalumeau2023qdax,
title={QDax: A Library for Quality-Diversity and Population-based Algorithms with Hardware Acceleration},
author={Felix Chalumeau and Bryan Lim and Raphael Boige and Maxime Allard and Luca Grillotti and Manon Flageat and Valentin Macé and Arthur Flajolet and Thomas Pierrot and Antoine Cully},
year={2023},
eprint={2308.03665},
archivePrefix={arXiv},
primaryClass={cs.AI}
}
贡献者
QDax由自适应智能机器人实验室(AIRL)和InstaDeep开发和维护。