Stable Baselines Jax (SB3 + Jax = SBX)
这是Stable-Baselines3在Jax上的概念验证版本。
已实现的算法:
- 软演员-评论家算法 (SAC) 和 SAC-N
- 截断分位数评论家 (TQC)
- 用于双重高效强化学习的丢弃Q函数 (DroQ)
- 近端策略优化 (PPO)
- 深度Q网络 (DQN)
- 双延迟DDPG (TD3)
- 深度确定性策略梯度 (DDPG)
- 深度强化学习中的批量归一化 (CrossQ)
使用pip安装
对于最新的主分支版本:
pip install git+https://github.com/araffin/sbx
或者:
pip install sbx-rl
示例
import gymnasium as gym
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ
env = gym.make("Pendulum-v1", render_mode="human")
model = TQC("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000, progress_bar=True)
vec_env = model.get_env()
obs = vec_env.reset()
for _ in range(1000):
vec_env.render()
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = vec_env.step(action)
vec_env.close()
在RL Zoo中使用SBX
由于SBX共享SB3的API,它与RL Zoo兼容,你只需覆盖算法映射:
import rl_zoo3
import rl_zoo3.train
from rl_zoo3.train import train
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ
rl_zoo3.ALGOS["ddpg"] = DDPG
rl_zoo3.ALGOS["dqn"] = DQN
# 参见下面的注释以使用DroQ配置
# rl_zoo3.ALGOS["droq"] = DroQ
rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["td3"] = TD3
rl_zoo3.ALGOS["tqc"] = TQC
rl_zoo3.ALGOS["crossq"] = CrossQ
rl_zoo3.train.ALGOS = rl_zoo3.ALGOS
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS
if __name__ == "__main__":
train()
然后你可以像使用RL Zoo一样运行这个脚本:
python train.py --algo sac --env HalfCheetah-v4 -params train_freq:4 gradient_steps:4 -P
enjoy脚本也是一样的:
import rl_zoo3
import rl_zoo3.enjoy
from rl_zoo3.enjoy import enjoy
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ
rl_zoo3.ALGOS["ddpg"] = DDPG
rl_zoo3.ALGOS["dqn"] = DQN
# 参见下面的注释以使用DroQ配置
# rl_zoo3.ALGOS["droq"] = DroQ
rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["td3"] = TD3
rl_zoo3.ALGOS["tqc"] = TQC
rl_zoo3.ALGOS["crossq"] = CrossQ
rl_zoo3.enjoy.ALGOS = rl_zoo3.ALGOS
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS
if __name__ == "__main__":
enjoy()
关于DroQ的说明
DroQ是SAC的一种特殊配置。
要使用论文中的超参数配置算法,你应该使用(使用RL Zoo配置格式):
HalfCheetah-v4:
n_timesteps: !!float 1e6
policy: 'MlpPolicy'
learning_starts: 10000
gradient_steps: 20
policy_delay: 20
policy_kwargs: "dict(dropout_rate=0.01, layer_norm=True)"
然后使用上面定义的RL Zoo脚本:python train.py --algo sac --env HalfCheetah-v4 -c droq.yml -P
。
我们建议调整policy_delay
和gradient_steps
参数以获得更好的速度/效率。
为Q值函数设置更高的学习率也会有帮助:qf_learning_rate: !!float 1e-3
。
注意:当使用CrossQ的DroQ配置时,你应该设置layer_norm=False
,因为已经有批量归一化了。
基准测试
部分基准测试可以在OpenRL Benchmark上找到,你还可以在那里找到几个报告。
引用本项目
在出版物中引用此仓库:
@article{stable-baselines3,
author = {Antonin Raffin and Ashley Hill and Adam Gleave and Anssi Kanervisto and Maximilian Ernestus and Noah Dormann},
title = {Stable-Baselines3: Reliable Reinforcement Learning Implementations},
journal = {Journal of Machine Learning Research},
year = {2021},
volume = {22},
number = {268},
pages = {1-8},
url = {http://jmlr.org/papers/v22/20-1364.html}
}
维护者
Stable-Baselines3目前由Ashley Hill(又名@hill-a)、Antonin Raffin(又名@araffin)、Maximilian Ernestus(又名@ernestum)、Adam Gleave(@AdamGleave)、Anssi Kanervisto(@Miffyli)和Quentin Gallouédec(@qgallouedec)维护。
重要提示:我们不提供技术支持、咨询服务,也不回答通过电子邮件发送的个人问题。 在这种情况下,请在RL Discord、Reddit或Stack Overflow上发布你的问题。
如何贡献
对于任何有兴趣改进基线的人,仍然有一些文档工作需要完成。 如果你想贡献,请先阅读CONTRIBUTING.md指南。
贡献者
我们要感谢我们的贡献者:@jan1854。