CLIP
CLIP(对比语言-图像预训练)是一种在多种(图像,文本)配对数据上训练的神经网络。它可以通过自然语言指令来预测给定图像最相关的文本片段,而不需要直接为任务进行优化,类似于GPT-2和GPT-3的零样本能力。我们发现CLIP在ImageNet上的“零样本”表现与原始ResNet50相当,而无需使用任何原始的1.28M标注样本,克服了计算机视觉中的几个主要挑战。
方法
用法
首先,安装PyTorch 1.7.1(或更高版本)和torchvision,以及一些小的额外依赖项,然后将此仓库安装为Python包。在带有CUDA GPU的机器上,以下命令即可完成安装:
$ conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0
$ pip install ftfy regex tqdm
$ pip install git+https://github.com/openai/CLIP.git
在没有GPU的机器上安装时,将上述cudatoolkit=11.0
替换为适当的CUDA版本或cpuonly
。
import torch
import clip
from PIL import Image
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["一个图表", "一只狗", "一只猫"]).to(device)
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text)
logits_per_image, logits_per_text = model(image, text)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
print("标签概率:", probs) # 输出: [[0.9927937 0.00421068 0.00299572]]
API
CLIP模块clip
提供以下方法:
clip.available_models()
返回可用CLIP模型的名称。
clip.load(name, device=..., jit=False)
返回指定模型名称的模型和模型所需的TorchVision转换。当必要时,它会自动下载模型。name
参数也可以是本地检查点的路径。
可以选择性地指定运行模型的设备,默认是使用第一个CUDA设备(如果有),否则使用CPU。当jit
为False
时,将加载模型的非JIT版本。
clip.tokenize(text: Union[str, List[str]], context_length=77)
返回包含给定文本输入(单个或多个)标记序列的LongTensor。这个Tensor可以作为模型的输入。
通过clip.load()
返回的模型支持以下方法:
model.encode_image(image: Tensor)
给定一批图像,返回CLIP模型的视觉部分编码的图像特征。
model.encode_text(text: Tensor)
给定一批文本标记,返回CLIP模型的语言部分编码的文本特征。
model(image: Tensor, text: Tensor)
给定一批图像和一批文本标记,返回两个Tensors,分别包含对应每个图像和文本输入的logit分数。这些值是对应图像和文本特征之间的余弦相似度乘以100。
更多示例
零样本预测
以下代码使用CLIP执行零样本预测,正如论文附录B中所示的那样。此示例从CIFAR-100数据集中获取一张图像,并在数据集的100个文本标签中预测最可能的标签。
import os
import clip
import torch
from torchvision.datasets import CIFAR100
# 加载模型
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)
# 下载数据集
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)
# 准备输入
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"一张{c}的照片") for c in cifar100.classes]).to(device)
# 计算特征
with torch.no_grad():
image_features = model.encode_image(image_input)
text_features = model.encode_text(text_inputs)
# 为图像选择前5个最相似的标签
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)
# 输出结果
print("\n最优预测:\n")
for value, index in zip(values, indices):
print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")
输出将如下所示(具体数值可能因计算设备而略有不同):
最优预测:
蛇: 65.31%
龟: 12.29%
甜椒: 3.83%
蜥蜴: 1.88%
鳄鱼: 1.75%
注意,这个示例使用了encode_image()
和encode_text()
方法来返回给定输入的编码特征。
线性探针评估
下面的示例使用scikit-learn对图像特征进行逻辑回归。
import os
import clip
import torch
import numpy as np
from sklearn.linear_model.LogisticRegression
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
from tqdm import tqdm
# 加载模型
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)
# 加载数据集
root = os.path.expanduser("~/.cache")
train = CIFAR100(root, download=True, train=True, transform=preprocess)
test = CIFAR100(root, download=True, train=False, transform=preprocess)
def get_features(dataset):
all_features = []
all_labels = []
with torch.no_grad():
for images, labels in tqdm(DataLoader(dataset, batch_size=100)):
features = model.encode_image(images.to(device))
all_features.append(features)
all_labels.append(labels)
return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy()
# 计算图像特征
train_features, train_labels = get_features(train)
test_features, test_labels = get_features(test)
# 执行逻辑回归
classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1)
classifier.fit(train_features, train_labels)
# 使用逻辑回归分类器进行评估
predictions = classifier.predict(test_features)
accuracy = np.mean((test_labels == predictions).astype(float)) * 100.
print(f"准确率 = {accuracy:.3f}")
请注意,C
值应通过使用验证集的超参数搜索确定。
另见
- OpenCLIP:包含更大且独立训练的CLIP模型,最高到ViT-G/14
- Hugging Face的CLIP实现:便于与HF生态系统集成