Project Icon

pytorch-frame

模块化深度学习框架用于异构表格数据

PyTorch Frame是一个为异构表格数据设计的深度学习框架,支持数值、分类、时间、文本和图像等多种列类型。它采用模块化架构,实现了先进的深度表格模型,并可与大型语言模型集成。该框架提供了便捷的mini-batch加载器、基准数据集和自定义数据接口,简化了表格数据的深度学习研究过程,适用于各层次研究人员。框架内置多个预实现的深度表格模型,如Trompt、FTTransformer和TabNet等,并提供与XGBoost等GBDT模型的性能对比基准。PyTorch Frame无缝集成于PyTorch生态系统,便于与其他PyTorch库协同使用,为端到端的深度学习研究提供了便利。



一个用于在异构表格数据上构建神经网络模型的模块化深度学习框架。


arXiv PyPI 版本 测试状态 文档状态 贡献 Slack

文档 | 论文

PyTorch Frame 是 PyTorch 的深度学习扩展,专为包含数值、类别、时间、文本和图像等不同列类型的异构表格数据设计。它提供了一个模块化框架,用于实现现有和未来的方法。该库包含最先进模型的方法、用户友好的小批量加载器、基准数据集以及自定义数据集成接口。

PyTorch Frame 让表格数据的深度学习研究变得更加普及,既适合新手也适合专家。我们的目标是:

  1. 促进表格数据的深度学习: 历史上,基于树的模型(如 GBDT)在表格学习方面表现出色,但存在一些明显的局限性,例如与下游模型的集成困难,以及处理复杂列类型(如文本、序列和嵌入)的能力。深度表格模型有望解决这些局限性。我们的目标是通过模块化实现并支持多样的列类型,来促进表格数据的深度学习研究。

  2. 与大型语言模型等多样化模型架构集成: PyTorch Frame 支持与各种不同的架构集成,包括大型语言模型。使用任何下载的模型或嵌入 API 端点,您可以为文本数据生成嵌入,并与其他复杂语义类型一起使用深度学习模型进行训练。我们支持以下(但不限于):


库亮点

PyTorch Frame 直接构建在 PyTorch 之上,确保现有 PyTorch 用户能够顺利过渡。主要特性包括:

  • 多样化列类型: PyTorch Frame 支持跨各种列类型的学习:numericalcategoricalmulticategoricaltext_embeddedtext_tokenizedtimestampimage_embeddedembedding。详细教程请参见此处
  • 模块化模型设计: 支持模块化深度学习模型实现,促进代码重用、清晰编码和实验灵活性。更多详情请参见架构概览
  • 模型 实现了许多最先进的深度表格模型以及强大的 GBDT(XGBoost、CatBoost 和 LightGBM),并支持超参数调优。
  • 数据集: 提供了一系列可直接使用的表格数据集。同时支持自定义数据集以解决您自己的问题。 我们对深度表格模型与 GBDT 进行了基准测试
  • PyTorch 集成: 与其他 PyTorch 库无缝集成,便于 PyTorch Frame 与下游 PyTorch 模型的端到端训练。例如,通过与 PyTorch 图神经网络库 PyG 集成,我们可以对关系数据库进行深度学习。更多信息请参见 RelBench示例代码(进行中)

架构概览

PyTorch Frame 中的模型遵循 FeatureEncoderTableConvDecoder 的模块化设计,如下图所示:

本质上,这种模块化设置使用户能够轻松尝试各种架构:

  • Materialization 处理将原始 pandas DataFrame 转换为适合 PyTorch 训练和建模的 TensorFrame
  • FeatureEncoderTensorFrame 编码为大小为 [batch_size, num_cols, channels] 的隐藏列嵌入。
  • TableConv 对隐藏嵌入进行列间交互建模。
  • Decoder 为每行生成嵌入/预测。

快速上手

在这个快速上手中,我们将展示如何仅用几行代码就能轻松创建和训练一个深度表格模型。

构建并训练您自己的深度表格模型

作为示例,我们将按照 PyTorch Frame 的模块化架构实现一个简单的 ExampleTransformer。 在下面的示例中:

  • self.encoder 将输入的 TensorFrame 映射到大小为 [batch_size, num_cols, channels] 的嵌入。
  • self.convs 迭代地将大小为 [batch_size, num_cols, channels] 的嵌入转换为相同大小的嵌入。
  • self.decoder 将大小为 [batch_size, num_cols, channels] 的嵌入池化为 [batch_size, out_channels]
from torch import Tensor
from torch.nn import Linear, Module, ModuleList

from torch_frame import TensorFrame, stype
from torch_frame.nn.conv import TabTransformerConv
from torch_frame.nn.encoder import (
    EmbeddingEncoder,
    LinearEncoder,
    StypeWiseFeatureEncoder,
)

class ExampleTransformer(Module):
    def __init__(
        self,
        channels, out_channels, num_layers, num_heads,
        col_stats, col_names_dict,
    ):
        super().__init__()
        self.encoder = StypeWiseFeatureEncoder(
            out_channels=channels,
            col_stats=col_stats,
            col_names_dict=col_names_dict,
            stype_encoder_dict={
                stype.categorical: EmbeddingEncoder(),
                stype.numerical: LinearEncoder()
            },
        )
        self.convs = ModuleList([
            TabTransformerConv(
                channels=channels,
                num_heads=num_heads,
            ) for _ in range(num_layers)
        ])
        self.decoder = Linear(channels, out_channels)

    def forward(self, tf: TensorFrame) -> Tensor:
        x, _ = self.encoder(tf)
        for conv in self.convs:
            x = conv(x)
        out = self.decoder(x.mean(dim=1))
        return out

要准备数据,我们可以快速实例化一个预定义的数据集并创建一个与 PyTorch 兼容的数据加载器,如下所示:

from torch_frame.datasets import Yandex
from torch_frame.data import DataLoader

dataset = Yandex(root='/tmp/adult', name='adult')
dataset.materialize()
train_dataset = dataset[:0.8]
train_loader = DataLoader(train_dataset.tensor_frame, batch_size=128,
                          shuffle=True)

然后,我们只需按照标准PyTorch训练流程来优化模型参数。就这么简单!

import torch
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ExampleTransformer(
    channels=32,
    out_channels=dataset.num_classes,
    num_layers=2,
    num_heads=8,
    col_stats=train_dataset.col_stats,
    col_names_dict=train_dataset.tensor_frame.col_names_dict,
).to(device)

optimizer = torch.optim.Adam(model.parameters())

for epoch in range(50):
    for tf in train_loader:
        tf = tf.to(device)
        pred = model.forward(tf)
        loss = F.cross_entropy(pred, tf.y)
        optimizer.zero_grad()
        loss.backward()

已实现的深度表格模型

以下是目前支持的深度表格模型列表:

此外,我们还为想要将模型性能与GBDT进行比较的用户实现了XGBoostCatBoostLightGBM示例,这些示例使用Optuna进行超参数调优。

基准测试

我们在多种规模和任务类型的公共数据集上对最近的表格深度学习模型与GBDT进行了基准测试。

下图展示了各种模型在小型回归数据集上的性能,其中行代表模型名称,列代表数据集索引(这里我们有13个数据集)。有关分类和更大数据集的更多结果,请查看基准测试文档

模型名称数据集0数据集1数据集2数据集3数据集4数据集5数据集6数据集7数据集8数据集9数据集10数据集11数据集12
XGBoost0.250±0.0000.038±0.0000.187±0.0000.475±0.0000.328±0.0000.401±0.0000.249±0.0000.363±0.0000.904±0.0000.056±0.0000.820±0.0000.857±0.0000.418±0.000
CatBoost0.265±0.0000.062±0.0000.128±0.0000.336±0.0000.346±0.0000.443±0.0000.375±0.0000.273±0.0000.881±0.0000.040±0.0000.756±0.0000.876±0.0000.439±0.000
LightGBM0.253±0.0000.054±0.0000.112±0.0000.302±0.0000.325±0.0000.384±0.0000.295±0.0000.272±0.0000.877±0.0000.011±0.0000.702±0.0000.863±0.0000.395±0.000
Trompt0.261±0.0030.015±0.0050.118±0.0010.262±0.0010.323±0.0010.418±0.0030.329±0.0090.312±0.002OOM0.008±0.0010.779±0.0060.874±0.0040.424±0.005
ResNet0.288±0.0060.018±0.0030.124±0.0010.268±0.0010.335±0.0010.434±0.0040.325±0.0120.324±0.0040.895±0.0050.036±0.0020.794±0.0060.875±0.0040.468±0.004
FTTransformerBucket0.325±0.0080.096±0.0050.360±0.3540.284±0.0050.342±0.0040.441±0.0030.345±0.0070.339±0.003OOM0.105±0.0110.807±0.0100.885±0.0080.468±0.006
ExcelFormer0.262±0.0040.099±0.0030.128±0.0000.264±0.0030.331±0.0030.411±0.0050.298±0.0120.308±0.007OOM0.011±0.0010.785±0.0110.890±0.0030.431±0.006
FTTransformer0.335±0.0100.161±0.0220.140±0.0020.277±0.0040.335±0.0030.445±0.0030.361±0.0180.345±0.005OOM0.106±0.0120.826±0.0050.896±0.0070.461±0.003
TabNet0.279±0.0030.224±0.0160.141±0.0100.275±0.0020.348±0.0030.451±0.0070.355±0.0300.332±0.0040.992±0.1820.015±0.0020.805±0.0140.885±0.0130.544±0.011
TabTransformer0.624±0.0030.229±0.0030.369±0.0050.340±0.0040.388±0.0020.539±0.0030.619±0.0050.351±0.0010.893±0.0050.431±0.0010.819±0.0020.886±0.0050.545±0.004

我们可以看到,一些最新的深度表格模型能够达到与强大的GBDT相当的模型性能(尽管训练速度慢5-100倍)。使深度表格模型在更少计算资源下表现更好是未来研究的一个富有成果的方向。

我们还在一个带有一列文本的真实世界表格数据集(葡萄酒评论)上对不同的文本编码器进行了基准测试。下表显示了性能结果:

测试准确率方法模型名称来源
0.7926预训练sentence-transformers/all-distilroberta-v1 (1.25亿参数)Hugging Face
0.7998预训练embed-english-v3.0 (维度大小: 1024)Cohere
0.8102预训练text-embedding-ada-002 (维度大小: 1536)OpenAI
0.8147预训练voyage-01 (维度大小: 1024)Voyage AI
0.8203预训练intfloat/e5-mistral-7b-instruct (70亿参数)Hugging Face
0.8230LoRA微调DistilBERT (6600万参数)Hugging Face

Hugging Face文本编码器的基准测试脚本在这个文件中,其他文本编码器的基准测试脚本在这个文件中。

安装

PyTorch Frame支持Python 3.8到Python 3.11版本。

pip install pytorch_frame

查看安装指南了解其他安装选项。

引用

如果您在工作中使用了PyTorch Frame,请引用我们的论文(Bibtex如下)。

@article{hu2024pytorch,
  title={PyTorch Frame: A Modular Framework for Multi-Modal Tabular Learning},
  author={Hu, Weihua and Yuan, Yiwen and Zhang, Zecheng and Nitta, Akihiro and Cao, Kaidi and Kocijan, Vid and Leskovec, Jure and Fey, Matthias},
  journal={arXiv preprint arXiv:2404.00776},
  year={2024}
}
项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

吐司

探索Tensor.Art平台的独特AI模型,免费访问各种图像生成与AI训练工具,从Stable Diffusion等基础模型开始,轻松实现创新图像生成。体验前沿的AI技术,推动个人和企业的创新发展。

Project Cover

SubCat字幕猫

SubCat字幕猫APP是一款创新的视频播放器,它将改变您观看视频的方式!SubCat结合了先进的人工智能技术,为您提供即时视频字幕翻译,无论是本地视频还是网络流媒体,让您轻松享受各种语言的内容。

Project Cover

美间AI

美间AI创意设计平台,利用前沿AI技术,为设计师和营销人员提供一站式设计解决方案。从智能海报到3D效果图,再到文案生成,美间让创意设计更简单、更高效。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号