Project Icon

ml-fastvit

高效混合视觉Transformer模型用于图像分类

FastViT是一种采用结构重参数化技术的混合视觉Transformer模型。该模型在ImageNet-1K数据集上实现了准确率和延迟的良好平衡,提供多个变体以适应不同应用场景。FastViT在iPhone 12 Pro上的基准测试显示出优秀的移动端性能。项目开源了预训练模型、训练评估代码和使用文档。

FastViT: 使用结构重参数化的快速混合视觉Transformer

这是以下论文的官方代码库:

FastViT: 使用结构重参数化的快速混合视觉Transformer Pavan Kumar Anasosalu Vasu, James Gabriel, Jeff Zhu, Oncel Tuzel, Anurag Ranjan. ICCV 2023

arxiv webpage

FastViT性能

所有模型都在ImageNet-1K上训练,并使用ModelBench应用在iPhone 12 Pro上进行基准测试。

环境设置

conda create -n fastvit python=3.9
conda activate fastvit
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
pip install -r requirements.txt

使用方法

要使用我们的模型,请参考以下代码片段:

import torch
import models
from timm.models import create_model
from models.modules.mobileone import reparameterize_model

# 从头开始训练/微调
model = create_model("fastvit_t8")
# ... 训练 ...

# 加载未融合的预训练检查点用于微调
# 或用于下游任务训练,如检测/分割
checkpoint = torch.load('/path/to/unfused_checkpoint.pth.tar')
model.load_state_dict(checkpoint['state_dict'])
# ... 训练 ...

# 用于推理
model.eval()      
model_inf = reparameterize_model(model)
# 在测试时使用model_inf

FastViT模型库

图像分类

在ImageNet-1K上训练的模型

使用知识蒸馏在ImageNet-1K上训练的模型。

延迟基准测试

所有模型的延迟都是在iPhone 12 Pro上使用ModelBench应用测量的。 如需更多详细信息,请联系James GabrielJeff Zhu。 所有报告的数字都四舍五入到最接近的小数点。

训练

图像分类

数据集准备

下载ImageNet-1K数据集,并按以下结构组织数据:

/path/to/imagenet-1k/
  train/
    class1/
      img1.jpeg
    class2/
      img2.jpeg
  validation/
    class1/
      img3.jpeg
    class2/
      img4.jpeg

要训练FastViT模型的变体,请按照以下相应的命令操作:

FastViT-T8
# 无蒸馏
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_t8 -b 128 --lr 1e-3 \
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256

# 有蒸馏
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_t8 -b 128 --lr 1e-3 \ 
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256 
--distillation-type "hard"
FastViT-T12
# 无蒸馏
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_t12 -b 128 --lr 1e-3 \
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256
# 使用蒸馏
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_t12 -b 128 --lr 1e-3 \ 
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256 
--distillation-type "hard"
FastViT-S12
# 不使用蒸馏
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_s12 -b 128 --lr 1e-3 \
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256

# 使用蒸馏
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_s12 -b 128 --lr 1e-3 \ 
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256 
--distillation-type "hard"
FastViT-SA12
# 不使用蒸馏
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_sa12 -b 128 --lr 1e-3 \
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256 --drop-path 0.1

# 使用蒸馏
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_sa12 -b 128 --lr 1e-3 \ 
--native-amp --output /path/to/save/results \
--input-size 3 256 256 
--distillation-type "hard"
FastViT-SA24
# 不使用蒸馏
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_sa24 -b 128 --lr 1e-3 \
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256 --drop-path 0.1

# 使用蒸馏
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_sa24 -b 128 --lr 1e-3 \ 
--native-amp --output /path/to/save/results \
--input-size 3 256 256 --drop-path 0.05 \
--distillation-type "hard"
FastViT-SA36
# 不使用蒸馏
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_sa36 -b 128 --lr 1e-3 \
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256 --drop-path 0.2

# 使用蒸馏
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_sa36 -b 128 --lr 1e-3 \ 
--native-amp --output /path/to/save/results \
--input-size 3 256 256 --drop-path 0.1 \
--distillation-type "hard"
FastViT-MA36
# 不使用蒸馏
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_t8 -b 128 --lr 1e-3 \
--native-amp --output /path/to/save/results \
--input-size 3 256 256 --drop-path 0.35

# 使用蒸馏
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_t8 -b 128 --lr 1e-3 \ 
--native-amp --output /path/to/save/results \
--input-size 3 256 256 --drop-path 0.2 \
--distillation-type "hard"

评估

要在ImageNet上运行评估,请按照以下示例命令操作:

FastViT-T8
# 评估未融合的检查点
python validate.py /path/to/ImageNet/dataset --model fastvit_t8 \
--checkpoint /path/to/pretrained_checkpoints/fastvit_t8.pth.tar

# 评估融合的检查点
python validate.py /path/to/ImageNet/dataset --model fastvit_t8 \
--checkpoint /path/to/pretrained_checkpoints/fastvit_t8_reparam.pth.tar \
--use-inference-mode

模型导出

要从PyTorch检查点导出CoreML包文件,请按照以下示例命令操作:

FastViT-T8
python export_model.py --variant fastvit_t8 --output-dir /path/to/save/exported_model \
--checkpoint /path/to/pretrained_checkpoints/fastvit_t8_reparam.pth.tar

引用

@inproceedings{vasufastvit2023,
  author = {Pavan Kumar Anasosalu Vasu and James Gabriel and Jeff Zhu and Oncel Tuzel and Anurag Ranjan},
  title = {FastViT:  A Fast Hybrid Vision Transformer using Structural Reparameterization},
  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
  year = {2023}
}

致谢

我们的代码库是基于多个开源贡献构建的,详情请参阅ACKNOWLEDGEMENTS

项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

白日梦AI

白日梦AI提供专注于AI视频生成的多样化功能,包括文生视频、动态画面和形象生成等,帮助用户快速上手,创造专业级内容。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

讯飞绘镜

讯飞绘镜是一个支持从创意到完整视频创作的智能平台,用户可以快速生成视频素材并创作独特的音乐视频和故事。平台提供多样化的主题和精选作品,帮助用户探索创意灵感。

Project Cover

讯飞文书

讯飞文书依托讯飞星火大模型,为文书写作者提供从素材筹备到稿件撰写及审稿的全程支持。通过录音智记和以稿写稿等功能,满足事务性工作的高频需求,帮助撰稿人节省精力,提高效率,优化工作与生活。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号