简介
本代码库包含由NVIDIA维护的实用工具,用于简化PyTorch中的混合精度和分布式训练。这里的部分代码最终将被纳入上游PyTorch。Apex的目的是尽快向用户提供最新的实用工具。
完整API文档:https://nvidia.github.io/apex
GTC 2019和PyTorch DevCon 2019幻灯片
内容
1. Amp:自动混合精度
已弃用。请使用PyTorch AMP
apex.amp
是一个工具,只需更改脚本中的3行代码即可启用混合精度训练。用户可以通过向amp.initialize
提供不同的标志,轻松尝试不同的纯精度和混合精度训练模式。
介绍Amp的网络研讨会
(标志cast_batchnorm
已更名为keep_batchnorm_fp32
)。
迁移到新的Amp API(适用于已弃用的"Amp"和"FP16_Optimizer" API的用户)
2. 分布式训练
apex.parallel.DistributedDataParallel
已弃用。请使用torch.nn.parallel.DistributedDataParallel
apex.parallel.DistributedDataParallel
是一个模块包装器,类似于torch.nn.parallel.DistributedDataParallel
。它能够方便地进行多进程分布式训练,针对NVIDIA的NCCL通信库进行了优化。
Imagenet示例展示了apex.parallel.DistributedDataParallel
与apex.amp
的使用。
同步批量归一化
已弃用。请使用torch.nn.SyncBatchNorm
apex.parallel.SyncBatchNorm
扩展了torch.nn.modules.batchnorm._BatchNorm
以支持同步BN。在多进程(DistributedDataParallel)训练期间,它会在进程之间进行统计信息的全局规约。当每个GPU只能容纳小型本地小批量时,同步BN被用于这些情况。全局规约的统计信息将BN层的有效批量大小增加到所有进程的全局批量大小(从技术上讲,这是正确的公式)。在我们的一些研究模型中,同步BN被观察到可以提高收敛精度。
检查点
为了正确保存和加载amp
训练,我们引入了amp.state_dict()
,它包含所有loss_scalers
及其对应的未跳过步数,以及amp.load_state_dict()
来恢复这些属性。
为了获得逐位精度,我们建议以下工作流程:
# 初始化
opt_level = 'O1'
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
# 训练模型
...
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
...
# 保存检查点
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'amp': amp.state_dict()
}
torch.save(checkpoint, 'amp_checkpoint.pt')
...
# 恢复
model = ...
optimizer = ...
checkpoint = torch.load('amp_checkpoint.pt')
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
amp.load_state_dict(checkpoint['amp'])
# 继续训练
...
请注意,我们建议使用相同的opt_level
恢复模型。另外,我们建议在amp.initialize
之后调用load_state_dict
方法。
安装
每个apex.contrib
模块除了--cpp_ext
和--cuda_ext
之外,还需要一个或多个安装选项。请注意,contrib模块不一定支持稳定的PyTorch版本。
容器
NVIDIA PyTorch容器可在NGC上获取:https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch。 这些容器包含当前可用的所有自定义扩展。
有关详细信息,如:
- 如何拉取容器
- 如何运行已拉取的容器
- 发行说明 请参阅NGC文档。
从源代码安装
要从源代码安装Apex,我们建议使用可从https://github.com/pytorch/pytorch获取的nightly PyTorch版本。
从https://pytorch.org获取的最新稳定版本也应该可以使用。
我们建议安装Ninja
以加快编译速度。
Linux
为了获得最佳性能和完整功能,我们建议通过以下方式安装带有CUDA和C++扩展的Apex:
git clone https://github.com/NVIDIA/apex
cd apex
# 如果pip版本 >= 23.1(参考:https://pip.pypa.io/en/stable/news/#v23-1),支持多个具有相同键的`--config-settings`...
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
# 否则
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" ./
APEX也支持纯Python构建:
pip install -v --disable-pip-version-check --no-build-isolation --no-cache-dir ./
纯Python构建将省略:
- 使用
apex.optimizers.FusedAdam
所需的融合内核。 - 使用
apex.normalization.FusedLayerNorm
和apex.normalization.FusedRMSNorm
所需的融合内核。 - 提高
apex.parallel.SyncBatchNorm
性能和数值稳定性的融合内核。 - 提高
apex.parallel.DistributedDataParallel
和apex.amp
性能的融合内核。DistributedDataParallel
、amp
和SyncBatchNorm
仍然可用,但可能会较慢。
[试验性] Windows
如果你能够在你的系统上从源代码构建 PyTorch,那么 pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" .
可能会起作用。通过 pip install -v --no-cache-dir .
进行纯 Python 构建更有可能成功。
如果你在 Conda 环境中安装了 PyTorch,请确保在同一环境中安装 Apex。
自定义 C++/CUDA 扩展和安装选项
如果某个模块的要求未满足,则不会构建该模块。
模块名称 | 安装选项 | 其他 |
---|---|---|
apex_C | --cpp_ext | |
amp_C | --cuda_ext | |
syncbn | --cuda_ext | |
fused_layer_norm_cuda | --cuda_ext | apex.normalization |
mlp_cuda | --cuda_ext | |
scaled_upper_triang_masked_softmax_cuda | --cuda_ext | |
generic_scaled_masked_softmax_cuda | --cuda_ext | |
scaled_masked_softmax_cuda | --cuda_ext | |
fused_weight_gradient_mlp_cuda | --cuda_ext | 需要 CUDA>=11 |
permutation_search_cuda | --permutation_search | apex.contrib.sparsity |
bnp | --bnp | apex.contrib.groupbn |
xentropy | --xentropy | apex.contrib.xentropy |
focal_loss_cuda | --focal_loss | apex.contrib.focal_loss |
fused_index_mul_2d | --index_mul_2d | apex.contrib.index_mul_2d |
fused_adam_cuda | --deprecated_fused_adam | apex.contrib.optimizers |
fused_lamb_cuda | --deprecated_fused_lamb | apex.contrib.optimizers |
fast_layer_norm | --fast_layer_norm | apex.contrib.layer_norm 。与 fused_layer_norm 不同 |
fmhalib | --fmha | apex.contrib.fmha |
fast_multihead_attn | --fast_multihead_attn | apex.contrib.multihead_attn |
transducer_joint_cuda | --transducer | apex.contrib.transducer |
transducer_loss_cuda | --transducer | apex.contrib.transducer |
cudnn_gbn_lib | --cudnn_gbn | 需要 cuDNN>=8.5, apex.contrib.cudnn_gbn |
peer_memory_cuda | --peer_memory | apex.contrib.peer_memory |
nccl_p2p_cuda | --nccl_p2p | 需要 NCCL >= 2.10, apex.contrib.nccl_p2p |
fast_bottleneck | --fast_bottleneck | 需要 peer_memory_cuda 和 nccl_p2p_cuda , apex.contrib.bottleneck |
fused_conv_bias_relu | --fused_conv_bias_relu | 需要 cuDNN>=8.4, apex.contrib.conv_bias_relu |