项目介绍:PyTorch Tabular
PyTorch Tabular 是一个旨在简化与表格数据进行深度学习的库,无论是现实中的应用场景还是科研项目,都可以从中受益。这个项目依托于强大的 PyTorch 及 PyTorch Lightning 框架而构建,强调了几个核心设计理念:
- 低难度的可用性
- 简单易用的自定义方式
- 强大的扩展功能及部署便利性
安装指南
推荐的最佳安装方式是先从 PyTorch 网站根据具体设备配置相应的 CUDA 版本进行安装。完成 PyTorch 的安装后,可以使用以下命令安装 PyTorch Tabular 包:
pip install -U “pytorch_tabular[extra]”
该命令将安装库的所有附加依赖项,包括 Weights&Biases 和 Plotly。如果只需要基本功能,可以使用:
pip install -U “pytorch_tabular”
要获取 PyTorch Tabular 的源代码,可以从 GitHub 仓库中克隆代码,并在本地进行安装:
git clone git://github.com/manujosephv/pytorch_tabular
cd pytorch_tabular && pip install .[extra]
模型支持
PyTorch Tabular 提供多种模型,用户可以根据需求选择不同的模型进行数据处理和分析:
- Category Embedding 的前馈网络:对于类别型列添加嵌入层。
- NODE、TabNet、Mixture Density Networks、AutoInt 和 TabTransformer 等先进模型。
- FT Transformer 和 Gated Additive Tree Ensemble 系列,帮助自动特征学习和表示。
- 支持半监督学习的 Denoising AutoEncoder。
使用示例
用户可以通过 Python 代码快速进行模型的训练和预测。以下为一个简单的使用案例:
from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig
from pytorch_tabular.config import (
DataConfig,
OptimizerConfig,
TrainerConfig,
)
data_config = DataConfig(
target=["target"],
continuous_cols=['col1', 'col2'],
categorical_cols=['col3', 'col4']
)
trainer_config = TrainerConfig(
auto_lr_find=True,
batch_size=1024,
max_epochs=100,
)
optimizer_config = OptimizerConfig()
model_config = CategoryEmbeddingModelConfig(
task="classification",
layers="1024-512-512",
activation="LeakyReLU",
learning_rate=1e-3,
)
tabular_model = TabularModel(
data_config=data_config,
model_config=model_config,
optimizer_config=optimizer_config,
trainer_config=trainer_config,
)
tabular_model.fit(train=train, validation=val)
result = tabular_model.evaluate(test)
pred_df = tabular_model.predict(test)
tabular_model.save_model("examples/basic")
loaded_model = TabularModel.load_model("examples/basic")
未来计划
- 集成Optuna进行超参数调优
- 数据模块迁移到Polars或NVTabular以增强数据加载性能
- 添加更多的架构和功能
贡献者
项目由多个开源贡献者协作完成,包括Manu Joseph、Jinu Sunil、Jiri Borovec等人。
学术引用
在科研出版物中使用 PyTorch Tabular 的用户,被鼓励引用相关软件及论文以给予支持。
PyTorch Tabular 旨在让开发者和科研人员在处理表格型数据时变得更加便捷,无需深入掌握复杂的深度学习算法即可快速应用。