Project Icon

sbx

Jax加持的Stable-Baselines3强化学习库

SBX是Stable-Baselines3的Jax实现版本,集成了SAC、TQC、PPO等多种先进强化学习算法。它与SB3保持相同API,可与RL Zoo无缝对接,并提供详细使用示例。SBX为复杂环境和任务提供高效、可靠的强化学习实现。

CI 代码风格

Stable Baselines Jax (SB3 + Jax = SBX)

这是Stable-Baselines3在Jax上的概念验证版本。

已实现的算法:

使用pip安装

对于最新的主分支版本:

pip install git+https://github.com/araffin/sbx

或者:

pip install sbx-rl

示例

import gymnasium as gym

from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ

env = gym.make("Pendulum-v1", render_mode="human")

model = TQC("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000, progress_bar=True)

vec_env = model.get_env()
obs = vec_env.reset()
for _ in range(1000):
    vec_env.render()
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, done, info = vec_env.step(action)

vec_env.close()

在RL Zoo中使用SBX

由于SBX共享SB3的API,它与RL Zoo兼容,你只需覆盖算法映射:

import rl_zoo3
import rl_zoo3.train
from rl_zoo3.train import train
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ

rl_zoo3.ALGOS["ddpg"] = DDPG
rl_zoo3.ALGOS["dqn"] = DQN
# 参见下面的注释以使用DroQ配置
# rl_zoo3.ALGOS["droq"] = DroQ
rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["td3"] = TD3
rl_zoo3.ALGOS["tqc"] = TQC
rl_zoo3.ALGOS["crossq"] = CrossQ
rl_zoo3.train.ALGOS = rl_zoo3.ALGOS
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS

if __name__ == "__main__":
    train()

然后你可以像使用RL Zoo一样运行这个脚本:

python train.py --algo sac --env HalfCheetah-v4 -params train_freq:4 gradient_steps:4 -P

enjoy脚本也是一样的:

import rl_zoo3
import rl_zoo3.enjoy
from rl_zoo3.enjoy import enjoy
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ

rl_zoo3.ALGOS["ddpg"] = DDPG
rl_zoo3.ALGOS["dqn"] = DQN
# 参见下面的注释以使用DroQ配置
# rl_zoo3.ALGOS["droq"] = DroQ
rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["td3"] = TD3
rl_zoo3.ALGOS["tqc"] = TQC
rl_zoo3.ALGOS["crossq"] = CrossQ
rl_zoo3.enjoy.ALGOS = rl_zoo3.ALGOS
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS

if __name__ == "__main__":
    enjoy()

关于DroQ的说明

DroQ是SAC的一种特殊配置。

要使用论文中的超参数配置算法,你应该使用(使用RL Zoo配置格式):

HalfCheetah-v4:
  n_timesteps: !!float 1e6
  policy: 'MlpPolicy'
  learning_starts: 10000
  gradient_steps: 20
  policy_delay: 20
  policy_kwargs: "dict(dropout_rate=0.01, layer_norm=True)"

然后使用上面定义的RL Zoo脚本:python train.py --algo sac --env HalfCheetah-v4 -c droq.yml -P

我们建议调整policy_delaygradient_steps参数以获得更好的速度/效率。 为Q值函数设置更高的学习率也会有帮助:qf_learning_rate: !!float 1e-3

注意:当使用CrossQ的DroQ配置时,你应该设置layer_norm=False,因为已经有批量归一化了。

基准测试

部分基准测试可以在OpenRL Benchmark上找到,你还可以在那里找到几个报告

引用本项目

在出版物中引用此仓库:

@article{stable-baselines3,
  author  = {Antonin Raffin and Ashley Hill and Adam Gleave and Anssi Kanervisto and Maximilian Ernestus and Noah Dormann},
  title   = {Stable-Baselines3: Reliable Reinforcement Learning Implementations},
  journal = {Journal of Machine Learning Research},
  year    = {2021},
  volume  = {22},
  number  = {268},
  pages   = {1-8},
  url     = {http://jmlr.org/papers/v22/20-1364.html}
}

维护者

Stable-Baselines3目前由Ashley Hill(又名@hill-a)、Antonin Raffin(又名@araffin)、Maximilian Ernestus(又名@ernestum)、Adam Gleave(@AdamGleave)、Anssi Kanervisto(@Miffyli)和Quentin Gallouédec(@qgallouedec)维护。

重要提示:我们不提供技术支持、咨询服务,也不回答通过电子邮件发送的个人问题。 在这种情况下,请在RL DiscordRedditStack Overflow上发布你的问题。

如何贡献

对于任何有兴趣改进基线的人,仍然有一些文档工作需要完成。 如果你想贡献,请先阅读CONTRIBUTING.md指南。

贡献者

我们要感谢我们的贡献者:@jan1854

项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

吐司

探索Tensor.Art平台的独特AI模型,免费访问各种图像生成与AI训练工具,从Stable Diffusion等基础模型开始,轻松实现创新图像生成。体验前沿的AI技术,推动个人和企业的创新发展。

Project Cover

SubCat字幕猫

SubCat字幕猫APP是一款创新的视频播放器,它将改变您观看视频的方式!SubCat结合了先进的人工智能技术,为您提供即时视频字幕翻译,无论是本地视频还是网络流媒体,让您轻松享受各种语言的内容。

Project Cover

美间AI

美间AI创意设计平台,利用前沿AI技术,为设计师和营销人员提供一站式设计解决方案。从智能海报到3D效果图,再到文案生成,美间让创意设计更简单、更高效。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号