Project Icon

mmengine

深度学习训练引擎支持大规模模型训练和多种策略

MMEngine是基于PyTorch的深度学习模型训练基础库,作为OpenMMLab代码库的训练引擎。它集成主流大规模模型训练框架,支持混合精度训练等多种策略,提供友好的配置系统和主流监控平台支持。MMEngine不仅适用于OpenMMLab项目,还可广泛应用于其他深度学习项目。

 
OpenMMLab 官网 热门      OpenMMLab 平台 立即体验
 

PyPI - Python 版本 pytorch PyPI 许可证

简介 | 安装 | 快速入门 | 📘文档 | 🤔报告问题

English | 简体中文

新特性

v0.10.4 已于 2024-4-23 发布。

亮点:

  • 支持在 MLflowVisBackend 中自定义 artifact_location #1505
  • DeepSpeedEngine._zero3_consolidated_16bit_state_dict 启用 exclude_frozen_parameters #1517

详细内容请参阅更新日志

简介

MMEngine 是一个基于 PyTorch 的、通用的、为深度学习模型训练而生的训练引擎。它是 OpenMMLab 项目的核心基础库,支持了包含检测、分割、分类、自监督在内的多个领域hundreds多个算法。同时,MMEngine 也是一个面向非 OpenMMLab 项目的通用训练框架,其主要特性包括:

集成主流大规模模型训练框架

支持多种训练策略

提供友好的配置系统

覆盖主流训练监控平台

安装

支持的 PyTorch 版本
MMEnginePyTorchPython
main>=1.6 <=2.1>=3.8, <=3.11
>=0.9.0, <=0.10.4>=1.6 <=2.1>=3.8, <=3.11

在安装 MMEngine 之前,请确保按照官方指南成功安装了 PyTorch。

安装 MMEngine

pip install -U openmim
mim install mmengine

验证安装

python -c 'from mmengine.utils.dl_utils import collect_env;print(collect_env())'

快速上手

以在 CIFAR-10 数据集上训练 ResNet-50 模型为例,我们将使用 MMEngine 在不到 80 行代码内构建一个完整的、可配置的训练和验证流程。

构建模型

首先,我们需要定义一个模型,该模型 1)继承自 BaseModel,2)在 forward 方法中接受一个额外的 mode 参数,除了与数据集相关的参数之外。

  • 在训练过程中,mode 的值为 "loss",forward 方法应返回一个包含 "loss" 键的 dict
  • 在验证过程中,mode 的值为 "predict",forward 方法应返回包含预测结果和标签的结果。
import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel

class MMResNet50(BaseModel):
    def __init__(self):
        super().__init__()
        self.resnet = torchvision.models.resnet50()

    def forward(self, imgs, labels, mode):
        x = self.resnet(imgs)
        if mode == 'loss':
            return {'loss': F.cross_entropy(x, labels)}
        elif mode == 'predict':
            return x, labels
构建数据集

接下来,我们需要为训练和验证创建数据集数据加载器。 在这个例子中,我们简单地使用 TorchVision 中支持的内置数据集。

import torchvision.transforms as transforms
from torch.utils.data import DataLoader

norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(batch_size=32,
                              shuffle=True,
                              dataset=torchvision.datasets.CIFAR10(
                                  'data/cifar10',
                                  train=True,
                                  download=True,
                                  transform=transforms.Compose([
                                      transforms.RandomCrop(32, padding=4),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize(**norm_cfg)
                                  ])))
val_dataloader = DataLoader(batch_size=32,
                            shuffle=False,
                            dataset=torchvision.datasets.CIFAR10(
                                'data/cifar10',
                                train=False,
                                download=True,
                                transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize(**norm_cfg)
                                ])))
构建指标

为了验证和测试模型,我们需要定义一个名为 accuracy 的指标来评估模型。这个指标需要继承自 BaseMetric 并实现 processcompute_metrics 方法。

from mmengine.evaluator import BaseMetric

class Accuracy(BaseMetric):
    def process(self, data_batch, data_samples):
        score, gt = data_samples
        # 将一个批次的结果保存到 `self.results` 中
        self.results.append({
            'batch_size': len(gt),
            'correct': (score.argmax(dim=1) == gt).sum().cpu(),
        })
    def compute_metrics(self, results):
        total_correct = sum(item['correct'] for item in results)
        total_size = sum(item['batch_size'] for item in results)
        # 返回一个包含评估指标结果的字典,
        # 其中键是指标的名称
        return dict(accuracy=100 * total_correct / total_size)
构建 Runner

最后,我们可以用之前定义的 ModelDataLoaderMetrics,以及一些其他配置来构建一个 Runner,如下所示。

from torch.optim import SGD
from mmengine.runner import Runner

runner = Runner(
    model=MMResNet50(),
    work_dir='./work_dir',
    train_dataloader=train_dataloader,
    # 一个执行反向传播和梯度更新等操作的包装器
    optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
    # 设置一些训练配置,如 epochs
    train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
    val_dataloader=val_dataloader,
    val_cfg=dict(),
    val_evaluator=dict(type=Accuracy),
)
启动训练
runner.train()

了解更多

教程
高级教程 - [注册表](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/registry.html) - [配置](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html) - [基础数据集](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/basedataset.html) - [数据变换](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/data_transform.html) - [权重初始化](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/initialize.html) - [可视化](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/visualization.html) - [抽象数据元素](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/data_element.html) - [分布式通信](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/distributed.html) - [日志记录](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/logging.html) - [文件输入输出](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/fileio.html) - [全局管理器 (ManagerMixin)](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/manager_mixin.html) - [使用其他库的模块](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/cross_library.html) - [测试时数据增强](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/test_time_augmentation.html)
示例
常见用法
设计
迁移指南

贡献

我们感谢所有为改进 MMEngine 做出贡献的人。请参考 CONTRIBUTING.md 了解贡献指南。

引用

如果您在研究中发现本项目有用,请考虑引用:

@article{mmengine2022,
  title   = {{MMEngine}: OpenMMLab 深度学习模型训练基础库},
  author  = {MMEngine 贡献者},
  howpublished = {\url{https://github.com/open-mmlab/mmengine}},
  year={2022}
}

许可证

该项目采用 Apache 2.0 许可证

生态系统

OpenMMLab 的项目

  • MIM: MIM 安装 OpenMMLab 包。
  • MMCV: OpenMMLab 计算机视觉基础库。
  • MMEval: 适用于多个机器学习库的统一评估库。
  • MMPreTrain: OpenMMLab 预训练工具箱和基准测试。
  • MMagic: OpenMMLab 高级、生成和智能创作工具箱。
  • MMDetection: OpenMMLab 检测工具箱和基准测试。
  • MMYOLO: OpenMMLab YOLO 系列工具箱和基准测试。
  • MMDetection3D: OpenMMLab 新一代通用 3D 目标检测平台。
  • MMRotate: OpenMMLab 旋转目标检测工具箱和基准测试。
  • MMTracking: OpenMMLab 视频感知工具箱和基准测试。
  • MMPose: OpenMMLab 姿态估计工具箱和基准测试。
  • MMSegmentation: OpenMMLab 语义分割工具箱和基准测试。
  • MMOCR: OpenMMLab 文本检测、识别和理解工具箱。
  • MMHuman3D: OpenMMLab 3D 人体参数化模型工具箱和基准测试。
  • MMSelfSup: OpenMMLab 自监督学习工具箱和基准测试。
  • MMFewShot: OpenMMLab 少样本学习工具箱和基准测试。
  • MMAction2: OpenMMLab 新一代视频理解工具箱和基准测试。
  • MMFlow: OpenMMLab 光流估计工具箱和基准测试。
  • MMDeploy: OpenMMLab 模型部署框架。
  • MMRazor: OpenMMLab 模型压缩工具箱和基准测试。
  • Playground: 汇集和展示基于 OpenMMLab 的精彩项目的中心枢纽。
项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

白日梦AI

白日梦AI提供专注于AI视频生成的多样化功能,包括文生视频、动态画面和形象生成等,帮助用户快速上手,创造专业级内容。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

讯飞绘镜

讯飞绘镜是一个支持从创意到完整视频创作的智能平台,用户可以快速生成视频素材并创作独特的音乐视频和故事。平台提供多样化的主题和精选作品,帮助用户探索创意灵感。

Project Cover

讯飞文书

讯飞文书依托讯飞星火大模型,为文书写作者提供从素材筹备到稿件撰写及审稿的全程支持。通过录音智记和以稿写稿等功能,满足事务性工作的高频需求,帮助撰稿人节省精力,提高效率,优化工作与生活。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号