使用LLaMA进行文本分类
该代码库提供了一个使用LLaMA进行文本分类的基本代码库。
我使用什么系统进行开发?
- 设备:Nvidia 1xV100 GPU
- 设备内存:34G
- 主机内存:252G
如果您需要其他关于硬件的信息,请打开一个issue。
使用方法
实验设置
-
从官方的LLaMA库获取检查点 这里。 1-1. 我假设检查点将位于项目根目录并按照如下内容排列。
checkpoints ├── llama │ ├── 7B │ │ ├── checklist.chk │ │ ├── consolidated.00.pth │ │ └── params.json │ └── tokenizer.model
-
准备好您的python环境。我推荐使用anaconda来分隔您的本地机器的CUDA版本。
conda create -y -n llama-classification python=3.8 conda activate llama-classification conda install cudatoolkit=11.7 -y -c nvidia conda list cudatoolkit # 检查已安装的cuda版本(11.7) pip install -r requirements.txt
方法:直接
直接
比较条件概率p(y|x)
。
-
使用以下脚本预处理来自huggingface datasets的数据。从现在起,我们使用ag_news数据集。
python run_preprocess_direct_ag_news.py python run_preprocess_direct_ag_news.py --sample=False --data_path=real/inputs_direct_ag_news.json # 用于全面评估
-
使用LLaMA推断计算条件概率并预测类别。
torchrun --nproc_per_node 1 run_evaluate_direct_llama.py \ --data_path samples/inputs_direct_ag_news.json \ --output_path samples/outputs_direct_ag_news.json \ --ckpt_dir checkpoints/llama/7B \ --tokenizer_path checkpoints/llama/tokenizer.model
校准
是用校准方法改进直接方法。
- 使用以下命令进行校准。
torchrun --nproc_per_node 1 run_evaluate_direct_calibrate_llama.py \ --direct_input_path samples/inputs_direct_ag_news.json \ --direct_output_path samples/outputs_direct_ag_news.json \ --output_path samples/outputs_direct_calibrate_ag_news.json \ --ckpt_dir checkpoints/llama/7B \ --tokenizer_path checkpoints/llama/tokenizer.model
方法:信道
信道
比较条件概率p(x|y)
。
-
使用以下脚本预处理来自huggingface datasets的数据。从现在起,我们使用ag_news数据集。
python run_preprocess_channel_ag_news.py python run_preprocess_channel_ag_news.py --sample=False --data_path=real/inputs_channel_ag_news.json # 用于全面评估
-
使用LLaMA推断计算条件概率并预测类别。
torchrun --nproc_per_node 1 run_evaluate_channel_llama.py \ --data_path samples/inputs_channel_ag_news.json \ --output_path samples/outputs_channel_ag_news.json \ --ckpt_dir checkpoints/llama/7B \ --tokenizer_path checkpoints/llama/tokenizer.model
方法:纯生成
- 为了使用
生成
模式进行评估,您可以使用预处理的直接版本。torchrun --nproc_per_node 1 run_evaluate_generate_llama.py \ --data_path samples/inputs_direct_ag_news.json \ --output_path samples/outputs_generate_ag_news.json \ --ckpt_dir checkpoints/llama/7B \ --tokenizer_path checkpoints/llama/tokenizer.model
实验
数据集 | 样本数 | k | 方法 | 准确率 | 推理时间 |
---|---|---|---|---|---|
ag_news | 7600 | 1 | 直接 | 0.7682 | 00:38:40 |
ag_news | 7600 | 1 | 直接+校准 | 0.8567 | 00:38:40 |
ag_news | 7600 | 1 | 信道 | 0.7825 | 00:38:37 |
待办事项
- 实现信道方法
- 实验报告
- 直接
- 信道
- 生成
- 实现其他校准方法
- 支持huggingface datasets中的其他数据集
- 实现LLM.int8
- 其他评估指标来衡量基础模型(LLaMA)的不同特性
最后的备注
- 我真的很感谢LLaMA项目团队发布检查点和他们高效的推理代码。该代码库中的大部分工作基于官方库。
- 对于读者,欢迎打开issue或pull requests。您可以提出...
- 关于其他功能请求的任何问题
- 关于详细实现的任何问题
- 关于研究方向的任何讨论
引用
如果您使用我的代码库进行研究,欢迎引用我的工作。
@software{Lee_Simple_Text_Classification_2023,
author = {Lee, Seonghyeon},
month = {3},
title = {{Simple Text Classification Codebase using LLaMA}},
url = {https://github.com/github/sh0416/llama-classification},
version = {1.1.0},
year = {2023}
}