扩散强化学习 X
DRLX 是一个用于通过强化学习进行分布式训练扩散模型的库。它旨在包装 🤗 Hugging Face 的 Diffusers 库,并使用 Accelerate 进行多 GPU 和多节点(目前尚未测试)训练。
最新消息 (2023年9月27日):查看我们最近实验的博客文章 点击这里!
📖 文档
设置
首先确保你已安装 OpenCLIP。之后,你可以通过 pypi 安装该库:
pip install drlx
或从源代码安装:
pip install git+https://github.com/CarperAI/DRLX.git
如何使用
目前我们仅测试了 Stable Diffusion 1.4、1.5 和 2.1 版本,但由于其即插即用的特性,实际上大多数流程中的任何去噪器都应该可用。使用 DRLX 保存的模型与其原始流程兼容,可以像任何其他预训练模型一样加载。目前仅支持 DDPO 算法进行训练。
from drlx.reward_modelling.aesthetics import Aesthetics
from drlx.pipeline.pickapic_prompts import PickAPicPrompts
from drlx.trainer.ddpo_trainer import DDPOTrainer
from drlx.configs import DRLXConfig
# 我们导入一个奖励模型、提示流程、训练器和配置
pipe = PickAPicPrompts()
config = DRLXConfig.load_yaml("configs/my_cfg.yml")
trainer = DDPOTrainer(config)
trainer.train(pipe, Aesthetics())
然后使用训练好的模型进行推理:
pipe = StableDiffusionPipeline.from_pretrained("out/ddpo_exp")
prompt = "A mad panda scientist"
image = pipe(prompt).images[0]
image.save("test.jpeg")
加速训练
accelerate config
accelerate launch -m [你的模块]
路线图
- 初始发布和 DDPO
- PickScore 调优模型
- DPO
- SDXL 支持