onnx2torch 是一个 ONNX 到 PyTorch 的转换器。 我们的转换器:
- 易于使用 - 通过调用
convert
函数转换 ONNX 模型; - 易于扩展 - 用 PyTorch 编写自定义层并使用
@add_converter
注册; - 可以转回 ONNX - 您可以使用
torch.onnx.export
函数将模型转回 ONNX。
如果您发现问题,请告诉我们! 随时创建合并请求。
请注意,此转换器仅支持有限数量的 PyTorch / ONNX 模型和操作。 让我们知道您使用或想从 ONNX 转换到 PyTorch 的模型在这里。
安装
pip install onnx2torch
或者
conda install -c conda-forge onnx2torch
使用方法
以下是一些使用示例。
转换
import onnx
import torch
from onnx2torch import convert
# ONNX 模型路径
onnx_model_path = "/some/path/mobile_net_v2.onnx"
# 您可以传递 ONNX 模型的路径进行转换,或者...
torch_model_1 = convert(onnx_model_path)
# 或者您可以加载常规 ONNX 模型并将其传递给转换器
onnx_model = onnx.load(onnx_model_path)
torch_model_2 = convert(onnx_model)
执行
我们可以像执行原始 torch 模型一样执行返回的 PyTorch 模型
。
import onnxruntime as ort
# 创建示例数据
x = torch.ones((1, 2, 224, 224)).cuda()
out_torch = torch_model_1(x)
ort_sess = ort.InferenceSession(onnx_model_path)
outputs_ort = ort_sess.run(None, {"input": x.numpy()})
# 检查 Onnx 输出与 PyTorch 的对比
print(torch.max(torch.abs(outputs_ort - out_torch.detach().numpy())))
print(np.allclose(outputs_ort, out_torch.detach().numpy(), atol=1.0e-7))
模型
我们已经测试了以下模型:
分割模型:
- DeepLabV3+
- DeepLabV3 ResNet-50 (TorchVision)
- HRNet
- UNet (TorchVision)
- FCN ResNet-50 (TorchVision)
- LRASPP MobileNetV3 (TorchVision)
来自 MMdetection 的检测模型:
来自 TorchVision 的分类模型:
- ResNet-18
- ResNet-50
- MobileNetV2
- MobileNetV3 Large
- EfficientNet-B{0, 1, 2, 3}
- WideResNet-50
- ResNext-50
- VGG-16
- GoogLeNet
- MnasNet
- RegNet
Transformers:
- ViT
- Swin
- GPT-J
:page_facing_up: 当前支持的操作列表可以在这里找到。
如何向转换器添加新操作
这里我们展示如何使用新的 ONNX 操作扩展 onnx2torch,这些操作同时被 PyTorch 和 ONNX 支持
并且具有相同的行为
这样一个模块的例子是 Relu
@add_converter(operation_type="Relu", version=6)
@add_converter(operation_type="Relu", version=13)
@add_converter(operation_type="Relu", version=14)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
return OperationConverterResult(
torch_module=nn.ReLU(),
onnx_mapping=onnx_mapping_from_node(node=node),
)
这里我们为 opset 版本 6、13、14 注册了一个名为 Relu
的操作。
注意,OperationConverterResult
中的 torch_module
参数必须是 torch.nn.Module,而不仅仅是一个可调用对象!
如果操作的行为在不同的 opset 版本之间有所不同,您应该单独实现它。
但有不同的行为
这样一个模块的例子是 ScatterND
# 建议对 ONNX 字符串属性使用 Enum。
class ReductionOnnxAttr(Enum):
NONE = "none"
ADD = "add"
MUL = "mul"
class OnnxScatterND(nn.Module, OnnxToTorchModuleWithCustomExport):
def __init__(self, reduction: ReductionOnnxAttr):
super().__init__()
self._reduction = reduction
# 以下方法应返回 ONNX 属性及其值作为字典。
# 属性的数量、名称和值取决于 opset 版本;
# 方法应返回正确的属性集。
# 注意:为每个键添加类型后缀:reduction -> reduction_s,其中 s 表示 "string"。
def _onnx_attrs(self, opset_version: int) -> Dict[str, Any]:
onnx_attrs: Dict[str, Any] = {}
# 这里我们处理 opset 版本 < 16 的情况,其中没有 "reduction" 属性。
if opset_version < 16:
if self._reduction != ReductionOnnxAttr.NONE:
raise ValueError(
"ScatterND from opset < 16 does not support"
f"reduction attribute != {ReductionOnnxAttr.NONE.value},"
f"got {self._reduction.value}"
)
return onnx_attrs
onnx_attrs["reduction_s"] = self._reduction.value
return onnx_attrs
def forward(
self,
data: torch.Tensor,
indices: torch.Tensor,
updates: torch.Tensor,
) -> torch.Tensor:
def _forward():
# ScatterND 前向实现...
return output
if torch.onnx.is_in_onnx_export():
# 请遵循我们的约定,args 包括:
# 前向函数、操作类型、操作输入、操作属性。
onnx_attrs = self._onnx_attrs(opset_version=get_onnx_version())
return DefaultExportToOnnx.export(
_forward, "ScatterND", data, indices, updates, onnx_attrs
)
return _forward()
@add_converter(operation_type="ScatterND", version=11)
@add_converter(operation_type="ScatterND", version=13)
@add_converter(operation_type="ScatterND", version=16)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
node_attributes = node.attributes
reduction = ReductionOnnxAttr(node_attributes.get("reduction", "none"))
return OperationConverterResult(
torch_module=OnnxScatterND(reduction=reduction),
onnx_mapping=onnx_mapping_from_node(node=node),
)
这里我们通过定义自定义的 _ScatterNDExportToOnnx
使用了一个技巧来将模型从 torch 转回 ONNX。
Opset 版本解决方案
如果您使用的是较旧 opset 的模型,请尝试以下解决方法:
示例
import onnx
from onnx import version_converter
import torch
from onnx2torch import convert
# 加载 ONNX 模型。
model = onnx.load("model.onnx")
# 将模型转换为目标版本。
target_version = 13
converted_model = version_converter.convert_version(model, target_version)
# 转换为 torch。
torch_model = convert(converted_model)
torch.save(torch_model, "model.pt")
注意:仅在使用现有 opset 版本无法将模型转换为 PyTorch 时使用此方法。结果可能会有所不同。
引用
要引用 onnx2torch,请使用 引用此仓库
按钮,或者:
@misc{onnx2torch,
title={onnx2torch},
author={ENOT developers and Kalgin, Igor and Yanchenko, Arseny and Ivanov, Pyoter and Goncharenko, Alexander},
year={2021},
howpublished={\url{https://enot.ai/}},
note={Version: x.y.z}
}
致谢
感谢 Dmitry Chudakov @cakeofwar42 的贡献。
特别感谢 Andrey Denisov @denisovap2013 设计的标志。