pytorch-widedeep
一个灵活的多模态深度学习包,用于在 PyTorch 中结合表格数据与文本和图像的广度和深度模型
文档: https://pytorch-widedeep.readthedocs.io
配套文章和教程: infinitoml
与 LightGBM
的实验和对比: TabularDL vs LightGBM
Slack: 如果您想贡献或只是想与我们聊天,请加入 slack
本文档内容组织如下:
简介
pytorch-widedeep
基于 Google 的 Wide and Deep 算法,针对多模态数据集进行了调整。
总的来说,pytorch-widedeep
是一个用于表格数据深度学习的包。特别是,它旨在通过使用广度和深度模型来方便地将文本和图像与相应的表格数据结合起来。考虑到这一点,可以使用该库实现多种架构。这些架构的主要组件如下图所示:
从数学角度来看,按照论文中的表示方法,没有 deephead
组件的架构可以表示为:
其中 σ 是 sigmoid 函数,'W' 是应用于广度模型和深度模型最终激活的权重矩阵,'a' 是这些最终激活,φ(x) 是原始特征 'x' 的交叉乘积转换,'b' 是偏置项。如果您想知道什么是*"交叉乘积转换"*,这里直接引用论文中的一段话:"对于二元特征,交叉乘积转换(例如,"AND(gender=female, language=en)")当且仅当组成特征("gender=female" 和 "language=en")都为 1 时为 1,否则为 0。"
完全可以使用自定义模型(不一定是库中的模型),只要自定义模型有一个名为 output_dim
的属性,表示最后一层激活的大小,这样就可以构建 WideDeep
。有关如何使用自定义组件的示例可以在 Examples 文件夹和下面的部分中找到。
架构
pytorch-widedeep
库提供了多种不同的架构。在本节中,我们将以最简单的形式展示其中的一些架构(即在大多数情况下使用默认参数值),并附上相应的代码片段。请注意,以下所有代码片段都应该能在本地运行。有关不同组件及其参数的更详细解释,请参阅文档。
对于以下示例,我们将使用如下生成的玩具数据集:
import os
import random
import numpy as np
import pandas as pd
from PIL import Image
from faker import Faker
def create_and_save_random_image(image_number, size=(32, 32)):
if not os.path.exists("images"):
os.makedirs("images")
array = np.random.randint(0, 256, (size[0], size[1], 3), dtype=np.uint8)
image = Image.fromarray(array)
image_name = f"image_{image_number}.png"
image.save(os.path.join("images", image_name))
return image_name
fake = Faker()
cities = ["纽约", "洛杉矶", "芝加哥", "休斯顿"]
names = ["爱丽丝", "鲍勃", "查理", "大卫", "伊娃"]
data = {
"city": [random.choice(cities) for _ in range(100)],
"name": [random.choice(names) for _ in range(100)],
"age": [random.uniform(18, 70) for _ in range(100)],
"height": [random.uniform(150, 200) for _ in range(100)],
"sentence": [fake.sentence() for _ in range(100)],
"other_sentence": [fake.sentence() for _ in range(100)],
"image_name": [create_and_save_random_image(i) for i in range(100)],
"target": [random.choice([0, 1]) for _ in range(100)],
}
df = pd.DataFrame(data)
这将创建一个包含100行的数据框,并在本地文件夹中创建一个名为images
的目录,其中包含100张随机图像(或只有噪声的图像)。
最简单的架构可能只包含一个组件,如wide
、deeptabular
、deeptext
或deepimage
,这也是可能的,但让我们从标准的Wide and Deep架构开始示例。从那里开始,如何构建仅由一个组件组成的模型将变得非常简单。
请注意,下面显示的示例使用库中任何可用的模型都几乎相同。例如,TabMlp
可以替换为TabResnet
、TabNet
、TabTransformer
等。同样,BasicRNN
可以替换为AttentiveRNN
、StackedAttentiveRNN
或HFModel
,并使用相应的参数和预处理器(在Hugging Face模型的情况下)。
1. Wide和Tabular组件(又称deeptabular)
from pytorch_widedeep.preprocessing import TabPreprocessor, WidePreprocessor
from pytorch_widedeep.models import Wide, TabMlp, WideDeep
from pytorch_widedeep.training import Trainer
# Wide
wide_cols = ["city"]
crossed_cols = [("city", "name")]
wide_preprocessor = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)
X_wide = wide_preprocessor.fit_transform(df)
wide = Wide(input_dim=np.unique(X_wide).shape[0])
# Tabular
tab_preprocessor = TabPreprocessor(
embed_cols=["city", "name"], continuous_cols=["age", "height"]
)
X_tab = tab_preprocessor.fit_transform(df)
tab_mlp = TabMlp(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=tab_preprocessor.continuous_cols,
mlp_hidden_dims=[64, 32],
)
# WideDeep
model = WideDeep(wide=wide, deeptabular=tab_mlp)
# Train
trainer = Trainer(model, objective="binary")
trainer.fit(
X_wide=X_wide,
X_tab=X_tab,
target=df["target"].values,
n_epochs=1,
batch_size=32,
)
2. Tabular和Text数据
from pytorch_widedeep.preprocessing import TabPreprocessor, TextPreprocessor
from pytorch_widedeep.models import TabMlp, BasicRNN, WideDeep
from pytorch_widedeep.training import Trainer
# Tabular
tab_preprocessor = TabPreprocessor(
embed_cols=["city", "name"], continuous_cols=["age", "height"]
)
X_tab = tab_preprocessor.fit_transform(df)
tab_mlp = TabMlp(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=tab_preprocessor.continuous_cols,
mlp_hidden_dims=[64, 32],
)
# Text
text_preprocessor = TextPreprocessor(
text_col="sentence", maxlen=20, max_vocab=100, n_cpus=1
)
X_text = text_preprocessor.fit_transform(df)
rnn = BasicRNN(
vocab_size=len(text_preprocessor.vocab.itos),
embed_dim=16,
hidden_dim=8,
n_layers=1,
)
# WideDeep
model = WideDeep(deeptabular=tab_mlp, deeptext=rnn)
# Train
trainer = Trainer(model, objective="binary")
trainer.fit(
X_tab=X_tab,
X_text=X_text,
target=df["target"].values,
n_epochs=1,
batch_size=32,
)
3. Tabular和text,通过WideDeep
中的head_hidden_dims
参数在顶部添加一个FC头
from pytorch_widedeep.preprocessing import TabPreprocessor, TextPreprocessor
from pytorch_widedeep.models import TabMlp, BasicRNN, WideDeep
from pytorch_widedeep.training import Trainer
# Tabular
tab_preprocessor = TabPreprocessor(
embed_cols=["city", "name"], continuous_cols=["age", "height"]
)
X_tab = tab_preprocessor.fit_transform(df)
tab_mlp = TabMlp(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=tab_preprocessor.continuous_cols,
mlp_hidden_dims=[64, 32],
)
# Text
text_preprocessor = TextPreprocessor(
text_col="sentence", maxlen=20, max_vocab=100, n_cpus=1
)
X_text = text_preprocessor.fit_transform(df)
rnn = BasicRNN(
vocab_size=len(text_preprocessor.vocab.itos),
embed_dim=16,
hidden_dim=8,
n_layers=1,
)
# WideDeep
model = WideDeep(deeptabular=tab_mlp, deeptext=rnn, head_hidden_dims=[32, 16])
# Train
trainer = Trainer(model, objective="binary")
trainer.fit(
X_tab=X_tab,
X_text=X_text,
target=df["target"].values,
n_epochs=1,
batch_size=32,
)
4. Tabular和多个文本列直接传递给WideDeep
```python from pytorch_widedeep.preprocessing import TabPreprocessor, TextPreprocessor from pytorch_widedeep.models import TabMlp, BasicRNN, WideDeep from pytorch_widedeep.training import Trainer
表格数据
tab_preprocessor = TabPreprocessor( embed_cols=["city", "name"], continuous_cols=["age", "height"] ) X_tab = tab_preprocessor.fit_transform(df) tab_mlp = TabMlp( column_idx=tab_preprocessor.column_idx, cat_embed_input=tab_preprocessor.cat_embed_input, continuous_cols=tab_preprocessor.continuous_cols, mlp_hidden_dims=[64, 32], )
文本数据
text_preprocessor_1 = TextPreprocessor( text_col="sentence", maxlen=20, max_vocab=100, n_cpus=1 ) X_text_1 = text_preprocessor_1.fit_transform(df) text_preprocessor_2 = TextPreprocessor( text_col="other_sentence", maxlen=20, max_vocab=100, n_cpus=1 ) X_text_2 = text_preprocessor_2.fit_transform(df) rnn_1 = BasicRNN( vocab_size=len(text_preprocessor_1.vocab.itos), embed_dim=16, hidden_dim=8, n_layers=1, ) rnn_2 = BasicRNN( vocab_size=len(text_preprocessor_2.vocab.itos), embed_dim=16, hidden_dim=8, n_layers=1, )
WideDeep
model = WideDeep(deeptabular=tab_mlp, deeptext=[rnn_1, rnn_2])
训练
trainer = Trainer(model, objective="binary")
trainer.fit( X_tab=X_tab, X_text=[X_text_1, X_text_2], target=df["target"].values, n_epochs=1, batch_size=32, )
**5. 表格数据和多个文本列通过库的`ModelFuser`类进行融合**
<p align="center">
<img width="500" src="https://yellow-cdn.veclightyear.com/835a84d5/c3d45f38-2409-46bb-9b59-09aa2cc2ad76.png">
</p>
```python
from pytorch_widedeep.preprocessing import TabPreprocessor, TextPreprocessor
from pytorch_widedeep.models import TabMlp, BasicRNN, WideDeep, ModelFuser
from pytorch_widedeep import Trainer
# 表格数据
tab_preprocessor = TabPreprocessor(
embed_cols=["city", "name"], continuous_cols=["age", "height"]
)
X_tab = tab_preprocessor.fit_transform(df)
tab_mlp = TabMlp(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=tab_preprocessor.continuous_cols,
mlp_hidden_dims=[64, 32],
)
# 文本数据
text_preprocessor_1 = TextPreprocessor(
text_col="sentence", maxlen=20, max_vocab=100, n_cpus=1
)
X_text_1 = text_preprocessor_1.fit_transform(df)
text_preprocessor_2 = TextPreprocessor(
text_col="other_sentence", maxlen=20, max_vocab=100, n_cpus=1
)
X_text_2 = text_preprocessor_2.fit_transform(df)
rnn_1 = BasicRNN(
vocab_size=len(text_preprocessor_1.vocab.itos),
embed_dim=16,
hidden_dim=8,
n_layers=1,
)
rnn_2 = BasicRNN(
vocab_size=len(text_preprocessor_2.vocab.itos),
embed_dim=16,
hidden_dim=8,
n_layers=1,
)
models_fuser = ModelFuser(models=[rnn_1, rnn_2], fusion_method="mult")
# WideDeep
model = WideDeep(deeptabular=tab_mlp, deeptext=models_fuser)
# 训练
trainer = Trainer(model, objective="binary")
trainer.fit(
X_tab=X_tab,
X_text=[X_text_1, X_text_2],
target=df["target"].values,
n_epochs=1,
batch_size=32,
)
6. 表格数据、多个文本列和一个图像列。文本列通过库的ModelFuser
进行融合,然后所有数据通过用户自定义的WideDeep
中的deephead参数(一个自定义的ModelFuser
)进行融合
这可能是最不优雅的解决方案,因为它涉及用户的自定义组件和对"传入"张量的切片。未来,我们将包含一个TextAndImageModelFuser
来使这个过程更加直观。尽管如此,这并不真的复杂,而且它是一个很好的例子,展示了如何在pytorch-widedeep
中使用自定义组件。
请注意,自定义组件的唯一要求是它有一个名为output_dim
的属性,该属性返回最后一层激活的大小。换句话说,它不需要继承自BaseWDModelComponent
。这个基类只是检查这种属性的存在,并在内部避免一些类型错误。
import torch
from pytorch_widedeep.preprocessing import TabPreprocessor, TextPreprocessor, ImagePreprocessor
from pytorch_widedeep.models import TabMlp, BasicRNN, WideDeep, ModelFuser, Vision
from pytorch_widedeep.models._base_wd_model_component import BaseWDModelComponent
from pytorch_widedeep import Trainer
# 表格数据
tab_preprocessor = TabPreprocessor(
embed_cols=["city", "name"], continuous_cols=["age", "height"]
)
X_tab = tab_preprocessor.fit_transform(df)
tab_mlp = TabMlp(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=tab_preprocessor.continuous_cols,
mlp_hidden_dims=[16, 8],
)
# 文本数据
text_preprocessor_1 = TextPreprocessor(
text_col="sentence", maxlen=20, max_vocab=100, n_cpus=1
)
X_text_1 = text_preprocessor_1.fit_transform(df)
text_preprocessor_2 = TextPreprocessor(
text_col="other_sentence", maxlen=20, max_vocab=100, n_cpus=1
)
X_text_2 = text_preprocessor_2.fit_transform(df)
rnn_1 = BasicRNN(
vocab_size=len(text_preprocessor_1.vocab.itos),
embed_dim=16,
hidden_dim=8,
n_layers=1,
)
rnn_2 = BasicRNN(
vocab_size=len(text_preprocessor_2.vocab.itos),
embed_dim=16,
hidden_dim=8,
n_layers=1,
)
models_fuser = ModelFuser(
models=[rnn_1, rnn_2],
fusion_method="mult",
)
# 图像数据
image_preprocessor = ImagePreprocessor(img_col="image_name", img_path="images")
X_img = image_preprocessor.fit_transform(df)
vision = Vision(pretrained_model_setup="resnet18", head_hidden_dims=[16, 8])
# deephead(自定义模型融合器)
class MyModelFuser(BaseWDModelComponent):
"""
在文本+图像之上简单地使用Linear + Relu序列,然后对张量的表格切片和
文本和图像序列模型的输出的连接使用Linear -> Relu -> Linear
"""
def __init__(
self,
tab_incoming_dim: int,
text_incoming_dim: int,
image_incoming_dim: int,
output_units: int,
):
super(MyModelFuser, self).__init__()
self.tab_incoming_dim = tab_incoming_dim self.text_incoming_dim = text_incoming_dim self.image_incoming_dim = image_incoming_dim self.output_units = output_units self.text_and_image_fuser = torch.nn.Sequential( torch.nn.Linear(text_incoming_dim + image_incoming_dim, output_units), torch.nn.ReLU(), ) self.out = torch.nn.Sequential( torch.nn.Linear(output_units + tab_incoming_dim, output_units * 4), torch.nn.ReLU(), torch.nn.Linear(output_units * 4, output_units), )
def forward(self, X: torch.Tensor) -> torch.Tensor: tab_slice = slice(0, self.tab_incoming_dim) text_slice = slice( self.tab_incoming_dim, self.tab_incoming_dim + self.text_incoming_dim ) image_slice = slice( self.tab_incoming_dim + self.text_incoming_dim, self.tab_incoming_dim + self.text_incoming_dim + self.image_incoming_dim, ) X_tab = X[:, tab_slice] X_text = X[:, text_slice] X_img = X[:, image_slice] X_text_and_image = self.text_and_image_fuser(torch.cat([X_text, X_img], dim=1)) return self.out(torch.cat([X_tab, X_text_and_image], dim=1))
@property def output_dim(self): return self.output_units
deephead = MyModelFuser( tab_incoming_dim=tab_mlp.output_dim, text_incoming_dim=models_fuser.output_dim, image_incoming_dim=vision.output_dim, output_units=8, )
WideDeep
model = WideDeep( deeptabular=tab_mlp, deeptext=models_fuser, deepimage=vision, deephead=deephead, )
训练
trainer = Trainer(model, objective="binary")
trainer.fit( X_tab=X_tab, X_text=[X_text_1, X_text_2], X_img=X_img, target=df["target"].values, n_epochs=1, batch_size=32, )
7. 带有多目标损失的表格数据
这个例子主要是为了演示多目标损失的使用,而不是一个真正不同的架构。
from pytorch_widedeep.preprocessing import TabPreprocessor, TextPreprocessor, ImagePreprocessor
from pytorch_widedeep.models import TabMlp, BasicRNN, WideDeep, ModelFuser, Vision
from pytorch_widedeep.losses_multitarget import MultiTargetClassificationLoss
from pytorch_widedeep.models._base_wd_model_component import BaseWDModelComponent
from pytorch_widedeep import Trainer
# 让我们给数据框添加第二个目标
df["target2"] = [random.choice([0, 1]) for _ in range(100)]
# 表格数据
tab_preprocessor = TabPreprocessor(
embed_cols=["city", "name"], continuous_cols=["age", "height"]
)
X_tab = tab_preprocessor.fit_transform(df)
tab_mlp = TabMlp(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=tab_preprocessor.continuous_cols,
mlp_hidden_dims=[64, 32],
)
# 'pred_dim=2'是因为我们有两个二元目标。对于其他类型的目标,
# 请参阅文档
model = WideDeep(deeptabular=tab_mlp, pred_dim=2)
loss = MultiTargetClassificationLoss(binary_config=[0, 1], reduction="mean")
# 当使用多目标损失时,'custom_loss_function'不能为None。
# 请参阅文档
trainer = Trainer(model, objective="multitarget", custom_loss_function=loss)
trainer.fit(
X_tab=X_tab,
target=df[["target", "target2"]].values,
n_epochs=1,
batch_size=32,
)
deeptabular
组件
再次强调,每个独立的组件,wide
、deeptabular
、deeptext
和deepimage
,都可以独立使用。例如,可以只使用wide
,它本质上就是一个线性模型。事实上,pytorch-widedeep
中最有趣的功能之一是单独使用deeptabular
组件,即通常所说的表格数据深度学习。目前,pytorch-widedeep
为该组件提供了以下不同的模型:
- Wide:一个简单的线性模型,其中非线性通过交叉乘积变换捕获,如前所述。
- TabMlp:一个简单的MLP,接收表示分类特征的嵌入,与连续特征(也可以嵌入)连接。
- TabResnet:与前一个模型类似,但嵌入通过一系列由密集层构建的ResNet块传递。
- TabNet:TabNet的详细信息可以在TabNet: Attentive Interpretable Tabular Learning中找到。
两个更简单的基于注意力的模型:
- ContextAttentionMLP:带有"顶部"注意力机制的MLP,基于Hierarchical Attention Networks for Document Classification
- SelfAttentionMLP:带有注意力机制的MLP,是transformer块的简化版本,我们称之为"查询-键自注意力"。
Tabformer
系列,即用于表格数据的Transformers:
- TabTransformer:TabTransformer的详细信息可以在TabTransformer: Tabular Data Modeling Using Contextual Embeddings中找到。
- SAINT:SAINT的详细信息可以在SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training中找到。
- FT-Transformer:FT-Transformer的详细信息可以在Revisiting Deep Learning Models for Tabular Data中找到。
- TabFastFormer:FastFormer在表格数据上的适配。FastFormer的详细信息可以在FastFormers: Highly Efficient Transformer Models for Natural Language Understanding中找到。
- TabPerceiver:Perceiver在表格数据上的适配。Perceiver的详细信息可以在Perceiver: General Perception with Iterative Attention中找到。
基于Weight Uncertainty in Neural Networks的表格数据概率深度学习模型:
- BayesianWide:
Wide
模型的概率适配。 - BayesianTabMlp:
TabMlp
模型的概率适配。
需要注意的是,虽然TabTransformer、SAINT和FT-Transformer有相关的科学出版物,但TabFastFormer和TabPerceiver是我们自己对这些算法在表格数据上的适配。
此外,自监督预训练可用于所有deeptabular
模型,但TabPerceiver
除外。自监督预训练可通过两种方法或例程使用,我们称之为:编码器-解码器方法和对比去噪方法。请查看文档和示例以了解此功能的详细信息,以及库中的所有其他选项。
文本和图像
对于文本组件deeptext
,该库提供以下模型:
- BasicRNN:简单的RNN
- AttentiveRNN:基于用于文档分类的分层注意力网络的带注意力机制的RNN
- StackedAttentiveRNN:AttentiveRNN的堆叠
- HFModel:Hugging Face Transformer模型的封装器。目前仅支持BERT、RoBERTa、DistilBERT、ALBERT和ELECTRA系列的模型。这是因为该库旨在解决分类和回归任务,而这些是最"流行"的仅编码器模型,已被证明最适合这些任务。如果有其他模型的需求,将来会包含在内。
对于图像组件deepimage
,该库支持以下系列的模型:
'resnet'、'shufflenet'、'resnext'、'wide_resnet'、'regnet'、'densenet'、'mobilenetv3'、'mobilenetv2'、'mnasnet'、'efficientnet'和'squeezenet'。这些通过torchvision
提供,并封装在Vision
类中。
安装
使用pip安装:
pip install pytorch-widedeep
或直接从GitHub安装
pip install git+https://github.com/jrzaurin/pytorch-widedeep.git
开发者安装
# 克隆仓库
git clone https://github.com/jrzaurin/pytorch-widedeep
cd pytorch-widedeep
# 以开发模式安装
pip install -e .
快速入门
这是一个使用成人数据集进行二元分类的端到端示例,使用Wide
和DeepDense
以及默认设置。
使用pytorch-widedeep
构建宽(线性)和深度模型:
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from pytorch_widedeep import Trainer
from pytorch_widedeep.preprocessing import WidePreprocessor, TabPreprocessor
from pytorch_widedeep.models import Wide, TabMlp, WideDeep
from pytorch_widedeep.metrics import Accuracy
from pytorch_widedeep.datasets import load_adult
df = load_adult(as_frame=True)
df["income_label"] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
df.drop("income", axis=1, inplace=True)
df_train, df_test = train_test_split(df, test_size=0.2, stratify=df.income_label)
# 定义'列设置'
wide_cols = [
"education",
"relationship",
"workclass",
"occupation",
"native-country",
"gender",
]
crossed_cols = [("education", "occupation"), ("native-country", "occupation")]
cat_embed_cols = [
"workclass",
"education",
"marital-status",
"occupation",
"relationship",
"race",
"gender",
"capital-gain",
"capital-loss",
"native-country",
]
continuous_cols = ["age", "hours-per-week"]
target = "income_label"
target = df_train[target].values
# 准备数据
wide_preprocessor = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)
X_wide = wide_preprocessor.fit_transform(df_train)
tab_preprocessor = TabPreprocessor(
cat_embed_cols=cat_embed_cols, continuous_cols=continuous_cols # type: ignore[arg-type]
)
X_tab = tab_preprocessor.fit_transform(df_train)
# 构建模型
wide = Wide(input_dim=np.unique(X_wide).shape[0], pred_dim=1)
tab_mlp = TabMlp(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=continuous_cols,
)
model = WideDeep(wide=wide, deeptabular=tab_mlp)
# 训练和验证
trainer = Trainer(model, objective="binary", metrics=[Accuracy])
trainer.fit(
X_wide=X_wide,
X_tab=X_tab,
target=target,
n_epochs=5,
batch_size=256,
)
# 在测试集上预测
X_wide_te = wide_preprocessor.transform(df_test)
X_tab_te = tab_preprocessor.transform(df_test)
preds = trainer.predict(X_wide=X_wide_te, X_tab=X_tab_te)
# 保存和加载
# 选项1:这也会保存训练历史和学习率历史(如果使用了LRHistory回调)
trainer.save(path="model_weights", save_state_dict=True)
# 选项2:像其他torch模型一样保存
torch.save(model.state_dict(), "model_weights/wd_model.pt")
# 从这里开始,选项1或2是相同的。我假设用户已准备好数据并定义了新的模型组件:
# 1. 构建模型
model_new = WideDeep(wide=wide, deeptabular=tab_mlp)
model_new.load_state_dict(torch.load("model_weights/wd_model.pt"))
# 2. 实例化训练器
trainer_new = Trainer(model_new, objective="binary")
# 3. 开始拟合或直接预测
preds = trainer_new.predict(X_wide=X_wide, X_tab=X_tab, batch_size=32)
当然,可以做更多。查看Examples文件夹、文档或配套文章以更好地理解包的内容及其功能。
测试
pytest tests
如何贡献
查看CONTRIBUTING页面。
致谢
这个库借鉴了一系列其他库,所以我认为在README中提及它们是公平的(具体提及也包含在代码中)。
Callbacks
和Initializers
的结构和代码受到torchsample
库的启发,而后者部分受到Keras
的启发。
本库中的TextProcessor
类使用了fastai
的Tokenizer
和Vocab
。utils.fastai_transforms
中的代码是他们代码的微小改编,以便在本库中使用。根据我的经验,他们的Tokenizer
是同类中最好的。
本库中的ImageProcessor
类使用了Adrian Rosebrock的精彩著作《计算机视觉深度学习》(DL4CV)中的代码。
许可证
本作品采用Apache 2.0和MIT(或任何更新版本)双重许可。如果使用本作品,可以选择其中之一。
SPDX-License-Identifier: Apache-2.0 AND MIT
引用
BibTex
@article{Zaurin_pytorch-widedeep_灵活的_2023,
作者 = {Zaurin, Javier Rodriguez 和 Mulinka, Pavol},
doi = {10.21105/joss.05027},
期刊 = {开源软件期刊},
月份 = 六月,
期号 = {86},
页码 = {5027},
标题 = {{pytorch-widedeep:一个用于多模态深度学习的灵活包}},
网址 = {https://joss.theoj.org/papers/10.21105/joss.05027},
卷号 = {8},
年份 = {2023}
}
APA格式
Zaurin, J. R., 和 Mulinka, P. (2023). pytorch-widedeep:一个用于多模态深度
学习的灵活包. 开源软件期刊, 8(86), 5027.
https://doi.org/10.21105/joss.05027