蜡烛
Candle 是一个专注于性能(包括 GPU 支持)和易用性的 Rust 极简机器学习框架。试试我们的在线演示: whisper, LLaMA2, T5, yolo, Segment Anything。
开始使用
确保您已按照安装中的说明正确安装了 candle-core
。
让我们看看如何运行一个简单的矩阵乘法。
在您的 myapp/src/main.rs
文件中写入以下内容:
use candle_core::{Device, Tensor};
fn main() -> Result<(), Box<dyn std::error::Error>> {
let device = Device::Cpu;
let a = Tensor::randn(0f32, 1., (2, 3), &device)?;
let b = Tensor::randn(0f32, 1., (3, 4), &device)?;
let c = a.matmul(&b)?;
println!("{c}");
Ok(())
}
运行 cargo run
应该会显示一个形状为 Tensor[[2, 4], f32]
的张量。
如果已经安装了支持 Cuda 的 candle
,只需将 device
定义为 GPU:
- let device = Device::Cpu;
+ let device = Device::new_cuda(0)?;
要查看更高级的示例,请参阅以下部分。
查看我们的示例
这些在线演示完全在您的浏览器中运行:
- yolo:姿势估计和物体识别。
- whisper:语音识别。
- LLaMA2:文本生成。
- T5:文本生成。
- Phi-1.5 和 Phi-2:文本生成。
- Segment Anything Model:图像分割。
- BLIP:图像描述。
我们还提供了一些使用最先进模型的命令行示例:
- LLaMA v1、v2和v3:通用大语言模型,包括SOLAR-10.7B变体。
- Falcon:通用大语言模型。
- Codegeex4:代码补全、代码解释器、网络搜索、函数调用、存储库级别。
- GLM4:由清华大学开发的开放式多语言多模态聊天大语言模型。
- Gemma:来自Google DeepMind的2b和7b通用大语言模型。
- RecurrentGemma:基于Griffin的2b和7b模型,由Google开发,结合了注意力机制和类RNN状态。
- Phi-1、Phi-1.5、Phi-2和Phi-3:1.3b、2.7b和3.8b通用大语言模型,性能可与7b模型相媲美。
- StableLM-3B-4E1T:在1万亿个英语和代码数据集上预训练的3b通用大语言模型。还支持StableLM-2,一个在2万亿个词元上训练的1.6b大语言模型,以及代码变体。
- Mamba:Mamba状态空间模型的推理实现。
- Mistral7b-v0.1:截至2023年9月28日,性能优于所有公开可用的13b模型的7b通用大语言模型。
- Mixtral8x7b-v0.1:8x7b稀疏混合专家通用大语言模型,性能优于Llama 2 70B模型,且推理速度更快。
- StarCoder和StarCoder2:专门用于代码生成的大语言模型。
- Qwen1.5:双语(英文/中文)大语言模型。
- RWKV v5和v6:具有Transformer级别性能的RNN大语言模型。
- Replit-code-v1.5:专门用于代码补全的3.3b大语言模型。
- Yi-6B / Yi-34B:两个双语(英文/中文)通用大语言模型,分别拥有6b和34b参数。
- 量化LLaMA:使用与llama.cpp相同量化技术的LLaMA模型量化版本。
- Stable Diffusion:文本到图像生成模型,支持1.5、2.1、SDXL 1.0和Turbo版本。
- Wuerstchen:另一种文本到图像生成模型。
- segment-anything:带提示的图像分割模型。
使用以下命令运行它们:
cargo run --example quantized --release
要使用CUDA,请在示例命令行中添加--features cuda
。如果安装了cuDNN,使用--features cudnn
可获得更快的速度。
还有一些用于whisper和llama2.c的wasm示例。你可以使用trunk
构建它们,或者在线尝试:
whisper、
llama2、
T5、
Phi-1.5和Phi-2、
Segment Anything Model。
对于LLaMA2,运行以下命令获取权重文件并启动测试服务器:
cd candle-wasm-examples/llama2-c
wget https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/model.bin
wget https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/tokenizer.json
trunk serve --release --port 8081
有用的外部资源
candle-tutorial
:一个非常详细的教程,展示如何将PyTorch模型转换为Candle。candle-lora
:Candle的高效且人性化的LoRA实现。candle-lora
为许多Candle模型提供了开箱即用的LoRA支持,可以在这里找到。optimisers
:优化器集合,包括带动量的SGD、AdaGrad、AdaDelta、AdaMax、NAdam、RAdam和RMSprop。candle-vllm
:高效的本地LLM推理和服务平台,包括兼容OpenAI的API服务器。candle-ext
:Candle的扩展库,提供目前Candle中不可用的PyTorch函数。candle-coursera-ml
:实现Coursera的机器学习专业课程中的ML算法。kalosm
:Rust中的多模态元框架,用于与本地预训练模型交互,支持受控生成、自定义采样器、内存向量数据库、音频转录等。candle-sampling
:Candle的采样技术。gpt-from-scratch-rs
:Andrej Karpathy在YouTube上的_Let's build GPT_教程的移植版,展示了Candle API在玩具问题上的应用。candle-einops
:Python einops库的纯Rust实现。
如果你有要添加到此列表的内容,请提交拉取请求。
特性
- 语法简单,外观和感觉类似 PyTorch。
- 模型训练。
- 嵌入用户自定义的操作/内核,例如 flash-attention v2。
- 后端。
- 针对 x86 优化的 CPU 后端,可选 MKL 支持,Mac 上使用 Accelerate。
- CUDA 后端用于在 GPU 上高效运行,通过 NCCL 实现多 GPU 分布式。
- WASM 支持,可在浏览器中运行模型。
- 包含的模型。
- 语言模型。
- LLaMA v1、v2 和 v3,包括 SOLAR-10.7B 等变体。
- Falcon。
- StarCoder、StarCoder2。
- Phi 1、1.5、2 和 3。
- Mamba、Minimal Mamba。
- Gemma 2b 和 7b。
- Mistral 7b v0.1。
- Mixtral 8x7b v0.1。
- StableLM-3B-4E1T、StableLM-2-1.6B、Stable-Code-3B。
- Replit-code-v1.5-3B。
- Bert。
- Yi-6B 和 Yi-34B。
- Qwen1.5、Qwen1.5 MoE。
- RWKV v5 和 v6。
- 量化 LLM。
- Llama 7b、13b、70b,以及聊天和代码变体。
- Mistral 7b 和 7b instruct。
- Mixtral 8x7b。
- Zephyr 7b a 和 b(基于 Mistral-7b)。
- OpenChat 3.5(基于 Mistral-7b)。
- 文本到文本。
- T5 及其变体:FlanT5、UL2、MADLAD400(翻译)、CoEdit(语法纠正)。
- Marian MT(机器翻译)。
- 文本到图像。
- Stable Diffusion v1.5、v2.1、XL v1.0。
- Wurstchen v2。
- 图像到文本。
- BLIP。
- TrOCR。
- 音频。
- Whisper,多语言语音转文本。
- EnCodec,音频压缩模型。
- MetaVoice-1B,文本转语音模型。
- 计算机视觉模型。
- DINOv2、ConvMixer、EfficientNet、ResNet、ViT、VGG、RepVGG、ConvNeXT、 ConvNeXTv2、MobileOne、EfficientVit (MSRA)、MobileNetv4、Hiera。
- yolo-v3、yolo-v8。
- Segment-Anything Model (SAM)。
- SegFormer。
- 语言模型。
- 文件格式:从 safetensors、npz、ggml 或 PyTorch 文件加载模型。
- 无服务器(在 CPU 上),小型且快速部署。
- 使用 llama.cpp 量化类型的量化支持。
如何使用
速查表:
使用 PyTorch | 使用 Candle | |
---|---|---|
创建 | torch.Tensor([[1, 2], [3, 4]]) | Tensor::new(&[[1f32, 2.], [3., 4.]], &Device::Cpu)? |
创建 | torch.zeros((2, 2)) | Tensor::zeros((2, 2), DType::F32, &Device::Cpu)? |
索引 | tensor[:, :4] | tensor.i((.., ..4))? |
操作 | tensor.view((2, 2)) | tensor.reshape((2, 2))? |
操作 | a.matmul(b) | a.matmul(&b)? |
算术 | a + b | &a + &b |
设备 | tensor.to(device="cuda") | tensor.to_device(&Device::new_cuda(0)?)? |
数据类型 | tensor.to(dtype=torch.float16) | tensor.to_dtype(&DType::F16)? |
保存 | torch.save({"A": A}, "model.bin") | candle::safetensors::save(&HashMap::from([("A", A)]), "model.safetensors")? |
加载 | weights = torch.load("model.bin") | candle::safetensors::load("model.safetensors", &device) |
结构
- candle-core:核心操作、设备和
Tensor
结构定义 - candle-nn:构建实际模型的工具
- candle-examples:在实际场景中使用库的示例
- candle-kernels:CUDA 自定义内核
- candle-datasets:数据集和数据加载器
- candle-transformers:与 transformers 相关的实用工具
- candle-flash-attn:Flash attention v2 层
- candle-onnx:ONNX 模型评估
常见问题
为什么我应该使用 Candle?
Candle 的核心目标是使无服务器推理成为可能。像 PyTorch 这样的完整机器学习框架非常庞大,这使得在集群上创建实例变得缓慢。Candle 允许部署轻量级二进制文件。
其次,Candle 让你可以从生产工作负载中移除 Python。Python 的开销可能会严重影响性能,而 GIL 是一个众所周知的麻烦源。
最后,Rust 很酷!HF 生态系统中已经有很多 Rust crate,比如 safetensors 和 tokenizers。
其他 ML 框架
-
dfdx 是一个出色的 crate,其形状包含在类型中。这通过让编译器直接指出形状不匹配来防止很多麻烦。 然而,我们发现一些功能仍需要使用 nightly 版本,而且对于非 Rust 专家来说,编写代码可能有点令人生畏。 我们正在利用并为运行时的其他核心 crate 做出贡献,希望两个 crate 都能互相受益。
-
burn 是一个通用的 crate,可以利用多个后端,让你为工作负载选择最佳引擎。
-
tch-rs 是 Rust 中的 torch 库绑定。非常versatile,但会将整个 torch 库引入运行时。
tch-rs
的主要贡献者也参与了candle
的开发。
常见错误
使用 mkl 特性编译时缺少符号
如果在使用 mkl 或 accelerate 特性编译二进制文件/测试时遇到缺少符号的情况,例如对于 mkl 你会看到:
= note: /usr/bin/ld: (....o): in function `blas::sgemm':
.../blas-0.22.0/src/lib.rs:1944: undefined reference to `sgemm_' collect2: error: ld returned 1 exit status
= note: some `extern` functions couldn't be found; some native libraries may need to be installed or have their path specified
= note: use the `-l` flag to specify native libraries to link
= note: use the `cargo:rustc-link-lib` directive to specify the native libraries to link with Cargo
或对于 accelerate:
Undefined symbols for architecture arm64:
"_dgemm_", referenced from:
candle_core::accelerate::dgemm::h1b71a038552bcabe in libcandle_core...
"_sgemm_", referenced from:
candle_core::accelerate::sgemm::h2cf21c592cba3c47 in libcandle_core...
ld: symbol(s) not found for architecture arm64
这可能是由于缺少启用 mkl 库所需的链接器标志。你可以尝试在二进制文件顶部为 mkl 添加以下内容:
extern crate intel_mkl_src;
或对于 accelerate:
extern crate accelerate_src;
无法运行 LLaMA 示例:访问源需要登录凭证
Error: request error: https://huggingface.co/meta-llama/Llama-2-7b-hf/resolve/main/tokenizer.json: status code 401
这可能是因为你没有 LLaMA-v2 模型的权限。要解决这个问题,你需要在 huggingface-hub 上注册,接受 LLaMA-v2 模型条件,并设置你的认证令牌。更多详情请参见 issue #350。
编译 flash-attn 时缺少 cute/cutlass 头文件
In file included from kernels/flash_fwd_launch_template.h:11:0,
from kernels/flash_fwd_hdim224_fp16_sm80.cu:5:
kernels/flash_fwd_kernel.h:8:10: fatal error: cute/algorithm/copy.hpp: No such file or directory
#include <cute/algorithm/copy.hpp>
^~~~~~~~~~~~~~~~~~~~~~~~~
compilation terminated.
Error: nvcc error while compiling:
cutlass 作为 git 子模块提供,你可能需要运行以下命令来正确检出它:
git submodule update --init
使用 flash-attention 编译失败
/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with '...':
这是 Cuda 编译器触发的 gcc-11 中的一个 bug。要解决这个问题,请安装其他受支持的 gcc 版本(例如 gcc-10),并在 NVCC_CCBIN 环境变量中指定编译器的路径。
env NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
在 Windows 上运行 rustdoc 或 mdbook 测试时出现链接错误
Couldn't compile the test.
---- .\candle-book\src\inference\hub.md - Using_the_hub::Using_in_a_real_model_ (line 50) stdout ----
error: linking with `link.exe` failed: exit code: 1181
//very long chain of linking
= note: LINK : fatal error LNK1181: cannot open input file 'windows.0.48.5.lib'
确保链接所有可能位于项目目标外的本地库,例如,要运行 mdbook 测试,你应该运行:
mdbook test candle-book -L .\target\debug\deps\ `
-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.42.2\lib `
-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.48.5\lib
WSL 下模型加载时间极慢
这可能是由于模型从 /mnt/c
加载导致的,更多详情请参见 stackoverflow。
追踪错误
你可以设置 RUST_BACKTRACE=1
来获得 candle 错误生成时的回溯信息。
CudaRC 错误
如果在 Windows 上遇到类似这样的错误 called
Result::unwrap()on an
Err value: LoadLibraryExW { source: Os { code: 126, kind: Uncategorized, message: "The specified module could not be found." } }
,请复制并重命名这 3 个文件(确保它们在路径中)。路径取决于你的 CUDA 版本。
c:\Windows\System32\nvcuda.dll
-> cuda.dll
c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\cublas64_12.dll
-> cublas.dll
c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\curand64_10.dll
-> curand.dll