介绍
sklearn-onnx 将 scikit-learn 模型转换为 ONNX 格式。 一旦转换为 ONNX 格式,你可以使用 ONNX Runtime 之类的工具进行高性能评分。 所有转换器都经过 onnxruntime 测试。 任何外部转换器都可以注册来转换 scikit-learn 的流水线,包括来自外部库的模型或转换器。
文档
完整的文档包括教程,请参阅 https://onnx.ai/sklearn-onnx/。 支持的 scikit-learn 模型 最后支持的 opset 是 21。
你也可以在 现有问题 中找到答案,或者提交一个新问题。
安装
你可以从 PyPi 安装:
pip install skl2onnx
或者你可以从源代码安装最新的更改。
pip install git+https://github.com/onnx/sklearn-onnx.git
入门
# 训练一个模型
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
iris = load_iris()
X, y = iris.data, iris.target
X = X.astype(np.float32)
X_train, X_test, y_train, y_test = train_test_split(X, y)
clr = RandomForestClassifier()
clr.fit(X_train, y_train)
# 转换为 ONNX 格式
from skl2onnx import to_onnx
onx = to_onnx(clr, X[:1])
with open("rf_iris.onnx", "wb") as f:
f.write(onx.SerializeToString())
# 使用 onnxruntime 进行预测
import onnxruntime as rt
sess = rt.InferenceSession("rf_iris.onnx", providers=["CPUExecutionProvider"])
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
pred_onx = sess.run([label_name], {input_name: X_test.astype(np.float32)})[0]
贡献
我们欢迎以反馈、点子或代码的形式进行贡献。