RWKV_Pytorch
这是一个用纯Pytorch原生实现的RWKV大语言模型的推理框架,官方的原生实现过于复杂且无法拓展生态,让我们加入灵活的Pytorch阵营,一起开源起来吧!
特性
- 原生pytorch实现!
- 支持batch推理!
- 支持并行推理!充分发挥RWKV优势!
- 代码整洁,容易阅读和二次开发!
- 支持导出并推理onnx格式模型!
使用方法
- 克隆仓库
git clone -b dev https://github.com/yuunnn-w/RWKV_Pytorch.git
- 执行
cd RWKV_Pytorch
进入仓库目录,执行pip install -r requirements.txt
安装依赖。 - 下载 RWKV6 模型,官方仓库地址:BlinkDL/rwkv-6-world,将模型权重放置在
weight
文件夹中。 - 修改
main.py
文件的MODEL_NAME
参数。 - 执行
python main.py
,即可看到batch推理效果。
流水并行(pipeline parallel)使用方法
- 克隆仓库
git clone -b pipeline https://github.com/yuunnn-w/RWKV_Pytorch.git
- 执行
cd RWKV_Pytorch
进入仓库目录,执行pip install -r requirements.txt
安装依赖。 - 下载 RWKV6 模型,官方仓库地址:BlinkDL/rwkv-6-world,将模型权重放置在
weight
文件夹中。 - 修改
train/params.json
文件的MODEL_NAME
参数。 - 执行
torchrun --nproc-per-node 3 train/train-parallel.py
开始训练。
导出onnx方法
- 修改
onnx_export.py
文件参数为你想导出的模型。 - 执行
python onnx_export.py
即可导出到./onnx路径。 - (可选)执行
mkdir ONNX_Simplified
创建一个用于存放简化算子模型的目录。 - (可选)执行
python simplify_large_onnx.py -m onnx/{model name}.onnx -o ONNX_Simplified/{model name}.onnx
来简化模型,简化后的模型将存放在ONNX_Simplified目录。 - (可选)修改
onnx_infer.py
文件内的模型路径参数,执行python onnx_infer.py
即可推理onnx格式模型。
本地部署体验
- 修改
openai_api.py
文件中的模型配置参数。 - 执行
python openai_api.py
即可启动后端。 - 用任意符合 OpenAI API 规范的客户端,填入
http://127.0.0.1:8848
作为API_URL
参数,即可体验。
已知的问题:
- 已知op17版本才支持LayerNorm算子,op18版本才支持GroupNorm算子,目前torch的preview版本支持op18,但是无法导出,current版本只支持op17,能够正常导出含LayerNorm算子的模型。你可以参照main.py 使用opset参数指定
注意,本框架目前仅支持RWKV v6模型,具体版本号为x060
预计未来基于本项目适配香橙派推出的AI Pro开发板,实现在昇腾的生态上推理国产大语言模型RWKV!!!
另外,经过测试,v6 1.6B导出并优化后的onnx模型含有如下算子:
- 算子类型:
Gather
,数量:145 - 算子类型:
Squeeze
,数量:121 - 算子类型:
ReduceMean
,数量:148 - 算子类型:
Sub
,数量:122 - 算子类型:
Mul
,数量:484 - 算子类型:
Add
,数量:675 - 算子类型:
Sqrt
,数量:74 - 算子类型:
Div
,数量:74 - 算子类型:
Shape
,数量:240 - 算子类型:
Expand
,数量:240 - 算子类型:
Range
,数量:72 - 算子类型:
Reshape
,数量:384 - 算子类型:
Equal
,数量:72 - 算子类型:
Where
,数量:72 - 算子类型:
Unsqueeze
,数量:192 - 算子类型:
Concat
,数量:192 - 算子类型:
ScatterND
,数量:72 - 算子类型:
MatMul
,数量:337 - 算子类型:
Tanh
,数量:48 - 算子类型:
Split
,数量:24 - 算子类型:
Exp
,数量:48 - 算子类型:
Neg
,数量:24 - 算子类型:
Sigmoid
,数量:48 - 算子类型:
Slice
,数量:24 - 算子类型:
Flatten
,数量:24 - 算子类型:
Relu
,数量:24
优化模型用到的仓库:onnxsim_large_model
贡献者
技术交流群
感谢各位大佬做出的贡献!欢迎各路大神为本项目提PR和Issue!你们的贡献对本项目十分有价值!!!