ETSformer - Pytorch
在Pytorch中实现ETSformer,一种最先进的时间序列Transformer
安装
$ pip install etsformer-pytorch
使用方法
import torch
from etsformer_pytorch import ETSFormer
model = ETSFormer(
time_features = 4,
model_dim = 512, # 论文中使用512
embed_kernel_size = 3, # 用于输入嵌入的一维卷积核大小
layers = 2, # 编码器和对应解码器的层数
heads = 8, # 指数平滑注意力头的数量
K = 4, # 保留振幅最高的频率数量(进行注意力计算)
dropout = 0.2 # dropout(论文中使用0.2)
)
timeseries = torch.randn(1, 1024, 4)
pred = model(timeseries, num_steps_forecast = 32) # (1, 32, 4) - (批次, 预测步数, 时间特征数)
使用ETSFormer进行分类,对所有潜在变量和水平输出进行交叉注意力池化
import torch
from etsformer_pytorch import ETSFormer, ClassificationWrapper
etsformer = ETSFormer(
time_features = 1,
model_dim = 512,
embed_kernel_size = 3,
layers = 2,
heads = 8,
K = 4,
dropout = 0.2
)
adapter = ClassificationWrapper(
etsformer = etsformer,
dim_head = 32,
heads = 16,
dropout = 0.2,
level_kernel_size = 5,
num_classes = 10
)
timeseries = torch.randn(1, 1024)
logits = adapter(timeseries) # (1, 10)
引用
@misc{woo2022etsformer,
title = {ETSformer: Exponential Smoothing Transformers for Time-series Forecasting},
author = {Gerald Woo and Chenghao Liu and Doyen Sahoo and Akshat Kumar and Steven Hoi},
year = {2022},
eprint = {2202.01381},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}