Project Icon

apex

NVIDIA Apex加速PyTorch混合精度与分布式训练

Apex是NVIDIA开发的PyTorch扩展库,专注于优化混合精度和分布式训练。该工具提供自动混合精度、分布式数据并行和同步批量归一化等功能,大幅提高训练效率。Apex还集成了多个CUDA优化扩展,如快速层归一化和融合优化器,进一步增强性能。作为持续更新的开源项目,Apex为PyTorch用户提供了最新的训练加速工具。

简介

本代码库包含由NVIDIA维护的实用工具,用于简化PyTorch中的混合精度和分布式训练。这里的部分代码最终将被纳入上游PyTorch。Apex的目的是尽快向用户提供最新的实用工具。

完整API文档:https://nvidia.github.io/apex

GTC 2019PyTorch DevCon 2019幻灯片

内容

1. Amp:自动混合精度

已弃用。请使用PyTorch AMP

apex.amp是一个工具,只需更改脚本中的3行代码即可启用混合精度训练。用户可以通过向amp.initialize提供不同的标志,轻松尝试不同的纯精度和混合精度训练模式。

介绍Amp的网络研讨会 (标志cast_batchnorm已更名为keep_batchnorm_fp32)。

API文档

全面的Imagenet示例

DCGAN示例即将推出...

迁移到新的Amp API(适用于已弃用的"Amp"和"FP16_Optimizer" API的用户)

2. 分布式训练

apex.parallel.DistributedDataParallel已弃用。请使用torch.nn.parallel.DistributedDataParallel

apex.parallel.DistributedDataParallel是一个模块包装器,类似于torch.nn.parallel.DistributedDataParallel。它能够方便地进行多进程分布式训练,针对NVIDIA的NCCL通信库进行了优化。

API文档

Python源代码

示例/演练

Imagenet示例展示了apex.parallel.DistributedDataParallelapex.amp的使用。

同步批量归一化

已弃用。请使用torch.nn.SyncBatchNorm

apex.parallel.SyncBatchNorm扩展了torch.nn.modules.batchnorm._BatchNorm以支持同步BN。在多进程(DistributedDataParallel)训练期间,它会在进程之间进行统计信息的全局规约。当每个GPU只能容纳小型本地小批量时,同步BN被用于这些情况。全局规约的统计信息将BN层的有效批量大小增加到所有进程的全局批量大小(从技术上讲,这是正确的公式)。在我们的一些研究模型中,同步BN被观察到可以提高收敛精度。

检查点

为了正确保存和加载amp训练,我们引入了amp.state_dict(),它包含所有loss_scalers及其对应的未跳过步数,以及amp.load_state_dict()来恢复这些属性。

为了获得逐位精度,我们建议以下工作流程:

# 初始化
opt_level = 'O1'
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)

# 训练模型
...
with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()
...

# 保存检查点
checkpoint = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'amp': amp.state_dict()
}
torch.save(checkpoint, 'amp_checkpoint.pt')
...

# 恢复
model = ...
optimizer = ...
checkpoint = torch.load('amp_checkpoint.pt')

model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
amp.load_state_dict(checkpoint['amp'])

# 继续训练
...

请注意,我们建议使用相同的opt_level恢复模型。另外,我们建议在amp.initialize之后调用load_state_dict方法。

安装

每个apex.contrib模块除了--cpp_ext--cuda_ext之外,还需要一个或多个安装选项。请注意,contrib模块不一定支持稳定的PyTorch版本。

容器

NVIDIA PyTorch容器可在NGC上获取:https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch。 这些容器包含当前可用的所有自定义扩展。

有关详细信息,如:

  • 如何拉取容器
  • 如何运行已拉取的容器
  • 发行说明 请参阅NGC文档

从源代码安装

要从源代码安装Apex,我们建议使用可从https://github.com/pytorch/pytorch获取的nightly PyTorch版本。

https://pytorch.org获取的最新稳定版本也应该可以使用。

我们建议安装Ninja以加快编译速度。

Linux

为了获得最佳性能和完整功能,我们建议通过以下方式安装带有CUDA和C++扩展的Apex:

git clone https://github.com/NVIDIA/apex
cd apex
# 如果pip版本 >= 23.1(参考:https://pip.pypa.io/en/stable/news/#v23-1),支持多个具有相同键的`--config-settings`...
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
# 否则
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" ./

APEX也支持纯Python构建:

pip install -v --disable-pip-version-check --no-build-isolation --no-cache-dir ./

纯Python构建将省略:

  • 使用apex.optimizers.FusedAdam所需的融合内核。
  • 使用apex.normalization.FusedLayerNormapex.normalization.FusedRMSNorm所需的融合内核。
  • 提高apex.parallel.SyncBatchNorm性能和数值稳定性的融合内核。
  • 提高apex.parallel.DistributedDataParallelapex.amp性能的融合内核。 DistributedDataParallelampSyncBatchNorm仍然可用,但可能会较慢。

[试验性] Windows

如果你能够在你的系统上从源代码构建 PyTorch,那么 pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" . 可能会起作用。通过 pip install -v --no-cache-dir . 进行纯 Python 构建更有可能成功。 如果你在 Conda 环境中安装了 PyTorch,请确保在同一环境中安装 Apex。

自定义 C++/CUDA 扩展和安装选项

如果某个模块的要求未满足,则不会构建该模块。

模块名称安装选项其他
apex_C--cpp_ext
amp_C--cuda_ext
syncbn--cuda_ext
fused_layer_norm_cuda--cuda_extapex.normalization
mlp_cuda--cuda_ext
scaled_upper_triang_masked_softmax_cuda--cuda_ext
generic_scaled_masked_softmax_cuda--cuda_ext
scaled_masked_softmax_cuda--cuda_ext
fused_weight_gradient_mlp_cuda--cuda_ext需要 CUDA>=11
permutation_search_cuda--permutation_searchapex.contrib.sparsity
bnp--bnpapex.contrib.groupbn
xentropy--xentropyapex.contrib.xentropy
focal_loss_cuda--focal_lossapex.contrib.focal_loss
fused_index_mul_2d--index_mul_2dapex.contrib.index_mul_2d
fused_adam_cuda--deprecated_fused_adamapex.contrib.optimizers
fused_lamb_cuda--deprecated_fused_lambapex.contrib.optimizers
fast_layer_norm--fast_layer_normapex.contrib.layer_norm。与 fused_layer_norm 不同
fmhalib--fmhaapex.contrib.fmha
fast_multihead_attn--fast_multihead_attnapex.contrib.multihead_attn
transducer_joint_cuda--transducerapex.contrib.transducer
transducer_loss_cuda--transducerapex.contrib.transducer
cudnn_gbn_lib--cudnn_gbn需要 cuDNN>=8.5, apex.contrib.cudnn_gbn
peer_memory_cuda--peer_memoryapex.contrib.peer_memory
nccl_p2p_cuda--nccl_p2p需要 NCCL >= 2.10, apex.contrib.nccl_p2p
fast_bottleneck--fast_bottleneck需要 peer_memory_cudanccl_p2p_cuda, apex.contrib.bottleneck
fused_conv_bias_relu--fused_conv_bias_relu需要 cuDNN>=8.4, apex.contrib.conv_bias_relu
项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

白日梦AI

白日梦AI提供专注于AI视频生成的多样化功能,包括文生视频、动态画面和形象生成等,帮助用户快速上手,创造专业级内容。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

讯飞绘镜

讯飞绘镜是一个支持从创意到完整视频创作的智能平台,用户可以快速生成视频素材并创作独特的音乐视频和故事。平台提供多样化的主题和精选作品,帮助用户探索创意灵感。

Project Cover

讯飞文书

讯飞文书依托讯飞星火大模型,为文书写作者提供从素材筹备到稿件撰写及审稿的全程支持。通过录音智记和以稿写稿等功能,满足事务性工作的高频需求,帮助撰稿人节省精力,提高效率,优化工作与生活。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号