.. image:: https://github.com/skorch-dev/skorch/blob/master/assets/skorch_bordered.svg :width: 30%
|build| |coverage| |docs| |huggingface| |powered|
一个与scikit-learn兼容的神经网络库,封装了PyTorch。
.. |build| image:: https://github.com/skorch-dev/skorch/workflows/tests/badge.svg :alt: 测试状态 :scale: 100%
.. |coverage| image:: https://github.com/skorch-dev/skorch/blob/master/assets/coverage.svg :alt: 测试覆盖率 :scale: 100%
.. |docs| image:: https://readthedocs.org/projects/skorch/badge/?version=latest :alt: 文档状态 :scale: 100% :target: https://skorch.readthedocs.io/en/latest/?badge=latest
.. |huggingface| image:: https://github.com/skorch-dev/skorch/actions/workflows/test-hf-integration.yml/badge.svg :alt: Hugging Face集成 :scale: 100% :target: https://github.com/skorch-dev/skorch/actions/workflows/test-hf-integration.yml
.. |powered| image:: https://github.com/skorch-dev/skorch/blob/master/assets/powered.svg :alt: 由...提供支持 :scale: 100% :target: https://github.com/ottogroup/
========= 资源
文档 <https://skorch.readthedocs.io/en/latest/?badge=latest>
_源代码 <https://github.com/skorch-dev/skorch/>
_安装 <https://github.com/skorch-dev/skorch#installation>
_
======== 示例
要查看更详细的示例,请看 这里 <https://github.com/skorch-dev/skorch/tree/master/notebooks/README.md>
__。
.. code:: python
import numpy as np
from sklearn.datasets import make_classification
from torch import nn
from skorch import NeuralNetClassifier
X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)
class MyModule(nn.Module):
def __init__(self, num_units=10, nonlin=nn.ReLU()):
super().__init__()
self.dense0 = nn.Linear(20, num_units)
self.nonlin = nonlin
self.dropout = nn.Dropout(0.5)
self.dense1 = nn.Linear(num_units, num_units)
self.output = nn.Linear(num_units, 2)
self.softmax = nn.Softmax(dim=-1)
def forward(self, X, **kwargs):
X = self.nonlin(self.dense0(X))
X = self.dropout(X)
X = self.nonlin(self.dense1(X))
X = self.softmax(self.output(X))
return X
net = NeuralNetClassifier(
MyModule,
max_epochs=10,
lr=0.1,
# 每个epoch随机打乱训练数据
iterator_train__shuffle=True,
)
net.fit(X, y)
y_proba = net.predict_proba(X)
在 sklearn Pipeline <https://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html>
_ 中使用:
.. code:: python
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
pipe = Pipeline([
('scale', StandardScaler()),
('net', net),
])
pipe.fit(X, y)
y_proba = pipe.predict_proba(X)
使用 网格搜索 <https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html>
_:
.. code:: python
from sklearn.model_selection import GridSearchCV
# 禁用skorch内部的训练-验证分割和详细日志
net.set_params(train_split=False, verbose=0)
params = {
'lr': [0.01, 0.02],
'max_epochs': [10, 20],
'module__num_units': [10, 20],
}
gs = GridSearchCV(net, params, refit=False, cv=3, scoring='accuracy', verbose=2)
gs.fit(X, y)
print("最佳得分: {:.3f}, 最佳参数: {}".format(gs.best_score_, gs.best_params_))
skorch还提供了许多便利功能,包括:
学习率调度器 <https://skorch.readthedocs.io/en/stable/callbacks.html#skorch.callbacks.LRScheduler>
_ (热重启、循环学习率等等)使用sklearn(和自定义)评分函数进行评分 <https://skorch.readthedocs.io/en/stable/callbacks.html#skorch.callbacks.EpochScoring>
_早停 <https://skorch.readthedocs.io/en/stable/callbacks.html#skorch.callbacks.EarlyStopping>
_检查点 <https://skorch.readthedocs.io/en/stable/callbacks.html#skorch.callbacks.Checkpoint>
_参数冻结/解冻 <https://skorch.readthedocs.io/en/stable/callbacks.html#skorch.callbacks.Freezer>
_进度条 <https://skorch.readthedocs.io/en/stable/callbacks.html#skorch.callbacks.ProgressBar>
_ (用于命令行界面和jupyter)CLI参数的自动推断 <https://github.com/skorch-dev/skorch/tree/master/examples/cli>
_与GPyTorch集成实现高斯过程 <https://skorch.readthedocs.io/en/latest/user/probabilistic.html>
_与Hugging Face 🤗集成 <https://skorch.readthedocs.io/en/stable/user/huggingface.html>
_
============ 安装
skorch需要Python 3.8或更高版本。
conda安装
你需要一个可用的conda安装。从 这里 <https://conda.io/miniconda.html>
__ 获取适合你系统的正确miniconda版本。
要安装skorch,你需要使用conda-forge通道:
.. code:: bash
conda install -c conda-forge skorch
我们建议使用 conda虚拟环境 <https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html>
_。
注意: conda通道不是由skorch维护者管理的。更多信息可在 这里 <https://github.com/conda-forge/skorch-feedstock>
__ 找到。
pip安装
要使用pip安装,运行:
.. code:: bash
python -m pip install -U skorch
同样,我们建议为此使用 虚拟环境 <https://docs.python.org/3/tutorial/venv.html>
_。
从源代码安装
如果你想使用skorch的最新添加功能或帮助开发,你应该从源代码安装skorch。
使用conda
要使用conda从源代码安装skorch,请按以下步骤操作:
.. code:: bash
git clone https://github.com/skorch-dev/skorch.git
cd skorch
conda create -n skorch-env python=3.10
conda activate skorch-env
conda install -c pytorch pytorch
python -m pip install -r requirements.txt
python -m pip install .
如果你想帮助开发,运行:
.. code:: bash
git clone https://github.com/skorch-dev/skorch.git
cd skorch
conda create -n skorch-env python=3.10
conda activate skorch-env
conda install -c pytorch pytorch
python -m pip install -r requirements.txt
python -m pip install -r requirements-dev.txt
python -m pip install -e .
py.test # 单元测试
pylint skorch # 静态代码检查
你可以将Python版本调整为任何支持的Python版本。
使用pip
对于pip,请按照以下说明操作:
.. code:: bash
git clone https://github.com/skorch-dev/skorch.git
cd skorch
# 创建并激活虚拟环境
python -m pip install -r requirements.txt
# 安装适合你系统的pytorch版本(见下文)
python -m pip install .
如果你想帮助开发,运行:
.. code:: bash
git clone https://github.com/skorch-dev/skorch.git
cd skorch
# 创建并激活虚拟环境
python -m pip install -r requirements.txt
# 安装适合你系统的pytorch版本(见下文)
python -m pip install -r requirements-dev.txt
python -m pip install -e .
py.test # 单元测试
pylint skorch # 静态代码检查
PyTorch
PyTorch不包含在依赖项中,因为你需要的PyTorch版本取决于你的操作系统和设备。有关PyTorch的安装说明,请访问 PyTorch网站 <http://pytorch.org/>
__。skorch官方支持最新的四个PyTorch小版本,目前是:
- 2.0.1
- 2.1.2
- 2.2.2
- 2.3.0
然而,这并不意味着旧版本不能工作,只是它们没有经过测试。由于skorch主要依赖于PyTorch API的稳定部分,旧的PyTorch版本应该可以正常工作。
通常,运行以下命令来安装PyTorch应该可以:
.. code:: bash
# 使用conda:
conda install pytorch pytorch-cuda -c pytorch
# 使用pip
python -m pip install torch
================== 外部资源
- @jakubczakon:
博客文章 <https://neptune.ai/blog/model-training-libraries-pytorch-ecosystem>
_ "8位创建者和核心贡献者谈论他们在PyTorch生态系统中的模型训练库" 2020 - @BenjaminBossan:
演讲1 <https://www.youtube.com/watch?v=Qbu_DCBjVEk>
_ "skorch: 一个与scikit-learn兼容的神经网络库" 在PyCon/PyData 2019上 - @githubnemo:
海报 <https://github.com/githubnemo/skorch-poster>
_ 用于2019年PyTorch开发者大会 - @thomasjpfan:
演讲2 <https://www.youtube.com/watch?v=0J7FaLk0bmQ>
_ "Skorch: Scikit-learn和PyTorch的结合" 在SciPy 2019上 - @thomasjpfan:
演讲3 <https://www.youtube.com/watch?v=yAXsxf2CQ8M>
_ "Skorch - Scikit-learn和PyTorch的结合" 在PyData 2018上 - @BenjaminBossan:
演讲4 <https://youtu.be/y_n7BjDCS-M>
_ "使用Hugging Face和skorch扩展你的 scikit-learn工作流" 在PyData Amsterdam 2023上 (幻灯片4 <https://github.com/BenjaminBossan/presentations/blob/main/2023-09-14-pydata/presentation.org>
_)
============= 交流
-
GitHub讨论 <https://github.com/skorch-dev/skorch/discussions>
_: 用户问题、想法、安装问题、一般讨论。 -
GitHub问题 <https://github.com/skorch-dev/skorch/issues>
_: 错误 报告、功能请求、RFC等。 -
Slack: 我们在
PyTorch Slack服务器 <https://pytorch.slack.com/>
_ 上运行#skorch频道,你可以在此处请求访问权限 <https://bit.ly/ptslack>
_。