通过匹配训练轨迹进行数据集蒸馏
项目页面 | 论文
这个仓库包含了训练专家轨迹和从我们的"通过匹配训练轨迹进行数据集蒸馏"论文(CVPR 2022)中蒸馏合成数据的代码。请查看我们的项目页面以获取更多结果。
通过匹配训练轨迹进行数据集蒸馏
George Cazenavette, Tongzhou Wang, Antonio Torralba, Alexei A. Efros, Jun-Yan Zhu
CMU, MIT, UC Berkeley
CVPR 2022 (口头报告)
"数据集蒸馏"的任务是学习少量合成图像,使得仅在这个集合上训练的模型能够获得与在完整真实数据集上训练的模型相似的测试性能。
我们的方法通过直接优化假图像来蒸馏合成数据集,以诱导与完整真实数据集相似的网络训练动态。我们在合成数据上对"学生"网络进行多次迭代训练,测量"学生"网络和在真实数据上训练的"专家"网络之间在参数空间的误差,并通过所有学生网络更新反向传播来优化合成像素。
可穿戴ImageNet:合成可平铺纹理
我们可以不将合成数据视为单独的图像,而是鼓励在更大的像素画布上的每个随机裁剪(使用循环填充)诱导良好的训练轨迹。这会产生在边缘连续的基于类别的纹理。
有了这些可平铺的纹理,我们可以将它们应用到需要这种特性的区域,比如服装图案。
可视化使用FAB3D制作
入门指南
首先,下载我们的仓库:
git clone https://github.com/GeorgeCazenavette/mtt-distillation.git
cd mtt-distillation
为了快速安装,我们提供了.yaml
文件。
如果你有RTX 30XX GPU(或更新的型号),运行
conda env create -f requirements_11_3.yaml
如果你有RTX 20XX GPU(或更旧的型号),运行
conda env create -f requirements_10_2.yaml
然后你可以通过以下命令激活conda环境
conda activate distillation
Quadro用户请注意:
torch.nn.DataParallel
似乎在Quadro A5000 GPU上不起作用,这可能也适用于其他Quadro卡。
如果你在训练过程中遇到无限挂起的情况,请尝试通过在命令前加上CUDA_VISIBLE_DEVICES=0
来仅使用1个GPU运行进程。
生成专家轨迹
在进行任何蒸馏之前,你需要使用buffer.py
生成一些专家轨迹
以下命令将在CIFAR-100上训练100个使用ZCA白化的ConvNet模型,每个模型训练50个epoch:
python buffer.py --dataset=CIFAR100 --model=ConvNet --train_epochs=50 --num_experts=100 --zca --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}
我们对所有专家使用默认学习率训练了50个epoch。
通过更改--num_experts
可以更快地获得较差(但仍然有趣)的结果。请注意,专家只需训练一次,可以在多个蒸馏实验中重复使用。
通过匹配训练轨迹进行蒸馏
以下命令将使用我们刚刚生成的缓冲区将CIFAR-100蒸馏为每个类别只有1张图像:
python distill.py --dataset=CIFAR100 --ipc=1 --syn_steps=20 --expert_epochs=3 --max_start_epoch=20 --zca --lr_img=1000 --lr_lr=1e-05 --lr_teacher=0.01 --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}
请在下面找到完整的超参数列表:
ImageNet
我们的方法还可以将ImageNet
的子集蒸馏为低支持的合成集。
在使用buffer.py
生成专家轨迹或使用distill.py
蒸馏数据集时,你必须使用--subset
标志指定ImageNet的命名子集。
例如,
python distill.py --dataset=ImageNet --subset=imagefruit --model=ConvNetD5 --ipc=1 --res=128 --syn_steps=20 --expert_epochs=2 --max_start_epoch=10 --lr_img=1000 --lr_lr=1e-06 --lr_teacher=0.01 --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}
将把imagefruit
子集(128x128分辨率)蒸馏为以下10张图像
要注册你自己的ImageNet子集,你可以将其添加到utils.py
顶部的Config
类中。
只需创建一个包含所需类别ID的列表,并将其添加到字典中。
这个gist包含所有1k个ImageNet类别及其对应的编号。
纹理蒸馏
你还可以使用相同的专家轨迹集(除了使用ZCA的那些)通过简单地添加--texture
标志将类别蒸馏为环面纹理。
例如,
python distill.py --texture --dataset=ImageNet --subset=imagesquawk --model=ConvNetD5 --ipc=1 --res=256 --syn_steps=20 --expert_epochs=2 --max_start_epoch=10 --lr_img=1000 --lr_lr=1e-06 --lr_teacher=0.01 --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}
将把imagesquawk
子集(256x256分辨率)蒸馏为以下10个纹理
致谢
我们要感谢Alexander Li、Assaf Shocher、Gokul Swamy、Kangle Deng、Ruihan Gao、Nupur Kumari、Muyang Li、Gaurav Parmar、Chonghyuk Song、Sheng-Yu Wang和Bingliang Zhang,以及阿德莱德大学Simon Lucey的视觉小组提供的宝贵反馈。这项工作部分得到了NSF研究生奖学金(Grant No. DGE1745016)以及J.P. Morgan Chase、IBM和SAP提供的资助。我们的代码改编自https://github.com/VICO-UoE/DatasetCondensation
相关工作
- Tongzhou Wang等人"数据集蒸馏",arXiv预印本2018
- Bo Zhao等人"使用梯度匹配进行数据集压缩",ICLR 2020
- Bo Zhao和Hakan Bilen"使用可微分孪生数据增强进行数据集压缩",ICML 2021
- Timothy Nguyen等人"从核岭回归进行数据集元学习",ICLR 2021
- Timothy Nguyen等人"使用无限宽卷积网络进行数据集蒸馏",NeurIPS 2021
- Bo Zhao和Hakan Bilen"使用分布匹配进行数据集压缩",arXiv预印本2021
- Kai Wang等人"CAFE:通过对齐特征学习压缩数据集",CVPR 2022
引用
如果你发现我们的代码对你的研究有用,请引用我们的论文。
@inproceedings{
cazenavette2022distillation,
title={Dataset Distillation by Matching Training Trajectories},
author={George Cazenavette and Tongzhou Wang and Antonio Torralba and Alexei A. Efros and Jun-Yan Zhu},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2022}
}
@iniroceedings{
cazenavette2022textures,
title={Wearable ImageNet: Synthesizing Tileable Textures via Dataset Distillation},
author= {George Cazenavette and Tongzhou Wang and Antonio Torralba and Alexei A. Efros and Jun-Yan Zhu},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
year={2022},
}