chatglm-maths
chatglm-6b微调/LORA/PPO/推理, 样本为自动生成的整数/小数加减乘除运算, 可gpu/cpu
踩坑
1. eps=1e-5(不要改小), 半精度float16, 以及LN采用的是Post-LN(泛化性更好) + DeepNorm, 【害, Attention前也有LN】目的是大模型为了防止梯度溢出等;
2. 模型输入输出, 默认的tokenization_chatglm.py/modeling_chatglm.py不能用, 因为那是完全为生成generate设置的, 需要自己写好所有缩入参数, 或者机子改成适配的;
2.1 ChatGLMModel中, get_masks()正常, get_position_ids()函数中‘context_length = seq.index(150004) + 1’ 改为 ‘context_length = len(seq)’;
2.2 训练输入input_ids格式暂定为(训练后post-padding, 推理前pre-padding[tokenization_chatglm.py默认pre-padding])
x: prompt_1 + "_" + text_1 + "\n" + prompt_2 + [gMASK] + [BOS] + "_" + text_2 + [PAD]*N
2.3 训练输入label_ids格式暂定为(CrossEntropyLoss默认忽略-100不参与计算loss)
y = [-100]*len(text_1) + [BOS] + text_2 + [EOS] + [-100]*N
2.4 注意position/mask(自带的只是推理用的batch_size=1, 所以训练输入还得自己写), 可参考GLM-130的README.md, huozhe 查看GLM-1源码https://github.com/THUDM/GLM/blob/main/tasks/seq2seq/dataset.py
3. 注意chatglm-6b权重是float16的, 不过计算loss时候会转成float32计算, 最后loss再转回float16更新梯度;
4. ChatGLMTokenizer有时候会报奇奇怪怪的错误, 建议生成时候设置max_new_tokens, 最大{"max_new_tokens": 2048}; decode有时候会出现不存在id;
5. 低秩自适应LORA, RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
尝试 transformers升级到最新, get_peft_model后再.cuda(), device_map={'':torch.cuda.current_device()},
微调数据
-
原始数据来自https://github.com/LYH-YF/MWPToolkit
处理后的微调数据(算式/解方程)-MWP: https://huggingface.co/datasets/Macropodus/MWP-Instruct
-
大数加减乘除来自: https://github.com/liutiedong/goat.git
LoRA权重
Baichuan-7B-GPT4ForALL: https://huggingface.co/Macropodus/MWP-Instruct
Bloomz-7B-GPT4ForALL: https://huggingface.co/Macropodus/MWP-Instruct
ChatGLM-6B-GPT4ForALL: https://huggingface.co/Macropodus/MWP-Instruct
LlaMA-7B-GPT4ForALL: https://huggingface.co/Macropodus/MWP-Instruct
ChatGLM-6B-MWP: https://huggingface.co/Macropodus/MWP-Instruct
数据集-中文
- https://github.com/tatsu-lab/stanford_alpaca
- https://github.com/LianjiaTech/BELLE
- https://github.com/carbonz0/alpaca-chinese-dataset
环境配置
transformers>=4.26.1
cpm_kernels==1.0.11
icetk==0.0.4
torch>=1.10.1
rouge==1.0.1
nltk==3.6.6
peft>=0.2.0
numpy
tqdm
lion_pytorch
macropodus
trl>=0.4.1
微调-计算题
lora
微调: python c00_toy_lora_train_6b.py
推理: python p00_toy_lora_predict_6b.py
ppo
训练: python t10_toy_trl_train_ppo.py
测试: python t10_toy_trl_predict_ppo.py
6b
微调: python c00_toy_cpu_train_6b.py
推理: python p00_toy_cpu_predit_6b.py
small-layer
微调: python c01_toy_cpu_train_small.py
推理: python p01_toy_cpu_predict_small.py
参考/感谢
- https://github.com/THUDM/ChatGLM-6B
- https://github.com/THUDM/GLM
- https://github.com/tatsu-lab/stanford_alpaca
- https://github.com/LianjiaTech/BELLE
- https://github.com/huggingface/peft
- https://github.com/mymusise/ChatGLM-Tuning
- https://github.com/bojone/bert4keras
- trl
- math23k
推理日志toy
generator_calculate_line: ('13+75=', '13+75=88')
tokenizer.vocab_size: 150344
eval: 0%| | 0/1 [00:00<?, ?it/s]batch_query: ['简便运算: 98+83= 剖析: 98+83=181']
batch_qtext_0: 简便运算: 98+83= 剖析:
batch_qans_0: 98+83=181
response_0: 98+83=171
{'rouge-1': 0.0, 'rouge-2': 0.0, 'rouge-l': 0.0, 'bleu': 0.0}
请输入:
25.31+86.35=
请稍等...
25.31+86.35=101.66
微调日志toy
generator_calculate_line: ('13+75=', '13+75=88')
tokenizer.vocab_size: 150344
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:10<00:00, 1.31s/it]
transformer.word_embeddings.weight False
......
transformer.layers.26.mlp.dense_4h_to_h.bias False
transformer.layers.27.input_layernorm.weight True
transformer.layers.27.input_layernorm.bias True
transformer.layers.27.attention.query_key_value.weight True
transformer.layers.27.attention.query_key_value.bias True
transformer.layers.27.attention.dense.weight True
transformer.layers.27.attention.dense.bias True
transformer.layers.27.post_attention_layernorm.weight True
transformer.layers.27.post_attention_layernorm.bias True
transformer.layers.27.mlp.dense_h_to_4h.weight True
transformer.layers.27.mlp.dense_h_to_4h.bias True
transformer.layers.27.mlp.dense_4h_to_h.weight True
transformer.layers.27.mlp.dense_4h_to_h.bias True
transformer.final_layernorm.weight True
transformer.final_layernorm.bias True
model.chat start
13+75=88, but that's not the correct answer. The correct answer is 13+75=88, which is 90.
/anaconda3/envs/py371/lib/python3.7/site-packages/transformers/optimization.py:395: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
FutureWarning,
epoch: 0%| | 0/21 [00:00<?, ?it/s]epochs:
batch_query: ['简便运算: 98+83= 剖析: 98+83=181'] | 0/8 [00:00<?, ?it/s]
epoch_global: 0, step_global: 1, step: 0, loss: 4.0625
batch_query: ['口算: 57.84+13.64 解: 57.84+13.64=71.48']
epoch_global: 0, step_global: 2, step: 1, loss: 2.5625███▌ | 2/8 [00:17<00:51, 8.54s/it]
batch_query: ['计算题: 48+1 解答: 48+1=49']
epoch_global: 0, step_global: 3, step: 2, loss: 4.15625█████████████████████▎ | 3/8 [00:38<01:09, 13.94s/it]
batch_query: ['计算题: 61.65+33.05 解答: 61.65+33.05=94.7']
epoch_global: 0, step_global: 4, step: 3, loss: 2.40625████████████████████████████████████████ | 4/8 [01:01<01:09, 17.43s/it]
batch_query: ['计算: 81+75 回答: 81+75=156']
epoch_global: 0, step_global: 5, step: 4, loss: 3.546875█████████████████████████████████████████████████████████▊ | 5/8 [01:27<01:01, 20.41s/it]
epoch: 5%|███████▎ | 1/21 [03:07<1:02:30, 187.52s/it]epochs: step: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [02:41<00:00, 23.15s/it]
epoch_0_step: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [03:07<00:00, 23.44s/it]
batch_query: ['问题: 99+37 答案: 99+37=136']
epoch_global: 1, step_global: 9, step: 0, loss: 3.640625 | 0/8 [00:00<?, ?it/s]
batch_query: ['问题: 26.81+55.91 答案: 26.81+55.91=82.72'] | 0/1 [00:00<?, ?it/s]
batch_qtext_0: 问题: 26.81+55.91 答案:
batch_qans_0: 26.81+55.91=82.72
response_0: 26.81+55.91=83.72
{'rouge-1': 0.749999995, 'rouge-2': 0.3333333283333334, 'rouge-l': 0.749999995, 'bleu': 0.0}
epoch_global: 1, step_global: 9, step: 0
best_score_avg: 0.45833
current_mertics: {'rouge-1': 0.749999995, 'rouge-2': 0.3333333283333334, 'rouge-l': 0.749999995, 'bleu': 0.0}
batch_query: ['数学题: 23.34+68.45 点拨: 23.34+68.45=91.79']
epoch_global: 1, step_global: 10, step: 1, loss: 2.09375
batch_query: ['计算: 77+14 回答: 77+14=91']█████████████▌ | 2/8 [00:33<01:39, 16.58s/it]
epoch_global: 1, step_global: 11, step: 2, loss: 3.265625
batch_query: ['口算: 79.69+17.43= 解: 79.69+17.43=97.12']██████████████████▎ | 3/8 [00:35<00:53, 10.75s/it]
epoch_global: 1, step_global: 12, step: 3, loss: 2.171875
batch_query: ['简便运算: 59.67+86.73 剖析: 59.67+86.73=146.4']████████████████████████████████ | 4/8 [00:37<00:29, 7.43s/it]
epoch_global: 1, step_global: 13, step: 4, loss: 2.328125
epoch: 10%|██████████████▊ | 2/21 [03:56<33:33, 105.97s/it]epochs:
epoch_1_step: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:48<00:00, 6.11s/it]
batch_query: ['初等数学: 24.29+76.26 解析: 24.29+76.26=100.55']
epoch_global: 2, step_global: 17, step: 0, loss: 2.046875
epoch_2_step: 0%| | 0/8 [00:00<?, ?it/sbatch_query: ['计算题: 69.85+28.46= 解答: 69.85+28.46=98.31']
batch_qtext_0: 计算题: 69.85+28.46= 解答: | 0/1 [00:00<?, ?it/s]
batch_qans_0: 69.85+28.46=98.31
response_0: 69.85+28.46=97.21
{'rouge-1': 0.4999999950000001, 'rouge-2': 0.3333333283333334, 'rouge-l': 0.4999999950000001, 'bleu': 0.0}
eval: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:07<00:00, 7.83s/it]
epoch_global: 2, step_global: 17, step: 0
best_score_avg: 0.33333
current_mertics: {'rouge-1': 0.4999999950000001, 'rouge-2': 0.3333333283333334, 'rouge-l': 0.4999999950000001, 'bleu': 0.0}
batch_query: ['问题: 113.79+81.78= 答案: 113.79+81.78=195.57']
epoch_global: 2, step_global: 18, step: 1, loss: 1.8515625
batch_query: ['计算: 10.74+17.87= 回答: