MLX ParaLLM
通过MLX在Apple Silicon设备上实现快速并行推理的批量KV缓存。
本仓库大量借鉴了mlx_lm
。将探索如何在那里添加批量生成作为非破坏性PR。
使用方法
需要安装mlx
和mlx_lm
。
from mlx_parallm.utils import load, batch_generate
model, tokenizer = load("google/gemma-1.1-2b-it")
prompts = ["prompt_0", ..., "prompt_k"]
responses = batch_generate(model, tokenizer, prompts=prompts_raw[:10], max_tokens=100, verbose=True, format_prompts=True, temp=0.0)
模型
已测试的模型:
meta-llama/Meta-Llama-3-8B-Instruct
microsoft/Phi-3-mini-4k-instruct
google/gemma-1.1-2b-it
mlx-community/Meta-Llama-3-8B-Instruct-4bit
mlx-community/Phi-3-mini-4k-instruct-4bit
mlx-community/gemma-1.1-2b-it-4bit
同时支持量化和float16
模型。如果有足够的RAM,float16
模型似乎通常表现更快(在M3 Max 128GB上,gemma-2b
的吞吐量最高可达1300+词/秒)。
可以通过从mlx_lm/models
复制架构文件并将任何对KVCache
的引用替换为BatchedKVCache
来添加其他模型。
特性
已支持:
batch_generate
方法(已测试len(prompts) > 500
)- 自动填充
- 使用提示模板自动格式化(
format_prompts=True
) temp = 0
、temp > 0
、top_p
采样- 单流
generate
方法
尚未支持:
- 重复惩罚
batch_generate
的流式输出- 异步请求的动态批处理