tch-rs入门指南 - Rust绑定PyTorch C++ API的高效工具
tch-rs是一个优秀的Rust库,它为PyTorch的C++ API提供了Rust绑定。通过tch-rs,Rust开发者可以方便地使用PyTorch强大的深度学习功能,同时享受Rust语言的安全性和高性能。本文将介绍tch-rs的基本用法,帮助读者快速上手这个工具。
项目简介
tch-rs的目标是为PyTorch的C++ API提供轻量级的Rust封装。它尽可能保持与原始C++ API的一致性,同时也为开发更符合Rust风格的绑定奠定了基础。项目的主要特点包括:
- 提供PyTorch张量和自动微分功能的Rust接口
- 支持构建和训练神经网络模型
- 可以加载和使用预训练的PyTorch模型
- 包含丰富的示例代码,涵盖基本操作到复杂应用
安装配置
tch-rs需要系统中安装有PyTorch的C++库(libtorch)。你可以通过以下几种方式配置:
- 使用系统全局安装的libtorch(默认方式)
- 手动安装libtorch,并通过
LIBTORCH
环境变量指定路径 - 使用Python版PyTorch,设置
LIBTORCH_USE_PYTORCH=1
- 通过
download-libtorch
特性自动下载预编译的libtorch二进制文件
对于Linux和macOS用户,可以将以下内容添加到.bashrc
中:
export LIBTORCH=/path/to/libtorch
export LD_LIBRARY_PATH=${LIBTORCH}/lib:$LD_LIBRARY_PATH
Windows用户需要设置相应的环境变量。
基本用法
下面是一个简单的张量操作示例:
use tch::Tensor;
fn main() {
let t = Tensor::from_slice(&[3, 1, 4, 1, 5]);
let t = t * 2;
t.print();
}
构建和训练模型
tch-rs支持使用nn::VarStore
创建变量,并通过梯度下降进行优化。以下是一个简单的神经网络示例:
use tch::{nn, nn::Module, nn::OptimizerConfig, Device};
const IMAGE_DIM: i64 = 784;
const HIDDEN_NODES: i64 = 128;
const LABELS: i64 = 10;
fn net(vs: &nn::Path) -> impl Module {
nn::seq()
.add(nn::linear(vs, IMAGE_DIM, HIDDEN_NODES, Default::default()))
.add_fn(|xs| xs.relu())
.add(nn::linear(vs, HIDDEN_NODES, LABELS, Default::default()))
}
fn main() -> anyhow::Result<()> {
let m = tch::vision::mnist::load_dir("data")?;
let vs = nn::VarStore::new(Device::Cpu);
let net = net(&vs.root());
let mut opt = nn::Adam::default().build(&vs, 1e-3)?;
for epoch in 1..200 {
let loss = net
.forward(&m.train_images)
.cross_entropy_for_logits(&m.train_labels);
opt.backward_step(&loss);
let test_accuracy = net
.forward(&m.test_images)
.accuracy_for_logits(&m.test_labels);
println!(
"epoch: {:4} train loss: {:8.5} test acc: {:5.2}%,"
epoch,
f64::from(&loss),
100. * f64::from(&test_accuracy),
);
}
Ok(())
}
使用预训练模型
tch-rs还支持加载和使用预训练的PyTorch模型。以下是使用预训练ResNet18模型进行图像分类的示例:
use tch::{nn, vision};
fn main() -> anyhow::Result<()> {
let image = vision::imagenet::load_image_and_resize("tiger.jpg")?;
let vs = nn::VarStore::new(tch::Device::Cpu);
let net = vision::resnet18(&vs.root(), Default::default());
vs.load("resnet18.ot")?;
let output = net.forward_t(&image.unsqueeze(0), false).softmax(-1);
for (probability, class) in vision::imagenet::top(&output, 5).iter() {
println!("{:50} {:5.2}%", class, 100.0 * probability)
}
Ok(())
}
更多资源
通过tch-rs,Rust开发者可以方便地将PyTorch的强大功能与Rust的安全性和性能结合起来。无论是构建简单的神经网络还是复杂的深度学习应用,tch-rs都是一个值得尝试的优秀工具。🚀