基于RNN的短文本分类
- 这是用于多类短文本分类。
- 模型使用词嵌入、LSTM(或GRU)和全连接层通过Pytorch构建。
- 通过零填充创建小批量,并使用torch.nn.utils.rnn.PackedSequence进行处理。
- 交叉熵损失 + Adam优化器。
- 支持预训练词嵌入(GloVe)。
模型
- 嵌入 --> 丢弃 --> LSTM(GRU) --> 丢弃 --> 全连接层。
预处理
- 以下命令将从此处下载《Learning to Classify Short and Sparse Text & Web with Hidden Topics from Large-scale Data Collections》中使用的数据集,并对其进行训练处理。
- 同时还会下载GloVe嵌入。
python preprocess.py
训练
- 以下命令开始训练。使用
-h
运行可查看可选参数。
python main.py