稳定基线3
稳定基线3 (SB3) 是一组在 PyTorch 中实现的可靠的强化学习算法。它是稳定基线的下一个主要版本。
你可以在 v1.0 博客文章 或我们的 JMLR 论文中阅读有关稳定基线3的详细介绍。
这些算法将使研究社区和工业界更容易复制、改进和识别新想法,并将创造良好的基线来构建项目。我们希望这些工具将被用作基础,添加新想法,并作为比较新方法与现有方法的工具。我们还希望这些工具的简化能使初学者能够尝试更高级的工具集,而不被实现细节所困扰。
注意:尽管使用简单,稳定基线3 (SB3) 假设你对强化学习 (RL) 有一定了解。 没有一定的实践经验,你不应该使用这个库。为此,我们在文档中提供了良好的资源,以便开始使用 RL。
主要功能
每种算法的性能都经过测试(见各自页面中的结果部分),你可以查看问题 #48 和 #49以了解更多详情。
功能 | Stable-Baselines3 |
---|---|
先进的 RL 方法 | :heavy_check_mark: |
文档 | :heavy_check_mark: |
自定义环境 | :heavy_check_mark: |
自定义策略 | :heavy_check_mark: |
通用接口 | :heavy_check_mark: |
Dict 观察空间支持 | :heavy_check_mark: |
Ipython / Notebook 友好 | :heavy_check_mark: |
Tensorboard 支持 | :heavy_check_mark: |
PEP8 代码风格 | :heavy_check_mark: |
自定义回调 | :heavy_check_mark: |
高代码覆盖率 | :heavy_check_mark: |
类型提示 | :heavy_check_mark: |
计划功能
迁移指南:从稳定基线 (SB2) 到稳定基线3 (SB3)
SB2 到 SB3 的迁移指南可以在文档中找到。
文档
文档在线提供:https://stable-baselines3.readthedocs.io/
集成
稳定基线3 与其他库/服务有一些集成,如用于实验跟踪的 Weights & Biases 或用于存储/共享训练模型的 Hugging Face。在文档的专用部分可以了解更多信息。
RL 基线3 动物园:稳定基线3 强化学习代理的训练框架
RL 基线3 动物园是一个针对强化学习 (RL) 的训练框架。
它提供了训练、评估代理、调整超参数、绘制结果和录制视频的脚本。
此外,还包括常见环境和 RL 算法的调优超参数集合,以及使用这些设置训练的代理。
此存储库的目标:
- 提供一个简单的接口来训练和享受 RL 代理
- 基准测试不同的强化学习算法
- 为每个环境和 RL 算法提供调优的超参数
- 与训练好的代理一起玩得开心!
GitHub 仓库:https://github.com/DLR-RM/rl-baselines3-zoo
文档:https://rl-baselines3-zoo.readthedocs.io/en/master/
SB3-Contrib:实验性 RL 功能
我们在一个单独的贡献存储库中实现实验性功能:SB3-Contrib
这允许 SB3 保持一个稳定和紧凑的核心,同时仍然提供最新的功能,如循环 PPO (PPO LSTM)、截断分位数评论 (TQC)、分位数回归 DQN (QR-DQN) 或具有无效操作屏蔽的 PPO (Maskable PPO)。
文档在线提供:https://sb3-contrib.readthedocs.io/
稳定基线 Jax (SBX)
稳定基线 Jax (SBX) 是稳定基线3 在 Jax 中的概念验证版本,包括诸如 DroQ 或 CrossQ 等最新算法。
与 SB3 相比,它提供的功能最少,但速度可能快很多(最多可快 20 倍!):https://twitter.com/araffin2/status/1590714558628253698
安装
注意: 稳定基线3 支持 PyTorch >= 1.13
先决条件
稳定基线3 需要 Python 3.8+。
Windows 10
要在 Windows 上安装稳定基线,请查看文档。
使用 pip 安装
安装稳定基线3 包:
pip install stable-baselines3[extra]
注意: 一些壳如 Zsh 需要用引号括起方括号,例如 pip install 'stable-baselines3[extra]'
(更多信息)。
这包括 Tensorboard、OpenCV 或 ale-py
等可选依赖项,以便在 Atari 游戏上进行训练。如果不需要这些,你可以使用:
pip install stable-baselines3
请阅读文档以获取更多详细信息和替代方法(从源文件、使用 docker)。
示例
库中的大部分代码尝试遵循 sklearn 风格的强化学习算法语法。
这是一个如何在 Cartpole 环境中训练和运行 PPO 的快速示例:
import gymnasium as gym
from stable_baselines3 import PPO
env = gym.make("CartPole-v1", render_mode="human")
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000)
vec_env = model.get_env()
obs = vec_env.reset()
for i in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = vec_env.step(action)
vec_env.render()
# VecEnv 自动重置
# if done:
# obs = env.reset()
env.close()
如果环境在 Gymnasium 中注册且策略已注册,你也可以只用一行代码训练一个模型:
from stable_baselines3 import PPO
model = PPO("MlpPolicy", "CartPole-v1").learn(10_000)
请阅读文档以获取更多示例。
使用 Colab 笔记本在线尝试!
以下所有示例都可以使用 Google Colab 笔记本在线执行:
实现的算法
<SOURCE_TEXT>
名称 | 循环 | Box | 离散 | 多离散 | 多二进制 | 多重处理 |
---|---|---|---|---|---|---|
ARS1 | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
DDPG | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
DQN | :x: | :x: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
HER | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
PPO | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
QR-DQN1 | :x: | :x: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
RecurrentPPO1 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
SAC | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
TD3 | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
TQC1 | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
TRPO1 | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
可掩蔽 PPO1 | :x: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
1: 实现于 SB3 Contrib GitHub 仓库。
动作 gym.spaces
:
Box
: 包含动作空间中每个点的 N 维盒。离散
: 可能动作的列表,每个时间步只能使用其中一个动作。多离散
: 可能动作的列表,每个时间步只能使用每个离散集合中的一个动作。多二进制
: 可能动作的列表,每个时间步可以任意组合使用其中任意动作。
测试安装
安装依赖
pip install -e .[docs,tests,extra]
运行测试
可以使用 pytest
运行 stable baselines3 中的所有单元测试:
make pytest
运行单个测试文件:
python3 -m pytest -v tests/test_env_checker.py
运行单个测试:
python3 -m pytest -v -k 'test_check_env_dict_action'
你也可以使用 pytype
和 mypy
做静态类型检查:
pip install pytype mypy
make type
使用 ruff
进行代码风格检查:
pip install ruff
make lint
使用Stable-Baselines3的项目
我们试图在文档中维护一个使用stable-baselines3的项目列表,如果希望你的项目出现在此页面,请告知我们 ;)
引用本项目
在出版物中引用此仓库:
@article{stable-baselines3,
author = {Antonin Raffin and Ashley Hill and Adam Gleave and Anssi Kanervisto and Maximilian Ernestus and Noah Dormann},
title = {Stable-Baselines3: Reliable Reinforcement Learning Implementations},
journal = {Journal of Machine Learning Research},
year = {2021},
volume = {22},
number = {268},
pages = {1-8},
url = {http://jmlr.org/papers/v22/20-1364.html}
}
维护者
Stable-Baselines3 目前由 Ashley Hill (aka @hill-a),Antonin Raffin (aka @araffin),Maximilian Ernestus (aka @ernestum),Adam Gleave (@AdamGleave),Anssi Kanervisto (@Miffyli) 和 Quentin Gallouédec (@qgallouedec) 维护。
重要提示:我们不提供技术支持或咨询服务,并且不通过电子邮件回答个人问题。 在这种情况下,请在 RL Discord,Reddit 或 Stack Overflow 上发布您的问题。
如何贡献
对任何有兴趣改进基线的人来说,仍有一些文档需要完成。 如果你想贡献,请先阅读 CONTRIBUTING.md 指南。
致谢
开发Stable Baselines3的初步工作部分由项目Reduced Complexity Models和Helmholtz-Gemeinschaft Deutscher Forschungszentren资助,并由欧盟的Horizon 2020研究和创新计划,根据资助协议号951992(VeriDream)资助。
最初版本,Stable Baselines,诞生于ENSTA ParisTech的机器人实验室U2IS(INRIA Flowers团队)。
Logo作者: L.M. Tenkes </SOURCE_TEXT>