Project Icon

sk2torch

实现scikit-learn模型到PyTorch模块的转换

sk2torch是一个开源工具,用于将scikit-learn模型转换为PyTorch模块。它解决了GPU加速推理、模型序列化和梯度计算等问题。sk2torch支持多种scikit-learn模型,使机器学习从业者能够利用PyTorch的GPU加速、TorchScript序列化和反向传播功能。这个项目为scikit-learn用户提供了更多的灵活性和性能优化选择。

sk2torch

sk2torchscikit-learn模型转换为可以通过反向传播进行微调,甚至可以编译为TorchScriptPyTorch模块。

该项目解决的问题:

  1. scikit-learn无法在GPU上执行推理。像SVM这样的模型可以从快速GPU原语中获得很多好处,将模型转换为PyTorch可以立即访问这些原语。
  2. 虽然scikit-learn支持通过pickle进行序列化,但保存的模型无法在不同版本的库之间复现。另一方面,TorchScript提供了一种方便、安全的方式来保存模型及其对应的实现。生成的模型可以在任何安装了PyTorch的地方加载,甚至不需要导入sk2torch。
  3. 虽然某些模型如SVM和线性分类器在理论上是端到端可微的,但scikit-learn没有提供计算已训练模型梯度的机制。PyTorch几乎免费提供了这种功能。

查看使用方法获取使用该库的高级示例。查看工作原理了解支持哪些模块。

有趣的是,这里有一个通过对两类SVM的概率预测进行微分生成的向量场(由此脚本生成):

具有两种模式的向量场箭头图

使用方法

首先,像往常一样使用scikit-learn训练模型:

from sklearn.linear_model import SGDClassifier
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

x, y = create_some_dataset()
model = Pipeline([
    ("center", StandardScaler(with_std=False)),
    ("classify", SGDClassifier()),
])
model.fit(x, y)

然后调用sk2torch.wrap将模型转换为PyTorch等效模型:

import sk2torch
import torch

torch_model = sk2torch.wrap(model)
print(torch_model.predict(torch.tensor([[1., 2., 3.]]).double()))

您可以使用TorchScript保存模型:

import torch.jit

torch.jit.script(torch_model).save("path.pt")

# ... 加载模型时无需安装sk2torch
loaded_model = torch.jit.load("path.pt")

有关训练模型并使用其PyTorch转换的完整示例,请参见examples/svm_vector_field.py

工作原理

sk2torch包含支持的scikit-learn模型的PyTorch重新实现。对于支持的估计器X,sk2torch中的TorchX类能够读取X的属性并将它们转换为torch.Tensor或简单的Python类型。TorchX继承自torch.nn.Module,并为X的每个推理API(如predictdecision_function等)提供相应的方法。

支持哪些模块?获取最新列表的最简单方法是通过supported_classes()函数,它返回所有可以用wrap()包装的scikit-learn类:

>>> import sk2torch
>>> sk2torch.supported_classes()
[<class 'sklearn.tree._classes.DecisionTreeClassifier'>, <class 'sklearn.tree._classes.DecisionTreeRegressor'>, <class 'sklearn.dummy.DummyClassifier'>, <class 'sklearn.ensemble._gb.GradientBoostingClassifier'>, <class 'sklearn.preprocessing._label.LabelBinarizer'>, <class 'sklearn.svm._classes.LinearSVC'>, <class 'sklearn.svm._classes.LinearSVR'>, <class 'sklearn.neural_network._multilayer_perceptron.MLPClassifier'>, <class 'sklearn.kernel_approximation.Nystroem'>, <class 'sklearn.pipeline.Pipeline'>, <class 'sklearn.linear_model._stochastic_gradient.SGDClassifier'>, <class 'sklearn.preprocessing._data.StandardScaler'>, <class 'sklearn.svm._classes.SVC'>, <class 'sklearn.svm._classes.NuSVC'>, <class 'sklearn.svm._classes.SVR'>, <class 'sklearn.svm._classes.NuSVR'>, <class 'sklearn.compose._target.TransformedTargetRegressor'>]

sklearn-onnx的比较

sklearn-onnx是一个开源包,用于将训练好的scikit-learn模型转换为ONNX。与sk2torch一样,sklearn-onnx重新实现了各种模型的推理函数,这意味着它也可以为支持的模块提供序列化和GPU加速。

自然地,这两个库都不会支持没有手动移植的模块。因此,两个库支持所有可用模型/方法的不同子集。例如,sk2torch支持SVC概率预测方法predict_probapredict_log_prob,而sklearn-onnx不支持。

虽然sklearn-onnx将模型导出为ONNX,但sk2torch将模型导出为具有熟悉方法名称的Python对象,这些对象可以进行微调、反向传播,并以用户友好的方式序列化。PyTorch比ONNX更通用,因为PyTorch模型可以根据需要转换为ONNX。

项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

白日梦AI

白日梦AI提供专注于AI视频生成的多样化功能,包括文生视频、动态画面和形象生成等,帮助用户快速上手,创造专业级内容。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

讯飞绘镜

讯飞绘镜是一个支持从创意到完整视频创作的智能平台,用户可以快速生成视频素材并创作独特的音乐视频和故事。平台提供多样化的主题和精选作品,帮助用户探索创意灵感。

Project Cover

讯飞文书

讯飞文书依托讯飞星火大模型,为文书写作者提供从素材筹备到稿件撰写及审稿的全程支持。通过录音智记和以稿写稿等功能,满足事务性工作的高频需求,帮助撰稿人节省精力,提高效率,优化工作与生活。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号