Project Icon

PickScore

优化文本到图像生成的用户偏好数据集和模型

PickScore是一个开源项目,提供数据集和模型用于优化文本到图像生成的用户偏好预测。项目包含Pick-a-Pic v1和v2数据集,以及基于v1训练的PickScore模型。此外,还提供演示、安装指南、推理示例和训练脚本,方便研究人员和开发者进行实验和改进。PickScore致力于提升AI生成图像的质量和用户体验。

PickScore

本仓库包含论文Pick-a-Pic:文本到图像生成的用户偏好开放数据集的代码。

我们还开源了Pick-a-Pic v2数据集(包含超过100万个示例)、Pick-a-Pic v1数据集(论文中使用的原始数据集)和PickScore模型(在v1数据集上训练)。我们鼓励读者试用Pick-a-Pic的网络应用并为数据集做出贡献。

演示

我们在HF Spaces上为PickScore创建了一个简单的演示,欢迎查看 :)

安装

创建虚拟环境并下载torch:

conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia

然后安装其余依赖:

pip install -r requirements.txt
pip install -e .

或根据需要单独下载每个包

conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
pip install transformers==4.27.3 

# 仅训练需要
pip install git+https://github.com/huggingface/accelerate.git@d1aa558119859c4b205a324afabaecabd9ef375e
pip install deepspeed==0.8.3
pip install datasets==2.10.1
pip install hydra-core==1.3.2 
pip install rich==13.3.2
pip install wandb==0.12.21
pip install -e .

# 仅在slurm上训练需要
pip install submitit==1.4.5

# 仅评估需要
pip install fire==0.4.0

使用PickScore进行推理

这里展示了一个使用PickScore作为偏好预测器进行推理的示例:

# 导入
from transformers import AutoProcessor, AutoModel
from PIL import Image
import torch

# 加载模型
device = "cuda"
processor_name_or_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
model_pretrained_name_or_path = "yuvalkirstain/PickScore_v1"

processor = AutoProcessor.from_pretrained(processor_name_or_path)
model = AutoModel.from_pretrained(model_pretrained_name_or_path).eval().to(device)

def calc_probs(prompt, images):
    
    # 预处理
    image_inputs = processor(
        images=images,
        padding=True,
        truncation=True,
        max_length=77,
        return_tensors="pt",
    ).to(device)
    
    text_inputs = processor(
        text=prompt,
        padding=True,
        truncation=True,
        max_length=77,
        return_tensors="pt",
    ).to(device)


    with torch.no_grad():
        # 嵌入
        image_embs = model.get_image_features(**image_inputs)
        image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
    
        text_embs = model.get_text_features(**text_inputs)
        text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
    
        # 评分
        scores = model.logit_scale.exp() * (text_embs @ image_embs.T)[0]
        
        # 如果有多个图像可供选择,获取概率
        probs = torch.softmax(scores, dim=-1)
    
    return probs.cpu().tolist()

pil_images = [Image.open("my_amazing_images/1.jpg"), Image.open("my_amazing_images/2.jpg")]
prompt = "fantastic, increadible prompt"
print(calc_probs(prompt, pil_images))

下载Pick-a-Pic数据集

下载数据集大约需要30分钟,占用约190GB的磁盘空间。只需运行:

from datasets import load_dataset
dataset = load_dataset("yuvalkirstain/pickapic_v1", num_proc=64)
# 如果您想下载最新版本的pickapic,请下载:
# dataset = load_dataset("yuvalkirstain/pickapic_v2", num_proc=64)

您也可以使用'streaming=true',这样就不会下载整个数据集。 jpg_0和jpg_1列包含图像的字节数据,可以使用PIL和io.BytesIO读取。

请注意,数据集包含超过50万张图像,因此您可以先下载验证集(添加streaming=True以避免下载整个数据集)或不包含图像的版本(仅包含图像URL): 但最近图像URL已失效。

dataset = load_dataset("yuvalkirstain/pickapic_v1_no_images")
# 如果您想下载最新版本的pickapic,请下载:
# dataset = load_dataset("yuvalkirstain/pickapic_v2_no_images")

注意,我们打算仅通过HF数据集允许下载图像,而不是直接从AWS下载。如果URL不起作用,请从huggingface数据集下载数据。

从头开始训练PickScore

在训练之前,您可能想先下载数据集以节省计算资源。 此处的训练在8个A100 GPU上进行,大约需要40分钟。

本地

accelerate launch --dynamo_backend no --gpu_ids all --num_processes 8  --num_machines 1 --use_deepspeed trainer/scripts/train.py +experiment=clip_h output_dir=output```

Slurm

python trainer/slurm_scripts/slurm_train.py +slurm=stability 'slurm.cmd="+experiment=clip_h"'

在Pick-a-Pic上测试PickScore

 python trainer/scripts/eval_preference_predictor.py

引用

如果您觉得这项工作有用,请引用:

@inproceedings{Kirstain2023PickaPicAO,
  title={Pick-a-Pic: An Open Dataset of User Preferences for Text-to-Image Generation},
  author={Yuval Kirstain and Adam Polyak and Uriel Singer and Shahbuland Matiana and Joe Penna and Omer Levy},
  year={2023}
}
项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

吐司

探索Tensor.Art平台的独特AI模型,免费访问各种图像生成与AI训练工具,从Stable Diffusion等基础模型开始,轻松实现创新图像生成。体验前沿的AI技术,推动个人和企业的创新发展。

Project Cover

SubCat字幕猫

SubCat字幕猫APP是一款创新的视频播放器,它将改变您观看视频的方式!SubCat结合了先进的人工智能技术,为您提供即时视频字幕翻译,无论是本地视频还是网络流媒体,让您轻松享受各种语言的内容。

Project Cover

美间AI

美间AI创意设计平台,利用前沿AI技术,为设计师和营销人员提供一站式设计解决方案。从智能海报到3D效果图,再到文案生成,美间让创意设计更简单、更高效。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号