English | 简体中文
新特性
v0.10.4 已于 2024-4-23 发布。
亮点:
- 支持在 MLflowVisBackend 中自定义
artifact_location
#1505 - 为
DeepSpeedEngine._zero3_consolidated_16bit_state_dict
启用exclude_frozen_parameters
#1517
详细内容请参阅更新日志。
简介
MMEngine 是一个基于 PyTorch 的、通用的、为深度学习模型训练而生的训练引擎。它是 OpenMMLab 项目的核心基础库,支持了包含检测、分割、分类、自监督在内的多个领域hundreds多个算法。同时,MMEngine 也是一个面向非 OpenMMLab 项目的通用训练框架,其主要特性包括:
集成主流大规模模型训练框架
支持多种训练策略
提供友好的配置系统
覆盖主流训练监控平台
安装
支持的 PyTorch 版本
MMEngine | PyTorch | Python |
---|---|---|
main | >=1.6 <=2.1 | >=3.8, <=3.11 |
>=0.9.0, <=0.10.4 | >=1.6 <=2.1 | >=3.8, <=3.11 |
在安装 MMEngine 之前,请确保按照官方指南成功安装了 PyTorch。
安装 MMEngine
pip install -U openmim
mim install mmengine
验证安装
python -c 'from mmengine.utils.dl_utils import collect_env;print(collect_env())'
快速上手
以在 CIFAR-10 数据集上训练 ResNet-50 模型为例,我们将使用 MMEngine 在不到 80 行代码内构建一个完整的、可配置的训练和验证流程。
构建模型
首先,我们需要定义一个模型,该模型 1)继承自 BaseModel
,2)在 forward
方法中接受一个额外的 mode
参数,除了与数据集相关的参数之外。
- 在训练过程中,
mode
的值为 "loss",forward
方法应返回一个包含 "loss" 键的dict
。 - 在验证过程中,
mode
的值为 "predict",forward 方法应返回包含预测结果和标签的结果。
import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel
class MMResNet50(BaseModel):
def __init__(self):
super().__init__()
self.resnet = torchvision.models.resnet50()
def forward(self, imgs, labels, mode):
x = self.resnet(imgs)
if mode == 'loss':
return {'loss': F.cross_entropy(x, labels)}
elif mode == 'predict':
return x, labels
构建数据集
接下来,我们需要为训练和验证创建数据集和数据加载器。 在这个例子中,我们简单地使用 TorchVision 中支持的内置数据集。
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(batch_size=32,
shuffle=True,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=True,
download=True,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
])))
val_dataloader = DataLoader(batch_size=32,
shuffle=False,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
])))
构建指标
为了验证和测试模型,我们需要定义一个名为 accuracy 的指标来评估模型。这个指标需要继承自 BaseMetric
并实现 process
和 compute_metrics
方法。
from mmengine.evaluator import BaseMetric
class Accuracy(BaseMetric):
def process(self, data_batch, data_samples):
score, gt = data_samples
# 将一个批次的结果保存到 `self.results` 中
self.results.append({
'batch_size': len(gt),
'correct': (score.argmax(dim=1) == gt).sum().cpu(),
})
def compute_metrics(self, results):
total_correct = sum(item['correct'] for item in results)
total_size = sum(item['batch_size'] for item in results)
# 返回一个包含评估指标结果的字典,
# 其中键是指标的名称
return dict(accuracy=100 * total_correct / total_size)
构建 Runner
最后,我们可以用之前定义的 Model
、DataLoader
和 Metrics
,以及一些其他配置来构建一个 Runner,如下所示。
from torch.optim import SGD
from mmengine.runner import Runner
runner = Runner(
model=MMResNet50(),
work_dir='./work_dir',
train_dataloader=train_dataloader,
# 一个执行反向传播和梯度更新等操作的包装器
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
# 设置一些训练配置,如 epochs
train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=Accuracy),
)
启动训练
runner.train()
了解更多
高级教程
- [注册表](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/registry.html) - [配置](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html) - [基础数据集](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/basedataset.html) - [数据变换](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/data_transform.html) - [权重初始化](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/initialize.html) - [可视化](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/visualization.html) - [抽象数据元素](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/data_element.html) - [分布式通信](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/distributed.html) - [日志记录](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/logging.html) - [文件输入输出](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/fileio.html) - [全局管理器 (ManagerMixin)](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/manager_mixin.html) - [使用其他库的模块](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/cross_library.html) - [测试时数据增强](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/test_time_augmentation.html)示例
迁移指南
贡献
我们感谢所有为改进 MMEngine 做出贡献的人。请参考 CONTRIBUTING.md 了解贡献指南。
引用
如果您在研究中发现本项目有用,请考虑引用:
@article{mmengine2022,
title = {{MMEngine}: OpenMMLab 深度学习模型训练基础库},
author = {MMEngine 贡献者},
howpublished = {\url{https://github.com/open-mmlab/mmengine}},
year={2022}
}
许可证
该项目采用 Apache 2.0 许可证。
生态系统
OpenMMLab 的项目
- MIM: MIM 安装 OpenMMLab 包。
- MMCV: OpenMMLab 计算机视觉基础库。
- MMEval: 适用于多个机器学习库的统一评估库。
- MMPreTrain: OpenMMLab 预训练工具箱和基准测试。
- MMagic: OpenMMLab 高级、生成和智能创作工具箱。
- MMDetection: OpenMMLab 检测工具箱和基准测试。
- MMYOLO: OpenMMLab YOLO 系列工具箱和基准测试。
- MMDetection3D: OpenMMLab 新一代通用 3D 目标检测平台。
- MMRotate: OpenMMLab 旋转目标检测工具箱和基准测试。
- MMTracking: OpenMMLab 视频感知工具箱和基准测试。
- MMPose: OpenMMLab 姿态估计工具箱和基准测试。
- MMSegmentation: OpenMMLab 语义分割工具箱和基准测试。
- MMOCR: OpenMMLab 文本检测、识别和理解工具箱。
- MMHuman3D: OpenMMLab 3D 人体参数化模型工具箱和基准测试。
- MMSelfSup: OpenMMLab 自监督学习工具箱和基准测试。
- MMFewShot: OpenMMLab 少样本学习工具箱和基准测试。
- MMAction2: OpenMMLab 新一代视频理解工具箱和基准测试。
- MMFlow: OpenMMLab 光流估计工具箱和基准测试。
- MMDeploy: OpenMMLab 模型部署框架。
- MMRazor: OpenMMLab 模型压缩工具箱和基准测试。
- Playground: 汇集和展示基于 OpenMMLab 的精彩项目的中心枢纽。