Project Icon

pytorch-widedeep

基于PyTorch的多模式深度学习工具包,结合表格、文本和图像数据

pytorch-widedeep是一个基于Google的Wide and Deep算法的开源项目,专为多模式数据集设计,支持结合表格、文本和图像数据。该工具包提供多种架构和自定义模型支持,如TabMlp、BasicRNN、TabTransformer等。详细的安装、快速入门和使用扩展步骤可在官方文档中找到。pytorch-widedeep适合多模式数据的深度学习研究和应用。

PyPI 版本 Python 3.8 3.9 3.10 3.11 构建状态 文档状态 codecov 代码风格: black 维护 欢迎贡献 Slack DOI

pytorch-widedeep

一个灵活的多模态深度学习包,用于在 PyTorch 中结合表格数据与文本和图像的广度和深度模型

文档: https://pytorch-widedeep.readthedocs.io

配套文章和教程: infinitoml

LightGBM 的实验和对比: TabularDL vs LightGBM

Slack: 如果您想贡献或只是想与我们聊天,请加入 slack

本文档内容组织如下:

简介

pytorch-widedeep 基于 Google 的 Wide and Deep 算法,针对多模态数据集进行了调整。

总的来说,pytorch-widedeep 是一个用于表格数据深度学习的包。特别是,它旨在通过使用广度和深度模型来方便地将文本和图像与相应的表格数据结合起来。考虑到这一点,可以使用该库实现多种架构。这些架构的主要组件如下图所示:

从数学角度来看,按照论文中的表示方法,没有 deephead 组件的架构可以表示为:

其中 σ 是 sigmoid 函数,'W' 是应用于广度模型和深度模型最终激活的权重矩阵,'a' 是这些最终激活,φ(x) 是原始特征 'x' 的交叉乘积转换,'b' 是偏置项。如果您想知道什么是*"交叉乘积转换"*,这里直接引用论文中的一段话:"对于二元特征,交叉乘积转换(例如,"AND(gender=female, language=en)")当且仅当组成特征("gender=female" 和 "language=en")都为 1 时为 1,否则为 0。"

完全可以使用自定义模型(不一定是库中的模型),只要自定义模型有一个名为 output_dim 的属性,表示最后一层激活的大小,这样就可以构建 WideDeep。有关如何使用自定义组件的示例可以在 Examples 文件夹和下面的部分中找到。

架构

pytorch-widedeep 库提供了多种不同的架构。在本节中,我们将以最简单的形式展示其中的一些架构(即在大多数情况下使用默认参数值),并附上相应的代码片段。请注意,以下所有代码片段都应该能在本地运行。有关不同组件及其参数的更详细解释,请参阅文档。

对于以下示例,我们将使用如下生成的玩具数据集:

import os
import random

import numpy as np
import pandas as pd
from PIL import Image
from faker import Faker
def create_and_save_random_image(image_number, size=(32, 32)):

    if not os.path.exists("images"):
        os.makedirs("images")

    array = np.random.randint(0, 256, (size[0], size[1], 3), dtype=np.uint8)

    image = Image.fromarray(array)

    image_name = f"image_{image_number}.png"
    image.save(os.path.join("images", image_name))

    return image_name


fake = Faker()

cities = ["纽约", "洛杉矶", "芝加哥", "休斯顿"]
names = ["爱丽丝", "鲍勃", "查理", "大卫", "伊娃"]

data = {
    "city": [random.choice(cities) for _ in range(100)],
    "name": [random.choice(names) for _ in range(100)],
    "age": [random.uniform(18, 70) for _ in range(100)],
    "height": [random.uniform(150, 200) for _ in range(100)],
    "sentence": [fake.sentence() for _ in range(100)],
    "other_sentence": [fake.sentence() for _ in range(100)],
    "image_name": [create_and_save_random_image(i) for i in range(100)],
    "target": [random.choice([0, 1]) for _ in range(100)],
}

df = pd.DataFrame(data)

这将创建一个包含100行的数据框,并在本地文件夹中创建一个名为images的目录,其中包含100张随机图像(或只有噪声的图像)。

最简单的架构可能只包含一个组件,如widedeeptabulardeeptextdeepimage,这也是可能的,但让我们从标准的Wide and Deep架构开始示例。从那里开始,如何构建仅由一个组件组成的模型将变得非常简单。

请注意,下面显示的示例使用库中任何可用的模型都几乎相同。例如,TabMlp可以替换为TabResnetTabNetTabTransformer等。同样,BasicRNN可以替换为AttentiveRNNStackedAttentiveRNNHFModel,并使用相应的参数和预处理器(在Hugging Face模型的情况下)。

1. Wide和Tabular组件(又称deeptabular)

from pytorch_widedeep.preprocessing import TabPreprocessor, WidePreprocessor
from pytorch_widedeep.models import Wide, TabMlp, WideDeep
from pytorch_widedeep.training import Trainer

# Wide
wide_cols = ["city"]
crossed_cols = [("city", "name")]
wide_preprocessor = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)
X_wide = wide_preprocessor.fit_transform(df)
wide = Wide(input_dim=np.unique(X_wide).shape[0])

# Tabular
tab_preprocessor = TabPreprocessor(
    embed_cols=["city", "name"], continuous_cols=["age", "height"]
)
X_tab = tab_preprocessor.fit_transform(df)
tab_mlp = TabMlp(
    column_idx=tab_preprocessor.column_idx,
    cat_embed_input=tab_preprocessor.cat_embed_input,
    continuous_cols=tab_preprocessor.continuous_cols,
    mlp_hidden_dims=[64, 32],
)

# WideDeep
model = WideDeep(wide=wide, deeptabular=tab_mlp)

# Train
trainer = Trainer(model, objective="binary")

trainer.fit(
    X_wide=X_wide,
    X_tab=X_tab,
    target=df["target"].values,
    n_epochs=1,
    batch_size=32,
)

2. Tabular和Text数据

from pytorch_widedeep.preprocessing import TabPreprocessor, TextPreprocessor
from pytorch_widedeep.models import TabMlp, BasicRNN, WideDeep
from pytorch_widedeep.training import Trainer

# Tabular
tab_preprocessor = TabPreprocessor(
    embed_cols=["city", "name"], continuous_cols=["age", "height"]
)
X_tab = tab_preprocessor.fit_transform(df)
tab_mlp = TabMlp(
    column_idx=tab_preprocessor.column_idx,
    cat_embed_input=tab_preprocessor.cat_embed_input,
    continuous_cols=tab_preprocessor.continuous_cols,
    mlp_hidden_dims=[64, 32],
)

# Text
text_preprocessor = TextPreprocessor(
    text_col="sentence", maxlen=20, max_vocab=100, n_cpus=1
)
X_text = text_preprocessor.fit_transform(df)
rnn = BasicRNN(
    vocab_size=len(text_preprocessor.vocab.itos),
    embed_dim=16,
    hidden_dim=8,
    n_layers=1,
)

# WideDeep
model = WideDeep(deeptabular=tab_mlp, deeptext=rnn)

# Train
trainer = Trainer(model, objective="binary")

trainer.fit(
    X_tab=X_tab,
    X_text=X_text,
    target=df["target"].values,
    n_epochs=1,
    batch_size=32,
)

3. Tabular和text,通过WideDeep中的head_hidden_dims参数在顶部添加一个FC头

from pytorch_widedeep.preprocessing import TabPreprocessor, TextPreprocessor
from pytorch_widedeep.models import TabMlp, BasicRNN, WideDeep
from pytorch_widedeep.training import Trainer

# Tabular
tab_preprocessor = TabPreprocessor(
    embed_cols=["city", "name"], continuous_cols=["age", "height"]
)
X_tab = tab_preprocessor.fit_transform(df)
tab_mlp = TabMlp(
    column_idx=tab_preprocessor.column_idx,
    cat_embed_input=tab_preprocessor.cat_embed_input,
    continuous_cols=tab_preprocessor.continuous_cols,
    mlp_hidden_dims=[64, 32],
)

# Text
text_preprocessor = TextPreprocessor(
    text_col="sentence", maxlen=20, max_vocab=100, n_cpus=1
)
X_text = text_preprocessor.fit_transform(df)
rnn = BasicRNN(
    vocab_size=len(text_preprocessor.vocab.itos),
    embed_dim=16,
    hidden_dim=8,
    n_layers=1,
)

# WideDeep
model = WideDeep(deeptabular=tab_mlp, deeptext=rnn, head_hidden_dims=[32, 16])

# Train
trainer = Trainer(model, objective="binary")

trainer.fit(
    X_tab=X_tab,
    X_text=X_text,
    target=df["target"].values,
    n_epochs=1,
    batch_size=32,
)

4. Tabular和多个文本列直接传递给WideDeep

```python from pytorch_widedeep.preprocessing import TabPreprocessor, TextPreprocessor from pytorch_widedeep.models import TabMlp, BasicRNN, WideDeep from pytorch_widedeep.training import Trainer

表格数据

tab_preprocessor = TabPreprocessor( embed_cols=["city", "name"], continuous_cols=["age", "height"] ) X_tab = tab_preprocessor.fit_transform(df) tab_mlp = TabMlp( column_idx=tab_preprocessor.column_idx, cat_embed_input=tab_preprocessor.cat_embed_input, continuous_cols=tab_preprocessor.continuous_cols, mlp_hidden_dims=[64, 32], )

文本数据

text_preprocessor_1 = TextPreprocessor( text_col="sentence", maxlen=20, max_vocab=100, n_cpus=1 ) X_text_1 = text_preprocessor_1.fit_transform(df) text_preprocessor_2 = TextPreprocessor( text_col="other_sentence", maxlen=20, max_vocab=100, n_cpus=1 ) X_text_2 = text_preprocessor_2.fit_transform(df) rnn_1 = BasicRNN( vocab_size=len(text_preprocessor_1.vocab.itos), embed_dim=16, hidden_dim=8, n_layers=1, ) rnn_2 = BasicRNN( vocab_size=len(text_preprocessor_2.vocab.itos), embed_dim=16, hidden_dim=8, n_layers=1, )

WideDeep

model = WideDeep(deeptabular=tab_mlp, deeptext=[rnn_1, rnn_2])

训练

trainer = Trainer(model, objective="binary")

trainer.fit( X_tab=X_tab, X_text=[X_text_1, X_text_2], target=df["target"].values, n_epochs=1, batch_size=32, )


**5. 表格数据和多个文本列通过库的`ModelFuser`类进行融合**

<p align="center">
    <img width="500" src="https://yellow-cdn.veclightyear.com/835a84d5/c3d45f38-2409-46bb-9b59-09aa2cc2ad76.png">
</p>

```python
from pytorch_widedeep.preprocessing import TabPreprocessor, TextPreprocessor
from pytorch_widedeep.models import TabMlp, BasicRNN, WideDeep, ModelFuser
from pytorch_widedeep import Trainer

# 表格数据
tab_preprocessor = TabPreprocessor(
    embed_cols=["city", "name"], continuous_cols=["age", "height"]
)
X_tab = tab_preprocessor.fit_transform(df)
tab_mlp = TabMlp(
    column_idx=tab_preprocessor.column_idx,
    cat_embed_input=tab_preprocessor.cat_embed_input,
    continuous_cols=tab_preprocessor.continuous_cols,
    mlp_hidden_dims=[64, 32],
)

# 文本数据
text_preprocessor_1 = TextPreprocessor(
    text_col="sentence", maxlen=20, max_vocab=100, n_cpus=1
)
X_text_1 = text_preprocessor_1.fit_transform(df)
text_preprocessor_2 = TextPreprocessor(
    text_col="other_sentence", maxlen=20, max_vocab=100, n_cpus=1
)
X_text_2 = text_preprocessor_2.fit_transform(df)

rnn_1 = BasicRNN(
    vocab_size=len(text_preprocessor_1.vocab.itos),
    embed_dim=16,
    hidden_dim=8,
    n_layers=1,
)
rnn_2 = BasicRNN(
    vocab_size=len(text_preprocessor_2.vocab.itos),
    embed_dim=16,
    hidden_dim=8,
    n_layers=1,
)

models_fuser = ModelFuser(models=[rnn_1, rnn_2], fusion_method="mult")

# WideDeep
model = WideDeep(deeptabular=tab_mlp, deeptext=models_fuser)

# 训练
trainer = Trainer(model, objective="binary")

trainer.fit(
    X_tab=X_tab,
    X_text=[X_text_1, X_text_2],
    target=df["target"].values,
    n_epochs=1,
    batch_size=32,
)

6. 表格数据、多个文本列和一个图像列。文本列通过库的ModelFuser进行融合,然后所有数据通过用户自定义的WideDeep中的deephead参数(一个自定义的ModelFuser)进行融合

这可能是最不优雅的解决方案,因为它涉及用户的自定义组件和对"传入"张量的切片。未来,我们将包含一个TextAndImageModelFuser来使这个过程更加直观。尽管如此,这并不真的复杂,而且它是一个很好的例子,展示了如何在pytorch-widedeep中使用自定义组件。

请注意,自定义组件的唯一要求是它有一个名为output_dim的属性,该属性返回最后一层激活的大小。换句话说,它不需要继承自BaseWDModelComponent。这个基类只是检查这种属性的存在,并在内部避免一些类型错误。

import torch

from pytorch_widedeep.preprocessing import TabPreprocessor, TextPreprocessor, ImagePreprocessor
from pytorch_widedeep.models import TabMlp, BasicRNN, WideDeep, ModelFuser, Vision
from pytorch_widedeep.models._base_wd_model_component import BaseWDModelComponent
from pytorch_widedeep import Trainer

# 表格数据
tab_preprocessor = TabPreprocessor(
    embed_cols=["city", "name"], continuous_cols=["age", "height"]
)
X_tab = tab_preprocessor.fit_transform(df)
tab_mlp = TabMlp(
    column_idx=tab_preprocessor.column_idx,
    cat_embed_input=tab_preprocessor.cat_embed_input,
    continuous_cols=tab_preprocessor.continuous_cols,
    mlp_hidden_dims=[16, 8],
)

# 文本数据
text_preprocessor_1 = TextPreprocessor(
    text_col="sentence", maxlen=20, max_vocab=100, n_cpus=1
)
X_text_1 = text_preprocessor_1.fit_transform(df)
text_preprocessor_2 = TextPreprocessor(
    text_col="other_sentence", maxlen=20, max_vocab=100, n_cpus=1
)
X_text_2 = text_preprocessor_2.fit_transform(df)
rnn_1 = BasicRNN(
    vocab_size=len(text_preprocessor_1.vocab.itos),
    embed_dim=16,
    hidden_dim=8,
    n_layers=1,
)
rnn_2 = BasicRNN(
    vocab_size=len(text_preprocessor_2.vocab.itos),
    embed_dim=16,
    hidden_dim=8,
    n_layers=1,
)
models_fuser = ModelFuser(
    models=[rnn_1, rnn_2],
    fusion_method="mult",
)

# 图像数据
image_preprocessor = ImagePreprocessor(img_col="image_name", img_path="images")
X_img = image_preprocessor.fit_transform(df)
vision = Vision(pretrained_model_setup="resnet18", head_hidden_dims=[16, 8])

# deephead(自定义模型融合器)
class MyModelFuser(BaseWDModelComponent):
    """
    在文本+图像之上简单地使用Linear + Relu序列,然后对张量的表格切片和
    文本和图像序列模型的输出的连接使用Linear -> Relu -> Linear
    """
    def __init__(
        self,
        tab_incoming_dim: int,
        text_incoming_dim: int,
        image_incoming_dim: int,
        output_units: int,
    ):

        super(MyModelFuser, self).__init__()

self.tab_incoming_dim = tab_incoming_dim self.text_incoming_dim = text_incoming_dim self.image_incoming_dim = image_incoming_dim self.output_units = output_units self.text_and_image_fuser = torch.nn.Sequential( torch.nn.Linear(text_incoming_dim + image_incoming_dim, output_units), torch.nn.ReLU(), ) self.out = torch.nn.Sequential( torch.nn.Linear(output_units + tab_incoming_dim, output_units * 4), torch.nn.ReLU(), torch.nn.Linear(output_units * 4, output_units), )

def forward(self, X: torch.Tensor) -> torch.Tensor: tab_slice = slice(0, self.tab_incoming_dim) text_slice = slice( self.tab_incoming_dim, self.tab_incoming_dim + self.text_incoming_dim ) image_slice = slice( self.tab_incoming_dim + self.text_incoming_dim, self.tab_incoming_dim + self.text_incoming_dim + self.image_incoming_dim, ) X_tab = X[:, tab_slice] X_text = X[:, text_slice] X_img = X[:, image_slice] X_text_and_image = self.text_and_image_fuser(torch.cat([X_text, X_img], dim=1)) return self.out(torch.cat([X_tab, X_text_and_image], dim=1))

@property def output_dim(self): return self.output_units

deephead = MyModelFuser( tab_incoming_dim=tab_mlp.output_dim, text_incoming_dim=models_fuser.output_dim, image_incoming_dim=vision.output_dim, output_units=8, )

WideDeep

model = WideDeep( deeptabular=tab_mlp, deeptext=models_fuser, deepimage=vision, deephead=deephead, )

训练

trainer = Trainer(model, objective="binary")

trainer.fit( X_tab=X_tab, X_text=[X_text_1, X_text_2], X_img=X_img, target=df["target"].values, n_epochs=1, batch_size=32, )

7. 带有多目标损失的表格数据

这个例子主要是为了演示多目标损失的使用,而不是一个真正不同的架构。

from pytorch_widedeep.preprocessing import TabPreprocessor, TextPreprocessor, ImagePreprocessor
from pytorch_widedeep.models import TabMlp, BasicRNN, WideDeep, ModelFuser, Vision
from pytorch_widedeep.losses_multitarget import MultiTargetClassificationLoss
from pytorch_widedeep.models._base_wd_model_component import BaseWDModelComponent
from pytorch_widedeep import Trainer

# 让我们给数据框添加第二个目标
df["target2"] = [random.choice([0, 1]) for _ in range(100)]

# 表格数据
tab_preprocessor = TabPreprocessor(
    embed_cols=["city", "name"], continuous_cols=["age", "height"]
)
X_tab = tab_preprocessor.fit_transform(df)
tab_mlp = TabMlp(
    column_idx=tab_preprocessor.column_idx,
    cat_embed_input=tab_preprocessor.cat_embed_input,
    continuous_cols=tab_preprocessor.continuous_cols,
    mlp_hidden_dims=[64, 32],
)

# 'pred_dim=2'是因为我们有两个二元目标。对于其他类型的目标,
# 请参阅文档
model = WideDeep(deeptabular=tab_mlp, pred_dim=2)

loss = MultiTargetClassificationLoss(binary_config=[0, 1], reduction="mean")

# 当使用多目标损失时,'custom_loss_function'不能为None。
# 请参阅文档
trainer = Trainer(model, objective="multitarget", custom_loss_function=loss)

trainer.fit(
    X_tab=X_tab,
    target=df[["target", "target2"]].values,
    n_epochs=1,
    batch_size=32,
)

deeptabular组件

再次强调,每个独立的组件,widedeeptabulardeeptextdeepimage,都可以独立使用。例如,可以只使用wide,它本质上就是一个线性模型。事实上,pytorch-widedeep中最有趣的功能之一是单独使用deeptabular组件,即通常所说的表格数据深度学习。目前,pytorch-widedeep为该组件提供了以下不同的模型:

  1. Wide:一个简单的线性模型,其中非线性通过交叉乘积变换捕获,如前所述。
  2. TabMlp:一个简单的MLP,接收表示分类特征的嵌入,与连续特征(也可以嵌入)连接。
  3. TabResnet:与前一个模型类似,但嵌入通过一系列由密集层构建的ResNet块传递。
  4. TabNet:TabNet的详细信息可以在TabNet: Attentive Interpretable Tabular Learning中找到。

两个更简单的基于注意力的模型:

  1. ContextAttentionMLP:带有"顶部"注意力机制的MLP,基于Hierarchical Attention Networks for Document Classification
  2. SelfAttentionMLP:带有注意力机制的MLP,是transformer块的简化版本,我们称之为"查询-键自注意力"。

Tabformer系列,即用于表格数据的Transformers:

  1. TabTransformer:TabTransformer的详细信息可以在TabTransformer: Tabular Data Modeling Using Contextual Embeddings中找到。
  2. SAINT:SAINT的详细信息可以在SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training中找到。
  3. FT-Transformer:FT-Transformer的详细信息可以在Revisiting Deep Learning Models for Tabular Data中找到。
  4. TabFastFormer:FastFormer在表格数据上的适配。FastFormer的详细信息可以在FastFormers: Highly Efficient Transformer Models for Natural Language Understanding中找到。
  5. TabPerceiver:Perceiver在表格数据上的适配。Perceiver的详细信息可以在Perceiver: General Perception with Iterative Attention中找到。

基于Weight Uncertainty in Neural Networks的表格数据概率深度学习模型:

  1. BayesianWideWide模型的概率适配。
  2. BayesianTabMlpTabMlp模型的概率适配。

需要注意的是,虽然TabTransformer、SAINT和FT-Transformer有相关的科学出版物,但TabFastFormer和TabPerceiver是我们自己对这些算法在表格数据上的适配。 此外,自监督预训练可用于所有deeptabular模型,但TabPerceiver除外。自监督预训练可通过两种方法或例程使用,我们称之为:编码器-解码器方法和对比去噪方法。请查看文档和示例以了解此功能的详细信息,以及库中的所有其他选项。

文本和图像

对于文本组件deeptext,该库提供以下模型:

  1. BasicRNN:简单的RNN
  2. AttentiveRNN:基于用于文档分类的分层注意力网络的带注意力机制的RNN
  3. StackedAttentiveRNN:AttentiveRNN的堆叠
  4. HFModel:Hugging Face Transformer模型的封装器。目前仅支持BERT、RoBERTa、DistilBERT、ALBERT和ELECTRA系列的模型。这是因为该库旨在解决分类和回归任务,而这些是最"流行"的仅编码器模型,已被证明最适合这些任务。如果有其他模型的需求,将来会包含在内。

对于图像组件deepimage,该库支持以下系列的模型: 'resnet'、'shufflenet'、'resnext'、'wide_resnet'、'regnet'、'densenet'、'mobilenetv3'、'mobilenetv2'、'mnasnet'、'efficientnet'和'squeezenet'。这些通过torchvision提供,并封装在Vision类中。

安装

使用pip安装:

pip install pytorch-widedeep

或直接从GitHub安装

pip install git+https://github.com/jrzaurin/pytorch-widedeep.git

开发者安装

# 克隆仓库
git clone https://github.com/jrzaurin/pytorch-widedeep
cd pytorch-widedeep

# 以开发模式安装
pip install -e .

快速入门

这是一个使用成人数据集进行二元分类的端到端示例,使用WideDeepDense以及默认设置。

使用pytorch-widedeep构建宽(线性)和深度模型:

import numpy as np
import torch
from sklearn.model_selection import train_test_split

from pytorch_widedeep import Trainer
from pytorch_widedeep.preprocessing import WidePreprocessor, TabPreprocessor
from pytorch_widedeep.models import Wide, TabMlp, WideDeep
from pytorch_widedeep.metrics import Accuracy
from pytorch_widedeep.datasets import load_adult


df = load_adult(as_frame=True)
df["income_label"] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
df.drop("income", axis=1, inplace=True)
df_train, df_test = train_test_split(df, test_size=0.2, stratify=df.income_label)

# 定义'列设置'
wide_cols = [
    "education",
    "relationship",
    "workclass",
    "occupation",
    "native-country",
    "gender",
]
crossed_cols = [("education", "occupation"), ("native-country", "occupation")]

cat_embed_cols = [
    "workclass",
    "education",
    "marital-status",
    "occupation",
    "relationship",
    "race",
    "gender",
    "capital-gain",
    "capital-loss",
    "native-country",
]
continuous_cols = ["age", "hours-per-week"]
target = "income_label"
target = df_train[target].values

# 准备数据
wide_preprocessor = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)
X_wide = wide_preprocessor.fit_transform(df_train)

tab_preprocessor = TabPreprocessor(
    cat_embed_cols=cat_embed_cols, continuous_cols=continuous_cols  # type: ignore[arg-type]
)
X_tab = tab_preprocessor.fit_transform(df_train)

# 构建模型
wide = Wide(input_dim=np.unique(X_wide).shape[0], pred_dim=1)
tab_mlp = TabMlp(
    column_idx=tab_preprocessor.column_idx,
    cat_embed_input=tab_preprocessor.cat_embed_input,
    continuous_cols=continuous_cols,
)
model = WideDeep(wide=wide, deeptabular=tab_mlp)

# 训练和验证
trainer = Trainer(model, objective="binary", metrics=[Accuracy])
trainer.fit(
    X_wide=X_wide,
    X_tab=X_tab,
    target=target,
    n_epochs=5,
    batch_size=256,
)

# 在测试集上预测
X_wide_te = wide_preprocessor.transform(df_test)
X_tab_te = tab_preprocessor.transform(df_test)
preds = trainer.predict(X_wide=X_wide_te, X_tab=X_tab_te)

# 保存和加载

# 选项1:这也会保存训练历史和学习率历史(如果使用了LRHistory回调)
trainer.save(path="model_weights", save_state_dict=True)

# 选项2:像其他torch模型一样保存
torch.save(model.state_dict(), "model_weights/wd_model.pt")

# 从这里开始,选项1或2是相同的。我假设用户已准备好数据并定义了新的模型组件:
# 1. 构建模型
model_new = WideDeep(wide=wide, deeptabular=tab_mlp)
model_new.load_state_dict(torch.load("model_weights/wd_model.pt"))

# 2. 实例化训练器
trainer_new = Trainer(model_new, objective="binary")

# 3. 开始拟合或直接预测
preds = trainer_new.predict(X_wide=X_wide, X_tab=X_tab, batch_size=32)

当然,可以做更多。查看Examples文件夹、文档或配套文章以更好地理解包的内容及其功能。

测试

pytest tests

如何贡献

查看CONTRIBUTING页面。

致谢

这个库借鉴了一系列其他库,所以我认为在README中提及它们是公平的(具体提及也包含在代码中)。

CallbacksInitializers的结构和代码受到torchsample库的启发,而后者部分受到Keras的启发。

本库中的TextProcessor类使用了fastaiTokenizerVocabutils.fastai_transforms中的代码是他们代码的微小改编,以便在本库中使用。根据我的经验,他们的Tokenizer是同类中最好的。

本库中的ImageProcessor类使用了Adrian Rosebrock的精彩著作《计算机视觉深度学习》(DL4CV)中的代码。

许可证

本作品采用Apache 2.0和MIT(或任何更新版本)双重许可。如果使用本作品,可以选择其中之一。

SPDX-License-Identifier: Apache-2.0 AND MIT

引用

BibTex

@article{Zaurin_pytorch-widedeep_灵活的_2023,
作者 = {Zaurin, Javier Rodriguez 和 Mulinka, Pavol},
doi = {10.21105/joss.05027},
期刊 = {开源软件期刊},
月份 = 六月,
期号 = {86},
页码 = {5027},
标题 = {{pytorch-widedeep:一个用于多模态深度学习的灵活包}},
网址 = {https://joss.theoj.org/papers/10.21105/joss.05027},
卷号 = {8},
年份 = {2023}
}

APA格式

Zaurin, J. R., 和 Mulinka, P. (2023). pytorch-widedeep:一个用于多模态深度
学习的灵活包. 开源软件期刊, 8(86), 5027.
https://doi.org/10.21105/joss.05027
项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

吐司

探索Tensor.Art平台的独特AI模型,免费访问各种图像生成与AI训练工具,从Stable Diffusion等基础模型开始,轻松实现创新图像生成。体验前沿的AI技术,推动个人和企业的创新发展。

Project Cover

SubCat字幕猫

SubCat字幕猫APP是一款创新的视频播放器,它将改变您观看视频的方式!SubCat结合了先进的人工智能技术,为您提供即时视频字幕翻译,无论是本地视频还是网络流媒体,让您轻松享受各种语言的内容。

Project Cover

美间AI

美间AI创意设计平台,利用前沿AI技术,为设计师和营销人员提供一站式设计解决方案。从智能海报到3D效果图,再到文案生成,美间让创意设计更简单、更高效。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号