TabFormer 项目介绍
TabFormer 是一个基于PyTorch的项目,旨在为多变量时间序列建模提供一种新的方法。该项目在ICASSP 2021大会上进行了展示。以下是TabFormer项目的具体介绍。
项目概述
TabFormer 项目提供了一系列工具和数据集,用于构建分层结构的表格数据变换器,具体包括以下几个方面:
- 模块化的层级变换器,用于处理表格数据。
- 一个模拟的信用卡交易数据集。
- 改进版的Adaptive Softmax,用于处理数据屏蔽问题。
- 针对表格数据定制的 DataCollatorForLanguageModeling 模块。
- 所有模块基于 HuggingFace 🤗 的transformers构建。
项目要求
为了运行TabFormer 项目,需要具备以下软件环境:
- Python 版本 3.7
- Pytorch 版本 1.6.0
- HuggingFace / Transformer 版本 3.2.0
- scikit-learn 版本 0.23.2
- Pandas 版本 1.1.2
可以通过运行以下命令来安装这些库:
conda env create -f setup.yml
信用卡交易数据集
项目中提供的信用卡交易数据集位于目录 ./data/credit_card
中,包含2400万条记录和12个字段。需要使用Git LFS来访问该数据。如果LFS带宽受限,可以通过直接链接访问数据,然后使用 GIT_LFS_SKIP_SMUDGE=1
参数访问文件。
PRSA 数据集
对于PRSA数据集,用户需从Kaggle下载并放置在 ./data/card
目录下。
训练Tabular模型
Tabular BERT
要在信用卡交易数据或PRSA数据集上训练一个Tabular BERT模型,可以使用以下命令:
python main.py --do_train --mlm --field_ce --lm_type bert \
--field_hs 64 --data_type [prsa/card] \
--output_dir [output_dir]
Tabular GPT2
对于特定用户的信用卡交易数据,可以训练Tabular GPT2模型:
python main.py --do_train --lm_type gpt2 --field_ce --flatten --data_type card \
--data_root [path_to_data] --user_ids [user-id] \
--output_dir [output_dir]
部分命令选项说明(更多选项可以在args.py
中查找):
--data_type
用于指定数据集类型,可选项包括prsa
和card
。--mlm
表示使用屏蔽语言模型,这是BERT变换器的一个选项。--field_hs
表示字段级变换器的隐藏层大小。--lm_type
可以选择bert
或gpt2
。--user_ids
用于只选择特定用户ID的交易数据。
引用
如果在研究中使用了该项目,建议引用以下文献:
@inproceedings{padhi2021tabular,
title={Tabular transformers for modeling multivariate time series},
author={Padhi, Inkit and Schiff, Yair and Melnyk, Igor and Rigotti, Mattia and Mroueh, Youssef and Dognin, Pierre and Ross, Jerret and Nair, Ravi and Altman, Erik},
booktitle={ICASSP 2021-2021 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
pages={3565--3569},
year={2021},
organization={IEEE}
}
通过以上介绍,可以对TabFormer项目有一个全面的了解。项目提供了强大的工具和数据,助力多变量时间序列建模。