OfflineRL-Kit: 一个优雅的 PyTorch 离线强化学习库

Ray

引言

在机器学习和人工智能领域,强化学习一直是一个备受关注的研究方向。然而,传统的在线强化学习方法在某些场景下可能存在局限性,例如在真实世界中收集大量数据可能成本高昂或危险。为了解决这一问题,离线强化学习应运而生。OfflineRL-Kit 正是为此而设计的一个强大工具,它为研究人员提供了一个优雅、高效的离线强化学习开发环境。

OfflineRL-Kit Logo

OfflineRL-Kit 简介

OfflineRL-Kit 是一个基于纯 PyTorch 的离线强化学习库,由 Yihao Sun 开发。这个库的设计理念是为研究人员提供一个友好、便捷的工具,帮助他们更快速地进行离线强化学习算法的开发和实验。

主要特性

  1. 优雅的框架:OfflineRL-Kit 的代码结构清晰,易于使用,让研究人员能够快速上手。

  2. 最先进的算法:库中包含了多种state-of-the-art的离线强化学习算法,涵盖了无模型和基于模型的方法。

  3. 高可扩展性:研究人员可以基于库中的组件,用很少的代码就能构建新的算法。

  4. 并行调优支持:方便研究人员进行大规模的参数调优实验。

  5. 强大的日志系统:清晰而功能强大的日志系统,便于管理和分析实验结果。

支持的算法

OfflineRL-Kit 支持多种先进的离线强化学习算法,大致可以分为无模型方法和基于模型的方法两类:

无模型方法

  1. Conservative Q-Learning (CQL):一种保守的 Q 学习方法,通过最小化某些动作的 Q 值来避免对未见过的状态-动作对的过度估计。

  2. TD3+BC:结合了 TD3 算法和行为克隆(Behavior Cloning)的方法,在离线设置中表现良好。

  3. Implicit Q-Learning (IQL):一种隐式 Q 学习方法,通过学习值函数的分位数来改善策略学习。

  4. Ensemble-Diversified Actor Critic (EDAC):使用集成方法来增强策略的多样性,提高离线学习的鲁棒性。

  5. Mildly Conservative Q-Learning (MCQ):相较于 CQL,这是一种较为温和的保守 Q 学习方法。

基于模型的方法

  1. Model-based Offline Policy Optimization (MOPO):通过学习环境模型并在模拟环境中进行策略优化,同时考虑模型不确定性。

  2. Conservative Offline Model-Based Policy Optimization (COMBO):在 MOPO 的基础上增加了保守性约束,进一步提高了性能。

  3. Robust Adversarial Model-Based Offline Reinforcement Learning (RAMBO):引入对抗性训练来增强模型的鲁棒性。

  4. Model-Bellman Inconsistency Penalized Offline Reinforcement Learning (MOBILE):通过惩罚模型与贝尔曼方程的不一致性来改善离线学习。

安装指南

要使用 OfflineRL-Kit,您需要按照以下步骤进行安装:

  1. 首先,安装 MuJoCo 引擎。您可以从 MuJoCo 官网 下载。安装完成后,根据您安装的 MuJoCo 版本安装相应的 mujoco-py

  2. 接下来,安装 D4RL:

git clone https://github.com/Farama-Foundation/d4rl.git
cd d4rl
pip install -e .
  1. 最后,安装 OfflineRL-Kit:
git clone https://github.com/yihaosun1124/OfflineRL-Kit.git
cd OfflineRL-Kit
python setup.py install

快速开始

为了帮助您快速上手 OfflineRL-Kit,这里提供了一个使用 Conservative Q-Learning (CQL) 算法的简单示例。

训练过程

  1. 首先,创建环境并获取离线数据集:
env = gym.make(args.task)
dataset = qlearning_dataset(env)
buffer = ReplayBuffer(
buffer_size=len(dataset["observations"]),
obs_shape=args.obs_shape,
obs_dtype=np.float32,
action_dim=args.action_dim,
action_dtype=np.float32,
device=args.device
)
buffer.load_dataset(dataset)
  1. 定义模型和优化器:
actor_backbone = MLP(input_dim=np.prod(args.obs_shape), hidden_dims=args.hidden_dims)
critic1_backbone = MLP(input_dim=np.prod(args.obs_shape) + args.action_dim, hidden_dims=args.hidden_dims)
critic2_backbone = MLP(input_dim=np.prod(args.obs_shape) + args.action_dim, hidden_dims=args.hidden_dims)
dist = TanhDiagGaussian(
    latent_dim=getattr(actor_backbone, "output_dim"),
    output_dim=args.action_dim,
    unbounded=True,
    conditioned_sigma=True
)
actor = ActorProb(actor_backbone, dist, args.device)
critic1 = Critic(critic1_backbone, args.device)
critic2 = Critic(critic2_backbone, args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
  1. 设置策略:
policy = CQLPolicy(
    actor,
    critic1,
    critic2,
    actor_optim,
    critic1_optim,
    critic2_optim,
    action_space=env.action_space,
tau=args.tau,
gamma=args.gamma,
alpha=alpha,
cql_weight=args.cql_weight,
temperature=args.temperature,
max_q_backup=args.max_q_backup,
deterministic_backup=args.deterministic_backup,
with_lagrange=args.with_lagrange,
lagrange_threshold=args.lagrange_threshold,
cql_alpha_lr=args.cql_alpha_lr,
num_repeart_actions=args.num_repeat_actions
)
  1. 定义日志记录器:
log_dirs = make_log_dirs(args.task, args.algo_name, args.seed, vars(args))
output_config = {
    "consoleout_backup": "stdout",
    "policy_training_progress": "csv",
    "tb": "tensorboard"
}
logger = Logger(log_dirs, output_config)
logger.log_hyperparameters(vars(args))
  1. 将所有组件加载到训练器中并开始训练:
policy_trainer = MFPolicyTrainer(
    policy=policy,
    eval_env=env,
    buffer=buffer,
    logger=logger,
    epoch=args.epoch,
    step_per_epoch=args.step_per_epoch,
    batch_size=args.batch_size,
    eval_episodes=args.eval_episodes
)

policy_trainer.train()

参数调优

OfflineRL-Kit 支持使用 Ray 进行并行参数调优。以下是一个简单的示例:

ray.init()
# 加载默认参数
args = get_args()

config = {}
real_ratios = [0.05, 0.5]
seeds = list(range(2))
config["real_ratio"] = tune.grid_search(real_ratios)
config["seed"] = tune.grid_search(seeds)

analysis = tune.run(
    run_exp,
    name="tune_mopo",
    config=config,
    resources_per_trial={
        "gpu": 0.5
    }
)

日志系统

OfflineRL-Kit 的日志系统支持多种记录文件类型,包括:

  • .txt(控制台输出备份)
  • .csv(记录训练过程中的损失、性能或其他指标)
  • .tfevents(用于 TensorBoard 可视化训练曲线)
  • .json(超参数备份)

日志系统的结构清晰,便于管理实验:

└─log(root dir)
    └─task
        └─algo_0
        |   └─seed_0&timestamp_xxx
        |   |   ├─checkpoint
        |   |   ├─model
        |   |   ├─record
        |   |   │  ├─tb
        |   |   │  ├─consoleout_backup.txt
        |   |   │  ├─policy_training_progress.csv
        |   |   │  ├─hyper_param.json
        |   |   ├─result
        |   └─seed_1&timestamp_xxx
        └─algo_1

使用日志记录器的示例:

from offlinerlkit.utils.logger import Logger, make_log_dirs

log_dirs = make_log_dirs(args.task, args.algo_name, args.seed, vars(args))
output_config = {
    "consoleout_backup": "stdout",
    "policy_training_progress": "csv",
    "dynamics_training_progress": "csv",
    "tb": "tensorboard"
}
logger = Logger(log_dirs, output_config)
logger.log_hyperparameters(vars(args))

# 记录一些指标
logger.logkv("eval/normalized_episode_reward", norm_ep_rew_mean)
logger.logkv("eval/normalized_episode_reward_std", norm_ep_rew_std)
logger.logkv("eval/episode_length", ep_length_mean)
logger.logkv("eval/episode_length_std", ep_length_std)
# 设置时间步
logger.set_timestep(num_timesteps)
# 将结果写入记录文件
logger.dumpkvs()

结果可视化

OfflineRL-Kit 提供了简单的脚本来绘制实验结果:

python run_example/plotter.py --algos "mopo" "cql" --task "hopper-medium-replay-v2"

这将生成指定算法和任务的性能对比图。

总结

OfflineRL-Kit 是一个功能强大、易于使用的离线强化学习库,为研究人员提供了一个理想的实验平台。通过其优雅的框架设计、丰富的算法支持、高度的可扩展性以及强大的日志系统,研究人员可以更加高效地进行离线强化学习的研究和开发。

无论您是刚开始接触离线强化学习的新手,还是在该领域深耕多年的专家,OfflineRL-Kit 都能为您提供宝贵的支持。我们期待看到更多研究人员利用这个工具,在离线强化学习领域取得突破性的进展。

如果您在您的工作中使用了 OfflineRL-Kit,请引用以下 bibtex:

@misc{offinerlkit,
  author = {Yihao Sun},
  title = {OfflineRL-Kit: An Elegant PyTorch Offline Reinforcement Learning Library},
  year = {2023},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/yihaosun1124/OfflineRL-Kit}},
}

让我们一起探索离线强化学习的无限可能吧!🚀🤖🧠

avatar
0
0
0
相关项目
Project Cover

AI-Optimizer

AI-Optimizer是一款多功能深度强化学习平台,涵盖从无模型到基于模型,从单智能体到多智能体的多种算法。其分布式训练框架高效便捷,支持多智能体强化学习、离线强化学习、迁移和多任务强化学习、自监督表示学习等,解决维度诅咒、非平稳性和探索-利用平衡等难题,广泛应用于无人机、围棋、扑克、机器人控制和自动驾驶等领域。

Project Cover

awesome-diffusion-model-in-rl

本项目汇总了强化学习领域应用扩散模型的最新研究论文,涵盖离线RL、机器人控制、轨迹规划等多个方向。持续追踪并整理扩散强化学习的前沿进展,为研究人员提供全面的参考资源。每篇论文均附有概述、代码链接和实验环境等详细信息,方便读者深入了解。

Project Cover

scope-rl

SCOPE-RL是一个用于离线强化学习的开源Python库。它实现了从数据生成到策略学习、评估和选择的完整流程。该库提供了多种离线策略评估(OPE)估计器和策略选择(OPS)方法,兼容OpenAI Gym和Gymnasium接口。SCOPE-RL还包含RTBGym和RecGym环境,用于模拟实际应用场景。它简化了离线强化学习的研究和实践过程,提高了实验的透明度和可靠性。

Project Cover

awesome-offline-rl

该项目汇集了离线强化学习(Offline RL)领域的研究论文、综述文章、开源实现等资源。内容涵盖离线RL的理论方法、基准测试、应用案例及相关主题。项目由康奈尔大学研究人员维护,为学术界和产业界提供离线RL的最新进展和重要文献。

Project Cover

Minari

Minari是一个面向离线强化学习研究的Python库,提供类似Gymnasium离线版本的功能。该库具备简洁的数据集读写API,支持远程数据集管理,并允许创建自定义数据集。Minari旨在为研究人员提供标准化工具,推动离线强化学习领域的进步。

Project Cover

OfflineRL-Kit

OfflineRL-Kit是基于PyTorch的离线强化学习库,提供清晰的代码结构和最新算法实现。支持CQL、TD3+BC等多种算法,具备高扩展性和强大的日志系统。该库还支持并行调优,便于研究人员进行实验。相比其他离线强化学习库,OfflineRL-Kit在性能和易用性方面都有显著优势,是离线强化学习研究的有力工具。

最新项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

白日梦AI

白日梦AI提供专注于AI视频生成的多样化功能,包括文生视频、动态画面和形象生成等,帮助用户快速上手,创造专业级内容。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

讯飞绘镜

讯飞绘镜是一个支持从创意到完整视频创作的智能平台,用户可以快速生成视频素材并创作独特的音乐视频和故事。平台提供多样化的主题和精选作品,帮助用户探索创意灵感。

Project Cover

讯飞文书

讯飞文书依托讯飞星火大模型,为文书写作者提供从素材筹备到稿件撰写及审稿的全程支持。通过录音智记和以稿写稿等功能,满足事务性工作的高频需求,帮助撰稿人节省精力,提高效率,优化工作与生活。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号