Project Icon

dfdx

Rust中的深度学习库,提供GPU加速和编译时类型检查

dfdx是一个注重人体工学和安全性的Rust深度学习库,支持GPU加速和最多6维的张量形状。它在编译时进行形状和类型检查,提供多种张量操作,例如矩阵乘法和卷积。该库还包含神经网络构建模块和标准的深度学习优化器,如Sgd和Adam。设计目标是性能最大化和最小化不安全代码。用户可以启用CUDA特性进行GPU加速,非常适合在Rust中进行深度学习开发的用户。

dfdx: 形状检查的 Rust 深度学习

crates.io docs.rs

注重易用性和安全性的 Rust 深度学习。

仍处于预先发布状态。接下来的几个版本计划会有破坏性更改。

功能一览:

  1. :fire: GPU 加速的张量库,支持多达 6 维的形状!
  2. 具有编译时和运行时尺寸的形状(如 Tensor<(usize, Const<10>)>Tensor<Rank2<5, 10>>)。
  3. 丰富的张量操作库(包括 matmulconv2d 等)。
    1. 所有张量操作在编译时进行形状和类型检查!!
  4. 人性化的神经网络构建块(如 LinearConv2DTransformer)。
  5. 标准深度学习优化器,如 SgdAdamAdamWRMSprop 等。

dfdx 可以在 crates.io 上获取!在你的 Cargo.toml 中添加以下内容进行使用:

dfdx = "0.13.0"

docs.rs/dfdx 查看文档。

[1] https://en.wikipedia.org/wiki/Automatic_differentiation#Reverse_accumulation

设计目标

  1. 全程人性化(包括前端接口和内部)。
  2. 尽可能多地在编译时进行检查(即如果有错误就不进行编译)。
  3. 最大化性能。
  4. 最小化不安全代码[1]
  5. 最小化内部代码中使用的 Rc<RefCell>[2]

[1] 目前唯一的不安全调用是用于矩阵乘法。

[2] 只有张量用于存储数据时使用 Arc。使用 Arc 而不是 Box,以减少克隆张量时的分配。

使用 CUDA 进行 GPU 加速

启用 cuda 特性以开始使用 Cuda 设备!需要安装 Nvidia 的 CUDA 工具包。详细信息请参见特性标志文档

API 预览

查看 examples/ 了解更多详情。

  1. 👌 简单的神经网络 API,完全在编译时进行形状检查。
type Mlp = (
    (Linear<10, 32>, ReLU),
    (Linear<32, 32>, ReLU),
    (Linear<32, 2>, Tanh),
);

fn main() {
    let dev: Cuda = Default::default(); // 或 `Cpu`
    let mlp = dev.build_module::<Mlp, f32>();
    let x: Tensor<Rank1<10>, f32, Cpu> = dev.zeros();
    let y: Tensor<Rank1<2>, f32, Cpu> = mlp.forward(x);
    mlp.save("checkpoint.npz")?;
}
  1. 📈 人性化的优化器 API
type Model = ...
let mut model = dev.build_module::<Model, f32>();
let mut grads = model.alloc_grads();
let mut sgd = Sgd::new(&model, SgdConfig {
    lr: 1e-2,
    momentum: Some(Momentum::Nesterov(0.9))
});

let loss = ...
grads = loss.backward();

sgd.update(&mut model, &grads);
  1. 💡 常量张量可以与普通 Rust 数组相互转换
let t0: Tensor<Rank0, f32, _> = dev.tensor(0.0);
assert_eq!(t0.array(), &0.0);

let t1 /*: Tensor<Rank1<3>, f32, _>*/ = dev.tensor([1.0, 2.0, 3.0]);
assert_eq!(t1.array(), [1.0, 2.0, 3.0]);

let t2: Tensor<Rank2<2, 3>, f32, _> = dev.sample_normal();
assert_ne!(t2.array(), [[0.0; 3]; 2]);

有趣/值得注意的实现细节

模块

pub trait Module<Input> {
    type Output;
    fn forward(&self, input: Input) -> Self::Output;
}

基于这个灵活的特性,我们可以实现:

  1. 单个和批处理输入(只需实现多个 impls 即可!)
  2. 多输入/多输出(多头模块或 RNNs)
  3. 不同的行为取决于是否存在 Tape(不是其他库中的 .train()/.eval() 行为!)。

元组表示前馈(即顺序)模块

由于我们可以为元组实现特性,这在其他语言中是不可能的,它们为顺序执行模块提供了一个非常好的前端。

// 不知道为什么你会这样做,但你可以!
type Model = (ReLU, Sigmoid, Tanh);
let model = dev.build_module::<Model, f32>();
type Model = (Linear<10, 5>, Tanh)
let model = dev.build_module::<Model, f32>();

为包含两个元素的元组实现 Module 的样子:

impl<Input, A, B> Module<Input> for (A, B)
where
    Input: Tensor,
    A: Module<Input>,        // A 是一个可以接受 Input 的模块
    B: Module<A::Output>,    // B 是一个可以接受 A 的输出的模块
{
    type Output = B::Output; // 输出是 B 的输出
    fn forward(&self, x: Input) -> Self::Output {
        let x = self.0.forward(x);
        let x = self.1.forward(x);
        x
    }
}

模块可以为多达 6 个元素的元组实现,但你可以任意嵌套它们

没有使用 Rc<RefCells<T>> - 梯度记录带不保存在单元中!

其他实现可能会直接在张量上存储对梯度记录带的引用,这需要经常改变张量或到处使用 Rc/RefCells。

我们找到了一个优雅的方法来避免这一点,将引用和动态借用检查减少为 0!

由于所有操作的结果恰好只有 1 个子代,我们可以始终将梯度记录带移动到最后一个操作的子代。同时,任何模型参数(所有张量)都不会拥有梯度记录带,因为它们永远不会成为任何操作的结果。这意味着我们确切知道哪个张量拥有梯度记录带,并且拥有它的张量将始终是中间结果,不需要在整个梯度计算过程中维护。

所有这些都为用户提供了前所未有的控制/精度,以记录梯度的张量!

一个高级用例需要在计算图中多次重用张量。这可以通过克隆张量并手动移动梯度记录带来处理。

类型检查反向传播

简要说明:如果你忘记调用 trace()traced(),程序将无法编译!

-let pred = module.forward(x);
+let pred = module.forward(x.traced(grads));
let loss = (y - pred).square().mean();
let gradients = loss.backward();

由于我们确切知道哪个张量拥有梯度记录带,我们可以要求传递到 .backward() 的张量拥有梯度记录带!并且进一步,我们可以要求它被移动到 .backward(),以便它可以销毁记录带并构建梯度!

所有这些都可以在编译时检查 🎉

📄 验证 PyTorch 的兼容性

所有函数和操作都与 PyTorch 中类似代码的行为进行了测试。

许可证

双重许可证以兼容 Rust 项目。

根据 Apache 许可证 2.0 版许可:http://www.apache.org/licenses/LICENSE-2.0 或 MIT 许可证:http://opensource.org/licenses/MIT。根据这些条款,这个文件可能不会被复制、修改或分发。

项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

吐司

探索Tensor.Art平台的独特AI模型,免费访问各种图像生成与AI训练工具,从Stable Diffusion等基础模型开始,轻松实现创新图像生成。体验前沿的AI技术,推动个人和企业的创新发展。

Project Cover

SubCat字幕猫

SubCat字幕猫APP是一款创新的视频播放器,它将改变您观看视频的方式!SubCat结合了先进的人工智能技术,为您提供即时视频字幕翻译,无论是本地视频还是网络流媒体,让您轻松享受各种语言的内容。

Project Cover

美间AI

美间AI创意设计平台,利用前沿AI技术,为设计师和营销人员提供一站式设计解决方案。从智能海报到3D效果图,再到文案生成,美间让创意设计更简单、更高效。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号