SparK: 首次成功在任何卷积网络上实现BERT/MAE风格的预训练
这是ICLR论文《为卷积网络设计BERT:Sparse和分层Masked建模》的官方实现,可以以BERT风格的自监督方式预训练任何CNN(如ResNet)。 我们尽力保持代码库的简洁、短小、易读、前沿,并且仅依赖于最少的依赖项。
🔥 新闻
- 我们的ICLR海报页面有一个简短的英文介绍![
📹录制视频、海报和幻灯片
]。 - 5月11日在OpenMMLab & ReadPaper (bilibili) 有另一场直播![
📹录制视频
] - **4月27日(UTC+8 下午8点)**在OpenMMLab (bilibili) 将举行另一场直播!
- **3月22日(UTC+8 下午8点)**在极市平台(bilibili) 将举行另一场直播![
📹录制视频
] - 3月16日(UTC+8 下午8点) 在TechBeat (将门创投) 也将有一个分享![
📹录制视频
] - 我们很荣幸被Synced("机器之心机动组 视频号" 在微信上)邀请于2月27日(UTC+0 上午11点, UTC+8 晚上7点) 讨论SparK,欢迎参与![
📹录制视频
] - 本工作作为Spotlight(前25%)被ICLR 2023接受。
- 其他文章:[
Synced
] [DeepAI
] [TheGradient
] [Bytedance
] [CVers
[QbitAI(量子位)
] [BAAI(智源)
] [机器之心机动组
] [极市平台
] [ReadPaper笔记
]
🕹️ Colab 可视化演示
查看pretrain/viz_reconstruction.ipynb以可视化SparK预训练模型的重构,例如:
我们还提供了pretrain/viz_spconv.ipynb,展示了密集卷积层的“mask pattern vanishing”问题。
有什么新内容?
🔥 预训练的CNN优于预训练的Swin-Transformer:
🔥 在SparK预训练之后较小模型可以超越未预训练的较大模型:
🔥 所有模型都能受益,表现出一种扩展行为:
🔥 生成自监督预训练超越对比学习:
查看我们的论文 了解更多分析、讨论和评估。
待办事项
目录
- 预训练代码
- 定制CNN模型的预训练教程 (预训练您的CNN模型的教程)
- 定制数据集的预训练教程 (预训练您的数据集的教程)
- 预训练Colab可视化操场(重建, 稀疏卷积)
- 微调代码
- 在
huggingface
上提供权重和可视化操场 - 在
timm
上提供权重
预训练权重(自监督;无解码器;可直接微调)
注:对于网络定义,我们直接使用timm.models.ResNet
和官方ConvNeXt。
reso.
:图像分辨率;acc@1
:ImageNet-1K微调精度(top-1)
<SOURCE_TEXT>
| 模型架构 | 分辨率 | 准确率@1 | 参数数量 | flops | 权重(自监督,无SparK解码器) |
|:--------------:|:-----:|:-----:|:-------:|:------:|:---------------------------------------------------------------------------------------------------------------------------------------|
| ResNet50 | 224 | 80.6 | 26M | 4.1G | [resnet50_1kpretrained_timm_style.pth](https://drive.google.com/file/d/1H8605HbxGvrsu4x4rIoNr-Wkd7JkxFPQ/view?usp=share_link) |
| ResNet101 | 224 | 82.2 | 45M | 7.9G | [resnet101_1kpretrained_timm_style.pth](https://drive.google.com/file/d/1ZwTztjU-_rfvOVfLoce9SMw2Fx0DQfoO/view?usp=share_link) |
| ResNet152 | 224 | 82.7 | 60M | 11.6G | [resnet152_1kpretrained_timm_style.pth](https://drive.google.com/file/d/1FOVuECnzQAI-OzE-hnrqW7tVpg8kTziM/view?usp=share_link) |
| ResNet200 | 224 | 83.1 | 65M | 15.1G | [resnet200_1kpretrained_timm_style.pth](https://drive.google.com/file/d/1_Q4e30qqhjchrdyW3fT6P98Ga-WnQ57s/view?usp=share_link) |
| ConvNeXt-S | 224 | 84.1 | 50M | 8.7G | [convnextS_1kpretrained_official_style.pth](https://drive.google.com/file/d/1Ah6lgDY5YDNXoXHQHklKKMbEd08RYivN/view?usp=share_link) |
| ConvNeXt-B | 224 | 84.8 | 89M | 15.4G | [convnextB_1kpretrained_official_style.pth](https://drive.google.com/file/d/1ZjWbqI1qoBcqeQijI5xX9E-YNkxpJcYV/view?usp=share_link) |
| ConvNeXt-L | 224 | 85.4 | 198M | 34.4G | [convnextL_1kpretrained_official_style.pth](https://drive.google.com/file/d/1qfYzGUpYBzuA88_kXkVl4KNUwfutMVfw/view?usp=share_link) |
| ConvNeXt-L | 384 | 86.0 | 198M | 101.0G | [convnextL_384_1kpretrained_official_style.pth](https://drive.google.com/file/d/1YgWNXJjI89l35P4ksAmBNWZ2JZCpj9n4/view?usp=share_link) |
<details>
<summary> <b> 预训练权重(带有SparK的UNet风格解码器;可用于重建图像) </b> </summary>
<br>
| 模型架构 | 分辨率 | 准确率@1 | 参数数量 | flops | 权重(自监督,带SparK解码器) |
|:----------:|:-----:|:-----:|:-------:|:------:|:------------------------------------------------------------------------------------------------------------------------------------------|
| ResNet50 | 224 | 80.6 | 26M | 4.1G | [res50_withdecoder_1kpretrained_spark_style.pth](https://drive.google.com/file/d/1STt3w3e5q9eCPZa8VzcJj1zG6p3jLeSF/view?usp=share_link) |
| ResNet101 | 224 | 82.2 | 45M | 7.9G | [res101_withdecoder_1kpretrained_spark_style.pth](https://drive.google.com/file/d/1GjN48LKtlop2YQre6---7ViCWO-3C0yr/view?usp=share_link) |
| ResNet152 | 224 | 82.7 | 60M | 11.6G | [res152_withdecoder_1kpretrained_spark_style.pth](https://drive.google.com/file/d/1U3Cd94j4ZHfYR2dUjWmsEWfjP6Opx4oo/view?usp=share_link) |
| ResNet200 | 224 | 83.1 | 65M | 15.1G | [res200_withdecoder_1kpretrained_spark_style.pth](https://drive.google.com/file/d/13AFSqvIr0v-2hmb4DzVza45t_lhf2CnD/view?usp=share_link) |
| ConvNeXt-S | 224 | 84.1 | 50M | 8.7G | [cnxS224_withdecoder_1kpretrained_spark_style.pth](https://drive.google.com/file/d/1bKvrE4sNq1PfzhWlQJXEPrl2kHqHRZM-/view?usp=share_link) |
| ConvNeXt-L | 384 | 86.0 | 198M | 101.0G | [cnxL384_withdecoder_1kpretrained_spark_style.pth](https://drive.google.com/file/d/1ZI9Jgtb3fKWE_vDFEly29w-1FWZSNwa0/view?usp=share_link) |
</details>
<br>
## 安装与运行
我们强烈推荐您使用 `torch==1.10.0`、`torchvision==0.11.1` 和 `timm==0.5.4` 来进行复现。
检查 [INSTALL.md](INSTALL.md) 以安装所有 pip 依赖。
- **用3行代码加载预训练模型权重**
```python3
# 先下载我们的权重 `resnet50_1kpretrained_timm_style.pth`
import torch, timm
res50, state = timm.create_model('resnet50'), torch.load('resnet50_1kpretrained_timm_style.pth', 'cpu')
res50.load_state_dict(state.get('module', state), strict=False) # 以防模型权重实际上保存在 state['module'] 中
-
预训练
- 在 ImageNet-1k 上的任意 ResNet 或 ConvNeXt: 参见 pretrain/
- 您自己的 CNN 模型: 参见 pretrain/,特别是 pretrain/models/custom.py
-
微调
- 在 ImageNet-1k 上的任意 ResNet 或 ConvNeXt: 请参见 downstream_imagenet/ 以获取后续指导。
- 在 COCO 上的 ResNets: 参见 downstream_d2/
- 在 COCO 上的 ConvNeXts: 参见 downstream_mmdet/
鸣谢
我们参考了以下有用的代码库:
许可证
本项目采用 MIT 许可证。详情请参见 LICENSE。
引用
如果您觉得本项目有用,您可以给我们一个星标 ⭐,或者在您的工作中引用我们 📖:
@Article{tian2023designing,
author = {Keyu Tian 和 Yi Jiang 和 Qishuai Diao 和 Chen Lin 和 Liwei Wang 和 Zehuan Yuan},
title = {为卷积网络设计BERT:稀疏和分层的掩码建模},
journal = {arXiv:2301.03580},
year = {2023},
}
</SOURCE_TEXT>