简介
对于想要在视觉分类领域达到最新技术水平的研究人员和开发者来说,vit-pytorch
项目提供了一种基于PyTorch的Vision Transformer(视觉转换器,简称ViT)实现。传统的卷积神经网络在提取图像特征方面已经受到了广泛的关注,而Vision Transformer通过使用单一的Transformer编码器来进行图像分类,提供了一种别样的解决方案。
安装
该项目的安装十分简单,只需运行以下命令:
$ pip install vit-pytorch
使用方法
要使用 vit-pytorch
,首先需要导入相应的库并初始化ViT实例,然后输入图像进行预测。以下是一个简单的示例:
import torch
from vit_pytorch import ViT
v = ViT(
image_size=256,
patch_size=32,
num_classes=1000,
dim=1024,
depth=6,
heads=16,
mlp_dim=2048,
dropout=0.1,
emb_dropout=0.1
)
img = torch.randn(1, 3, 256, 256)
preds = v(img) # 输出 (1, 1000)
参数说明
image_size
: 图像尺寸。对于非正方形图像,以最大边为准。patch_size
: 每个小块的尺寸,image_size
应该可以被patch_size
整除。生成的小块数必须超过16。num_classes
: 分类的类别数。dim
: 线性变换后的输出张量最后一维的大小。depth
: Transformer块的数量。heads
: 多头注意力机制中的头数。mlp_dim
: MLP(FeedForward层)的维度。channels
: 图像的通道数(默认3
)。dropout
: 丢弃率(介于[0, 1]
之间)。emb_dropout
: 嵌入丢弃率。pool
: 使用cls
token pooling 或mean
pooling。
简化版的 ViT
在某些情况下,使用简化版的ViT(Simple ViT)能够提高训练速度和精度。以下代码展示了如何使用Simple ViT:
import torch
from vit_pytorch import SimpleViT
v = SimpleViT(
image_size=256,
patch_size=32,
num_classes=1000,
dim=1024,
depth=6,
heads=16,
mlp_dim=2048
)
img = torch.randn(1, 3, 256, 256)
preds = v(img) # 输出 (1, 1000)
NaViT
NaViT是一种利用多变长序列的注意力和掩码处理来加速训练和提高准确性的Transformer变种。以下是使用NaViT的代码示例:
import torch
from vit_pytorch.na_vit import NaViT
v = NaViT(
image_size=256,
patch_size=32,
num_classes=1000,
dim=1024,
depth=6,
heads=16,
mlp_dim=2048,
dropout=0.1,
emb_dropout=0.1,
token_dropout_prob=0.1 # 令牌随机丢弃率为10%
)
images = [
[torch.randn(3, 256, 256), torch.randn(3, 128, 128)],
[torch.randn(3, 128, 256), torch.randn(3, 256, 128)],
[torch.randn(3, 64, 256)]
]
preds = v(images) # 输出 (5, 1000)
知识蒸馏
知识蒸馏通过从卷积网络中提炼知识用于训练紧凑的Vision Transformer。例如,可以从ResNet50向ViT进行知识迁移。以下代码展示了如何实现这种方法:
import torch
from torchvision.models import resnet50
from vit_pytorch.distill import DistillableViT, DistillWrapper
teacher = resnet50(pretrained=True)
v = DistillableViT(
image_size=256,
patch_size=32,
num_classes=1000,
dim=1024,
depth=6,
heads=8,
mlp_dim=2048,
dropout=0.1,
emb_dropout=0.1
)
distiller = DistillWrapper(
student=v,
teacher=teacher,
temperature=3,
alpha=0.5,
hard=False
)
img = torch.randn(2, 3, 256, 256)
labels = torch.randint(0, 1000, (2,))
loss = distiller(img, labels)
loss.backward()
pred = v(img) # 输出 (2, 1000)
以上介绍了 vit-pytorch
项目中的部分内容,并展示了如何通过PyTorch实现文本转换,这只是它在众多应用领域中的一个例子。当然,根据特定需求,vit-pytorch
还提供了更多的Transformer变种和功能。