项目介绍:数据集蒸馏与训练轨迹匹配
背景
数据集蒸馏是一项极具挑战性的任务,其目标是通过学习少量的合成图像,使模型仅在该数据集上进行训练便可获得接近全真实数据集训练效果的测试表现。这项研究由多位学者共同进行,发表在2022年的计算机视觉与模式识别会议(CVPR)上。
方法
mtt-distillation项目提出了一种通过训练轨迹匹配来进行数据集蒸馏的方法。具体地,研究者通过优化合成图像来诱导出与完整真实数据集相似的网络训练动态。他们通过两个步骤实现这一目标:
- 训练学生网络:在合成数据上进行多次迭代训练,并测量学生与在真实数据上训练的专家网络在参数空间的误差。
- 反向传播优化:利用反向传播调整合成图像,使其在训练性能上趋于与真实数据训练的结果一致。
可穿戴ImageNet:合成平铺纹理
mtt-distillation项目不仅可以处理合成图像,还能够通过在更大的像素画布上随机裁剪(并使用循环填充)来生成良好的训练轨迹,从而形成基于类的连续纹理。这些纹理可以应用于需要这种特性的区域,比如衣物图案。
使用指南
-
仓库下载与环境设置
首先从GitHub克隆项目仓库,并根据自己的显卡型号(RTX 30XX或RTX 20XX)安装对应的Python环境:
git clone https://github.com/GeorgeCazenavette/mtt-distillation.git cd mtt-distillation conda env create -f requirements_11_3.yaml # For RTX 30XX conda env create -f requirements_10_2.yaml # For RTX 20XX conda activate distillation
-
生成专家轨迹
在进行蒸馏前,需要利用
buffer.py
生成一些专家轨迹:python buffer.py --dataset=CIFAR100 --model=ConvNet --train_epochs=50 --num_experts=100 --zca --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}
-
蒸馏过程
使用生成的专家轨迹进行CIFAR-100数据集的蒸馏,将其压缩至每类仅一张合成图像:
python distill.py --dataset=CIFAR100 --ipc=1 --syn_steps=20 --expert_epochs=3 --max_start_epoch=20 --zca --lr_img=1000 --lr_lr=1e-05 --lr_teacher=0.01 --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}
扩展应用:ImageNet和纹理蒸馏
该方法同样适用于ImageNet子集,通过指定子集名称和相应参数,可以实现不同子集的蒸馏任务。同时,还可以通过添加--texture
标志,利用相同的专家轨迹进行纹理的蒸馏,生成可以应用于服装设计的平铺纹理。
致谢
该项目得到了多个组织的支持,包括美国国家科学基金会的研究生奖学金计划,以及J.P. Morgan Chase, IBM和SAP的资助。项目代码基于VICO-UoE的DatasetCondensation项目进行改编。
结语
mtt-distillation项目不仅在学术界取得了突破性进展,也为图像合成技术应用于实际提供了有力的工具。通过该项目,研究者展示了合成数据在机器学习领域中潜在的巨大价值。