Project Icon

flashbax

JAX强化学习高效体验回放缓冲库

Flashbax是一个为JAX设计的高效体验回放缓冲库,适用于强化学习算法。它提供平坦缓冲、轨迹缓冲及其优先级变体等多种缓冲类型,特点是高效内存使用、易于集成到编译函数中,并支持优先级采样。Flashbax还具有Vault功能,可将大型缓冲区保存到磁盘。这个简单灵活的框架适用于学术研究、工业应用和个人项目中的体验回放处理。

Flashbax标志 Flashbax标志

Python版本 PyPI版本 测试 代码风格 MyPy 许可证


⚡ Jax中的高速缓冲区 ⚡

概述 🔍

Flashbax是一个旨在简化强化学习(RL)中经验回放缓冲区使用的库。它专为与JAX范式兼容而设计,允许这些缓冲区在完全编译的函数和训练循环中轻松使用。

Flashbax提供了各种不同类型缓冲区的实现,如平面缓冲区、轨迹缓冲区以及两者的优先级变体。无论是用于学术研究、工业应用还是个人项目,Flashbax都为RL经验回放处理提供了简单灵活的框架。

特性 🛠️

🚀 高效缓冲区变体:所有Flashbax缓冲区都是作为轨迹缓冲区的专门变体构建的,在各种类型的缓冲区中优化内存使用和功能。

🗄️ 平面缓冲区:平面缓冲区类似于DQN等算法中使用的转换缓冲区,是一个核心组件。它使用序列长度为2(即$s_t$, $s_{t+1}$),周期为1,以全面考虑转换对。

🧺 项目缓冲区:项目缓冲区是一个存储单个项目的简单缓冲区。它适用于存储相互独立的数据,如(观察、动作、奖励、折扣、下一个观察)元组或整个回合。

🛤️ 轨迹缓冲区:轨迹缓冲区便于采样多步轨迹,适用于使用循环网络的算法,如R2D2(Kapturowski等人,2018)。

🏅 优先级缓冲区:平面缓冲区和轨迹缓冲区都可以设置优先级,实现基于用户定义优先级的采样。优先机制与PER论文(Schaul等人,2016)中概述的原则一致。

🚶 轨迹/平面队列:提供了一种队列数据结构,用于按先进先出(FIFO)顺序采样数据。该队列可用于特定用例的在线策略算法。

设置 🎬

要将Flashbax集成到您的项目中,请按以下步骤操作:

  1. 安装:首先使用pip安装Flashbax:
pip install flashbax
  1. 选择缓冲区:从平面缓冲区、轨迹缓冲区和优先级变体等多种缓冲区选项中选择。
import flashbax as fbx

buffer = fbx.make_trajectory_buffer(...)
# 或
buffer = fbx.make_prioritised_trajectory_buffer(...)
# 或
buffer = fbx.make_flat_buffer(...)
# 或
buffer = fbx.make_prioritised_flat_buffer(...)
# 或
buffer = fbx.make_item_buffer(...)
# 或
buffer = fbx.make_trajectory_queue(...)

# 初始化
state = buffer.init(example_timestep)
# 添加数据
state = buffer.add(state, example_data)
# 采样数据
batch = buffer.sample(state, rng_key)

快速开始 🏁

以下我们提供了使用平面缓冲区的最小代码示例。在此示例中,我们展示了如何使用定义平面缓冲区的每个纯函数。请注意,这些纯函数都与jax.pmapjax.jit兼容,但为简单起见,以下示例中未使用这些函数。

import jax
import jax.numpy as jnp
import flashbax as fbx

# 使用简单配置通过`make_flat_buffer`实例化平面缓冲区NamedTuple。
# 返回的`buffer`只是使用平面缓冲区所需的纯函数的容器。
buffer = fbx.make_flat_buffer(max_length=32, min_length=2, sample_batch_size=1)

# 初始化缓冲区的状态。
fake_timestep = {"obs": jnp.array([0, 0]), "reward": jnp.array(0.0)}
state = buffer.init(fake_timestep)

# 现在我们向缓冲区添加数据。
state = buffer.add(state, {"obs": jnp.array([1, 2]), "reward": jnp.array(3.0)})
print(buffer.can_sample(state))  # False,因为尚未达到min_length。

state = buffer.add(state, {"obs": jnp.array([4, 5]), "reward": jnp.array(6.0)})
print(buffer.can_sample(state))  # 仍为False,因为我们需要2个*转换*(即3个时间步)。

state = buffer.add(state, {"obs": jnp.array([7, 8]), "reward": jnp.array(9.0)})
print(buffer.can_sample(state))  # True!我们有2个转换(3个时间步)。

# 从缓冲区获取一个转换。
rng_key = jax.random.PRNGKey(0)  # 随机源。
batch = buffer.sample(state, rng_key)  # 采样

# 我们有一个转换!打印:obs = [[4 5]], obs' = [[7 8]]
print(
    f"obs = {batch.experience.first['obs']}, obs' = {batch.experience.second['obs']}"
)

示例 🧑‍💻

我们提供以下Colab示例,作为如何使用每个flashbax缓冲区的更高级教程以及使用示例:

Colab 笔记本描述
Colab平面缓冲区快速入门
Colab轨迹缓冲区快速入门
Colab优先级平面缓冲区快速入门
Colab使用 Matrax 的项目缓冲区示例
ColabAnakin DQN
ColabAnakin 优先级 DQN
ColabAnakin PPO
Colab使用向量化 Gym 环境的 DQN
  • 👾 Anakin - 基于 JAX 的架构,用于端到端地即时编译强化学习代理的训练。
  • 🎮 DQN - 实现改编自 CleanRL 的 DQN JAX 示例。
  • 🦎 Jumanji - 利用 Jumanji 基于 JAX 的环境(如贪吃蛇)进行完全即时编译的示例。
  • ✖️ Matrax - JAX 中的双人矩阵游戏。

保险库 💾

保险库是一种将 Flashbax 缓冲区保存到持久数据存储的高效机制,例如用于离线强化学习。考虑一个维度为 $(B, T, *E)$ 的 Flashbax 缓冲区,其中 $B$ 是批次维度(用于同步记录独立轨迹),$T$ 是时间/序列维度,$*E$ 表示经验数据本身的一个或多个维度。由于特定环境可能会生成大量数据,保险库通过沿时间轴读写缓冲区切片来将 $T$ 维度扩展到几乎不受限制的程度。这样,巨大的缓冲区存储可以驻留在磁盘上,从中可以将子缓冲区加载到 RAM/VRAM 中进行高效的离线训练。保险库已经在项目、平面和轨迹缓冲区上进行了测试。

更多信息,请参阅演示笔记本:Colab

重要考虑事项 ⚠️

在使用 Flashbax 缓冲区时,需要注意某些考虑事项以确保强化学习代理的正常功能。

顺序数据添加

Flashbax 使用轨迹缓冲区作为所有缓冲区类型的基础。这意味着数据必须按顺序添加。具体而言,对于平面缓冲区,每个添加的时间步必须紧跟其连续的时间步。在大多数情况下,这个要求自然得到满足,不需要过多考虑。然而,当添加完全独立的数据批次时,必须注意这个限制。未能维持时间步之间的序列关系可能导致算法问题。用户需要处理从最后一个时间步到第一个时间步的情况。这发生在同一批次中从第 n 个情节到第 n+1 个情节时。例如,我们使用自动重置包装器在终止时间步时自动重置环境。此外,我们使用折扣值(非终止状态为 1,终止状态为 0)来相应地掩蔽价值函数和奖励折扣。

有效缓冲区大小

添加数据批次时,缓冲区以块状结构创建。这意味着有效缓冲区大小取决于批次维度的大小。轨迹缓冲区允许用户指定添加批次维度和时间轴的最大长度。这将创建一个 (批次, 时间) 的块状结构,允许存储的最大时间步数为 批次*时间。为了便于使用,我们提供了 max_size 参数,允许用户设置所需的总时间步数,我们根据提供的添加批次维度计算时间轴的最大长度。因此,重要的是要注意,使用 max_size 参数时,时间轴的最大长度将等于 max_size // 添加批次大小,这将向下取整,从而减少有效缓冲区大小。这意味着人们可能认为他们增加了一定量的缓冲区大小,但实际上并没有增加。因此,为避免这种情况,我们建议采取以下两种方法之一:明确使用最大时间轴长度参数,或者以添加批次大小的倍数增加 max_size 参数。

处理情节截断

另一个关键方面是情节截断。当截断情节并将数据添加到缓冲区时,必须确保适当设置完成标志或"折扣"值。忽视这一点可能会给算法的实现和行为带来挑战。如前所述,预期算法会适当处理这些情况。使用平面缓冲区或轨迹缓冲区处理截断可能很困难,因为算法必须处理一个情节的最后时间步后面跟着下一个情节的第一个时间步的情况。为了牺牲内存效率来换取易用性,可以使用项目缓冲区来独立存储转换或整个轨迹。这意味着算法不需要处理一个情节的最后时间步后面跟着下一个情节的第一个时间步的情况,因为只有明确插入的数据才能被采样。

独立数据使用

对于打算使用缺乏顺序信息的数据的缓冲区的情况,你可以利用项目缓冲区,它是一个具有特定配置的包装轨迹缓冲区。通过将序列维度设置为 1 并将周期设置为 1,每个项目将被视为独立的。然而,当处理独立的转换项目(如观察、动作、奖励、折扣、下一个观察)时,请注意这种方法将导致缓冲区中的观察重复,从而导致不必要的内存消耗。值得注意的是,平面缓冲区的实现速度会比以这种方式使用项目缓冲区慢,这是由于硬件加速器上数据索引的固有速度问题;然而,这种权衡是为了提高内存效率。如果速度远比内存效率更重要,那么使用序列为 1 和周期为 1 的轨迹缓冲区存储完整的转换数据项。

缓冲区状态的原地更新

由于缓冲区通常很大并占用设备内存的大部分,因此执行原地更新是有益的。为此,重要的是要向顶级编译函数指定你希望执行这种原地更新操作。这表示如下:

def train(train_state, buffer_state):
    ...
    return train_state, buffer_state

# 初始化缓冲区状态
buffer_fn = fbx.make_trajectory_buffer(...)
buffer_state = buffer_fn.init(example_timestep)

# 初始化一些训练状态
train_state = train_state.init(...)

# 编译训练函数并指定缓冲区状态参数的捐赠
train_state, buffer_state = jax.jit(train, donate_argnums=(1,))(
    train_state, buffer_state
)

在调用 jax.jit 时包含 donate_argnums 很重要,这可以使 JAX 对回放缓冲区状态进行原地更新。如果省略 donate_argnums,JAX 将被迫为回放缓冲区状态的任何修改创建副本,可能会抵消所有性能优势。有关 JAX 中缓冲区捐赠的更多信息,可以在文档中找到。

使用 Vault 存储数据

如上所述,Vault 通过扩展 Flashbax 缓冲区状态的时间轴将经验数据存储到磁盘。默认情况下,Vault 方便地处理此过程的簿记:消耗缓冲区状态并保存任何新的、以前未见过的数据。例如,假设我们向 Flashbax 缓冲区写入 10 个时间步,然后将此状态保存到 Vault;由于所有这些数据都是新的,所有数据都将写入磁盘。但是,如果我们再写入一个时间步并将状态保存到 Vault,则只会写入该新时间步,防止重复已保存的数据。重要的是,必须记住 Flashbax 状态是作为环形缓冲区实现的,这意味着必须足够频繁地更新 Vault,然后再覆盖 Flashbax 缓冲区状态中未见过的数据。即如果我们的缓冲区状态的时间轴长度为 $\tau$,那么我们必须每 $\tau - 1$ 步保存到 vault 一次,以免覆盖(并丢失)未保存的数据。

总之,理解并解决这些考虑因素将帮助您避开潜在的陷阱,并确保在使用 Flashbax 缓冲区时强化学习策略的有效性。

基准测试 📈

这里我们提供了一系列初步基准测试,概述了各种 Flashbax 缓冲区与常用开源缓冲区相比的性能。在这些基准测试中,我们(除非另有明确说明)使用以下配置:

参数
缓冲区大小500_000
采样批次大小256
观察大小(32, 32, 3)
添加序列长度1
添加序列批次大小1
采样序列长度1
采样序列周期1

我们使用采样序列长度和周期为 1 的原因是为了直接与其他缓冲区进行比较,这意味着轨迹缓冲区的速度与项目缓冲区的速度相当,因为项目缓冲区只是具有此配置的包装轨迹缓冲区。这实际上意味着轨迹缓冲区被用作内存效率低下的转换缓冲区。需要注意的是,Flat Buffer 实现使用采样序列长度为 2。此外,必须记住,并非所有其他缓冲区实现都能有效利用 GPU/TPU,因此它们只在 CPU 上运行并执行设备转换。最后,我们明确使用 Python 循环来公平比较其他缓冲区。使用扫描操作可以大大提高速度(取决于观察大小)。

CPU 速度

CPU_Add CPU_Sample

TPU 速度

TPU_Add TPU_Sample

GPU 速度

我们注意到添加数据时 GPU 速度出现奇怪的行为。我们认为这是因为某些 JAX 操作尚未针对 GPU 使用进行充分优化,我们看到 Dejax 也有相同的性能问题。我们预计这些速度将来会有所改善。

GPU_Add GPU_Sample

CPU、GPU 和 TPU 添加批次

之前的基准测试每次只添加一个时间步,现在我们评估每次添加 128 个时间步的批次 - 这是大多数人在高吞吐量 RL 中会使用的功能。我们只与具有此功能的缓冲区进行比较。

CPU_Add_Batch TPU_Add_Batch

GPU_Add_Batch

最终,我们看到性能优于或可与基准测试的缓冲区相媲美,同时提供完全兼容 JAX 的缓冲区,此外还提供批量添加以及能够添加不同长度的序列等功能。我们确实注意到,由于 JAX 对 CPU、GPU 和 TPU 有不同的 XLA 后端,缓冲区的性能可能会因设备和所调用的特定操作而异。

贡献 🤝

欢迎贡献!请查看我们的问题跟踪器以了解适合新手的问题。请阅读我们的贡献指南,了解如何提交拉取请求、我们的贡献者许可协议和社区指南的详细信息。

另请参阅 📚

其他缓冲区

查看我们在基准测试中强调的社区中的其他缓冲区库。

  • 📀 Dejax: 第一个提供兼容 JAX 的回放缓冲区的库。
  • 🎶 Reverb: 用于本地和大规模分布式 RL 的高效回放缓冲区。
  • 🍰 Dopamine: 用于快速原型设计的研究框架,提供了几个核心回放缓冲区。
  • 🤖 StableBaselines3: 可靠的 RL 基线套件,具有自己易于使用的回放缓冲区。

使用示例

查看社区中使用 flashbax 的一些库:

  • 🦁 Mava: 利用 flashbax 的多智能体算法的端到端 JAX 实现。
  • 🏛️ Stoix: 利用 flashbax 的单智能体算法的端到端 JAX 实现。

引用 Flashbax ✏️

如果您在工作中使用了 Flashbax,请使用以下方式引用该库:

@misc{flashbax,
    title={Flashbax: Streamlining Experience Replay Buffers for Reinforcement Learning with JAX},
    author={Edan Toledo and Laurence Midgley and Donal Byrne and Callum Rhys Tilbury and
    Matthew Macfarlane and Cyprien Courtot and Alexandre Laterre},
    year={2023},
    url={https://github.com/instadeepai/flashbax/},
}

致谢 🙏

该库的开发得到了来自 Google 的 TPU Research Cloud (TRC) 🌤 的 Cloud TPU 支持。

项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

吐司

探索Tensor.Art平台的独特AI模型,免费访问各种图像生成与AI训练工具,从Stable Diffusion等基础模型开始,轻松实现创新图像生成。体验前沿的AI技术,推动个人和企业的创新发展。

Project Cover

SubCat字幕猫

SubCat字幕猫APP是一款创新的视频播放器,它将改变您观看视频的方式!SubCat结合了先进的人工智能技术,为您提供即时视频字幕翻译,无论是本地视频还是网络流媒体,让您轻松享受各种语言的内容。

Project Cover

美间AI

美间AI创意设计平台,利用前沿AI技术,为设计师和营销人员提供一站式设计解决方案。从智能海报到3D效果图,再到文案生成,美间让创意设计更简单、更高效。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号