Project Icon

chatglm-maths

ChatGLM-6B数学运算能力优化项目

该项目旨在优化ChatGLM-6B模型的整数和小数四则运算能力。项目采用LORA、PPO等多种训练方法,支持GPU和CPU环境。内容包括自动生成的训练样本、微调数据集、LORA权重,以及环境配置和使用说明。这一工具主要面向开发者和研究人员,用于提升大语言模型的数学计算表现。

chatglm-maths

chatglm-6b微调/LORA/PPO/推理, 样本为自动生成的整数/小数加减乘除运算, 可gpu/cpu

踩坑

1. eps=1e-5(不要改小), 半精度float16, 以及LN采用的是Post-LN(泛化性更好) + DeepNorm, 【害, Attention前也有LN】目的是大模型为了防止梯度溢出等;
2. 模型输入输出, 默认的tokenization_chatglm.py/modeling_chatglm.py不能用, 因为那是完全为生成generate设置的, 需要自己写好所有缩入参数, 或者机子改成适配的;
   2.1 ChatGLMModel中, get_masks()正常, get_position_ids()函数中‘context_length = seq.index(150004) + 1’ 改为 ‘context_length = len(seq)’;
   2.2 训练输入input_ids格式暂定为(训练后post-padding, 推理前pre-padding[tokenization_chatglm.py默认pre-padding])
       x: prompt_1 + "_" + text_1 + "\n" + prompt_2 + [gMASK] + [BOS] + "_" + text_2 + [PAD]*N
   2.3 训练输入label_ids格式暂定为(CrossEntropyLoss默认忽略-100不参与计算loss)  
       y = [-100]*len(text_1) + [BOS] + text_2 + [EOS] + [-100]*N
   2.4 注意position/mask(自带的只是推理用的batch_size=1, 所以训练输入还得自己写), 可参考GLM-130的README.md, huozhe 查看GLM-1源码https://github.com/THUDM/GLM/blob/main/tasks/seq2seq/dataset.py
3. 注意chatglm-6b权重是float16的, 不过计算loss时候会转成float32计算, 最后loss再转回float16更新梯度;
4. ChatGLMTokenizer有时候会报奇奇怪怪的错误, 建议生成时候设置max_new_tokens, 最大{"max_new_tokens": 2048}; decode有时候会出现不存在id;
5. 低秩自适应LORA, RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
   尝试 transformers升级到最新, get_peft_model后再.cuda(), device_map={'':torch.cuda.current_device()}, 

微调数据

  1. 原始数据来自https://github.com/LYH-YF/MWPToolkit

    处理后的微调数据(算式/解方程)-MWP: https://huggingface.co/datasets/Macropodus/MWP-Instruct

  2. 大数加减乘除来自: https://github.com/liutiedong/goat.git

LoRA权重

Baichuan-7B-GPT4ForALL: https://huggingface.co/Macropodus/MWP-Instruct
Bloomz-7B-GPT4ForALL: https://huggingface.co/Macropodus/MWP-Instruct
ChatGLM-6B-GPT4ForALL: https://huggingface.co/Macropodus/MWP-Instruct
LlaMA-7B-GPT4ForALL: https://huggingface.co/Macropodus/MWP-Instruct
ChatGLM-6B-MWP: https://huggingface.co/Macropodus/MWP-Instruct

数据集-中文

环境配置

transformers>=4.26.1
cpm_kernels==1.0.11
icetk==0.0.4
torch>=1.10.1
rouge==1.0.1
nltk==3.6.6
peft>=0.2.0
numpy
tqdm

lion_pytorch
macropodus
trl>=0.4.1

微调-计算题

lora
微调: python c00_toy_lora_train_6b.py
推理: python p00_toy_lora_predict_6b.py

ppo
训练: python t10_toy_trl_train_ppo.py
测试: python t10_toy_trl_predict_ppo.py

6b
微调: python c00_toy_cpu_train_6b.py
推理: python p00_toy_cpu_predit_6b.py

small-layer
微调: python c01_toy_cpu_train_small.py
推理: python p01_toy_cpu_predict_small.py

参考/感谢

推理日志toy

generator_calculate_line: ('13+75=', '13+75=88')
tokenizer.vocab_size: 150344
eval:   0%|                                                                                                                                                                      | 0/1 [00:00<?, ?it/s]batch_query: ['简便运算: 98+83= 剖析: 98+83=181']
batch_qtext_0: 简便运算: 98+83= 剖析:
batch_qans_0: 98+83=181
response_0: 98+83=171
{'rouge-1': 0.0, 'rouge-2': 0.0, 'rouge-l': 0.0, 'bleu': 0.0}
请输入:
25.31+86.35=
请稍等...
25.31+86.35=101.66

微调日志toy

generator_calculate_line: ('13+75=', '13+75=88')
tokenizer.vocab_size: 150344
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:10<00:00,  1.31s/it]
transformer.word_embeddings.weight False
......
transformer.layers.26.mlp.dense_4h_to_h.bias False
transformer.layers.27.input_layernorm.weight True
transformer.layers.27.input_layernorm.bias True
transformer.layers.27.attention.query_key_value.weight True
transformer.layers.27.attention.query_key_value.bias True
transformer.layers.27.attention.dense.weight True
transformer.layers.27.attention.dense.bias True
transformer.layers.27.post_attention_layernorm.weight True
transformer.layers.27.post_attention_layernorm.bias True
transformer.layers.27.mlp.dense_h_to_4h.weight True
transformer.layers.27.mlp.dense_h_to_4h.bias True
transformer.layers.27.mlp.dense_4h_to_h.weight True
transformer.layers.27.mlp.dense_4h_to_h.bias True
transformer.final_layernorm.weight True
transformer.final_layernorm.bias True
model.chat start
13+75=88, but that's not the correct answer. The correct answer is 13+75=88, which is 90.
/anaconda3/envs/py371/lib/python3.7/site-packages/transformers/optimization.py:395: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
  FutureWarning,   
epoch:   0%|                                                                                                                                                                    | 0/21 [00:00<?, ?it/s]epochs:
                                                                                                                                                                                                      batch_query: ['简便运算: 98+83= 剖析: 98+83=181']                                                                                                                                 | 0/8 [00:00<?, ?it/s]
epoch_global: 0, step_global: 1, step: 0, loss: 4.0625
batch_query: ['口算: 57.84+13.64 解: 57.84+13.64=71.48']
                                                                                                                                                                                                      epoch_global: 0, step_global: 2, step: 1, loss: 2.5625███▌                                                                                                                | 2/8 [00:17<00:51,  8.54s/it]
batch_query: ['计算题: 48+1 解答: 48+1=49']
                                                                                                                                                                                                      epoch_global: 0, step_global: 3, step: 2, loss: 4.15625█████████████████████▎                                                                                             | 3/8 [00:38<01:09, 13.94s/it]
batch_query: ['计算题: 61.65+33.05 解答: 61.65+33.05=94.7']
                                                                                                                                                                                                      epoch_global: 0, step_global: 4, step: 3, loss: 2.40625████████████████████████████████████████                                                                           | 4/8 [01:01<01:09, 17.43s/it]
batch_query: ['计算: 81+75 回答: 81+75=156']
                                                                                                                                                                                                      epoch_global: 0, step_global: 5, step: 4, loss: 3.546875█████████████████████████████████████████████████████████▊                                                        | 5/8 [01:27<01:01, 20.41s/it]
epoch:   5%|███████▎                                                                                                                                                 | 1/21 [03:07<1:02:30, 187.52s/it]epochs: step: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [02:41<00:00, 23.15s/it]
epoch_0_step: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [03:07<00:00, 23.44s/it]
batch_query: ['问题: 99+37 答案: 99+37=136']
epoch_global: 1, step_global: 9, step: 0, loss: 3.640625                                                                                                                         | 0/8 [00:00<?, ?it/s]
                                                                                                                                                                                                      batch_query: ['问题: 26.81+55.91 答案: 26.81+55.91=82.72']                                                                                                                        | 0/1 [00:00<?, ?it/s]
batch_qtext_0: 问题: 26.81+55.91 答案:
batch_qans_0: 26.81+55.91=82.72
response_0: 26.81+55.91=83.72
{'rouge-1': 0.749999995, 'rouge-2': 0.3333333283333334, 'rouge-l': 0.749999995, 'bleu': 0.0}
epoch_global: 1, step_global: 9, step: 0
best_score_avg: 0.45833

current_mertics: {'rouge-1': 0.749999995, 'rouge-2': 0.3333333283333334, 'rouge-l': 0.749999995, 'bleu': 0.0}
batch_query: ['数学题: 23.34+68.45 点拨: 23.34+68.45=91.79']
                                                                                                                                                                                                      epoch_global: 1, step_global: 10, step: 1, loss: 2.09375
batch_query: ['计算: 77+14 回答: 77+14=91']█████████████▌                                                                                                                | 2/8 [00:33<01:39, 16.58s/it]
                                                                                                                                                                                                      epoch_global: 1, step_global: 11, step: 2, loss: 3.265625
batch_query: ['口算: 79.69+17.43= 解: 79.69+17.43=97.12']██████████████████▎                                                                                             | 3/8 [00:35<00:53, 10.75s/it]
                                                                                                                                                                                                      epoch_global: 1, step_global: 12, step: 3, loss: 2.171875
batch_query: ['简便运算: 59.67+86.73 剖析: 59.67+86.73=146.4']████████████████████████████████                                                                           | 4/8 [00:37<00:29,  7.43s/it]
                                                                                                                                                                                                      epoch_global: 1, step_global: 13, step: 4, loss: 2.328125
epoch:  10%|██████████████▊                                                                                                                                            | 2/21 [03:56<33:33, 105.97s/it]epochs:
epoch_1_step: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:48<00:00,  6.11s/it]
batch_query: ['初等数学: 24.29+76.26 解析: 24.29+76.26=100.55']
epoch_global: 2, step_global: 17, step: 0, loss: 2.046875
epoch_2_step:   0%|                                                                                                                                                              | 0/8 [00:00<?, ?it/sbatch_query: ['计算题: 69.85+28.46= 解答: 69.85+28.46=98.31']
batch_qtext_0: 计算题: 69.85+28.46= 解答:                                                                                                                                        | 0/1 [00:00<?, ?it/s]
batch_qans_0: 69.85+28.46=98.31
response_0: 69.85+28.46=97.21
{'rouge-1': 0.4999999950000001, 'rouge-2': 0.3333333283333334, 'rouge-l': 0.4999999950000001, 'bleu': 0.0}
eval: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:07<00:00,  7.83s/it]
epoch_global: 2, step_global: 17, step: 0
best_score_avg: 0.33333

current_mertics: {'rouge-1': 0.4999999950000001, 'rouge-2': 0.3333333283333334, 'rouge-l': 0.4999999950000001, 'bleu': 0.0}
batch_query: ['问题: 113.79+81.78= 答案: 113.79+81.78=195.57']
                                                                                                                                                                                                      epoch_global: 2, step_global: 18, step: 1, loss: 1.8515625
batch_query: ['计算: 10.74+17.87= 回答:
项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

吐司

探索Tensor.Art平台的独特AI模型,免费访问各种图像生成与AI训练工具,从Stable Diffusion等基础模型开始,轻松实现创新图像生成。体验前沿的AI技术,推动个人和企业的创新发展。

Project Cover

SubCat字幕猫

SubCat字幕猫APP是一款创新的视频播放器,它将改变您观看视频的方式!SubCat结合了先进的人工智能技术,为您提供即时视频字幕翻译,无论是本地视频还是网络流媒体,让您轻松享受各种语言的内容。

Project Cover

美间AI

美间AI创意设计平台,利用前沿AI技术,为设计师和营销人员提供一站式设计解决方案。从智能海报到3D效果图,再到文案生成,美间让创意设计更简单、更高效。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号