Tab Transformer
在Pytorch中实现Tab Transformer,一个用于表格数据的注意力网络。这种简单的架构的性能几乎与GBDT不相上下。
更新:亚马逊AI声称在一个真实世界的表格数据集(预测运输成本)上,使用注意力机制击败了GBDT。
安装
$ pip install tab-transformer-pytorch
使用方法
import torch
import torch.nn as nn
from tab_transformer_pytorch import TabTransformer
cont_mean_std = torch.randn(10, 2)
model = TabTransformer(
categories = (10, 5, 6, 5, 8), # 包含每个类别中唯一值数量的元组
num_continuous = 10, # 连续值的数量
dim = 32, # 维度,论文设置为32
dim_out = 1, # 二元预测,但可以是任何值
depth = 6, # 深度,论文推荐6
heads = 8, # 头数,论文推荐8
attn_dropout = 0.1, # 注意力后的dropout
ff_dropout = 0.1, # 前馈网络的dropout
mlp_hidden_mults = (4, 2), # 最后mlp到logits的每个隐藏维度的相对倍数
mlp_act = nn.ReLU(), # 最后mlp的激活函数,默认为relu,但可以是其他任何函数(如selu等)
continuous_mean_std = cont_mean_std # (可选)- 在层归一化之前归一化连续值
)
x_categ = torch.randint(0, 5, (1, 5)) # 类别值,从0到最大类别数,按上面构造函数传入的顺序
x_cont = torch.randn(1, 10) # 假设连续值已经单独归一化
pred = model(x_categ, x_cont) # (1, 1)
FT Transformer
来自Yandex的这篇论文通过使用更简单的方案来嵌入连续数值,改进了Tab Transformer,如上图所示,图片来源于这个Reddit帖子。
为了方便与Tab Transformer进行比较,此实现也包含在本仓库中。
import torch
from tab_transformer_pytorch import FTTransformer
model = FTTransformer(
categories = (10, 5, 6, 5, 8), # 包含每个类别中唯一值数量的元组
num_continuous = 10, # 连续值的数量
dim = 32, # 维度,论文设置为32
dim_out = 1, # 二元预测,但可以是任何值
depth = 6, # 深度,论文推荐6
heads = 8, # 头数,论文推荐8
attn_dropout = 0.1, # 注意力后的dropout
ff_dropout = 0.1 # 前馈网络的dropout
)
x_categ = torch.randint(0, 5, (1, 5)) # 类别值,从0到最大类别数,按上面构造函数传入的顺序
x_numer = torch.randn(1, 10) # 数值
pred = model(x_categ, x_numer) # (1, 1)
无监督训练
要进行论文中描述的无监督训练类型,你可以首先将类别标记转换为适当的唯一ID,然后在model.transformer
上使用Electra。
待办事项
引用
@misc{huang2020tabtransformer,
title = {TabTransformer: Tabular Data Modeling Using Contextual Embeddings},
author = {Xin Huang and Ashish Khetan and Milan Cvitkovic and Zohar Karnin},
year = {2020},
eprint = {2012.06678},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@article{Gorishniy2021RevisitingDL,
title = {Revisiting Deep Learning Models for Tabular Data},
author = {Yu. V. Gorishniy and Ivan Rubachev and Valentin Khrulkov and Artem Babenko},
journal = {ArXiv},
year = {2021},
volume = {abs/2106.11959}
}