Project Icon

equinox

强大且易用的JAX兼容神经网络库

Equinox是一款专为JAX设计的神经网络库,拥有类似PyTorch的语法。该库支持过滤API和PyTree操作,并兼容JAX及其生态系统中的所有工具。对于新手用户,推荐使用MNIST卷积神经网络示例,简化模型构建过程。Equinox还提供运行时错误处理等高级功能。

春分

春分是你的单一 JAX 库,提供核心 JAX 中没有的一切:

  • 神经网络(或更广泛的任何模型),使用简便的类似 PyTorch 的语法;
  • 变换的过滤 API;
  • 有用的 PyTree 操作例程;
  • 高级功能如运行时错误;

最重要的是,春分不是一个框架:你在春分中编写的一切都兼容 JAX 或其生态系统中的其他内容。

如果你完全不熟悉 JAX,请从这个 CNN on MNIST 示例 开始。

FlaxHaiku 转过来?主要区别在于春分 (a) 提供了许多这些库中没有的高级功能,例如 PyTree 操作或运行时错误;(b) 拥有更简单的模型构建方式:它们只是 PyTrees,所以它们可以顺利地通过 JIT/grad/etc. 边界。

安装

pip install equinox

需要 Python 3.9+ 和 JAX 0.4.13+。

文档

可在 https://docs.kidger.site/equinox 获取。

快速示例

使用类似 PyTorch 的语法定义模型:

import equinox as eqx
import jax

class Linear(eqx.Module):
    weight: jax.Array
    bias: jax.Array

    def __init__(self, in_size, out_size, key):
        wkey, bkey = jax.random.split(key)
        self.weight = jax.random.normal(wkey, (out_size, in_size))
        self.bias = jax.random.normal(bkey, (out_size,))

    def __call__(self, x):
        return self.weight @ x + self.bias

并与普通 JAX 操作完全兼容:

@jax.jit
@jax.grad
def loss_fn(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred_y) ** 2)

batch_size, in_size, out_size = 32, 2, 3
model = Linear(in_size, out_size, key=jax.random.PRNGKey(0))
x = jax.numpy.zeros((batch_size, in_size))
y = jax.numpy.zeros((batch_size, out_size))
grads = loss_fn(model, x, y)

最后,背后没有任何魔法。所有 eqx.Module 做的就是将你的类注册为一个 PyTree。从那时起,JAX 已经知道如何处理 PyTrees。

引用

如果你发现在学术工作中有用,请引用: (arXiv 链接)

@article{kidger2021equinox,
    author={Patrick Kidger and Cristian Garcia},
    title={{E}quinox: neural networks in {JAX} via callable {P}y{T}rees and filtered transformations},
    year={2021},
    journal={Differentiable Programming workshop at Neural Information Processing Systems 2021}
}

(也请考虑在 GitHub 上给项目加星。)

另见:JAX 生态系统中的其他库

总是有用
jaxtyping: 数组形状/类型的注解。

深度学习
Optax: 一阶梯度(SGD、Adam 等)优化器。
Orbax: 检查点(异步/多主机/多设备)。
Levanter:基础模型(例如 LLMs)的可扩展且可靠的训练。

科学计算
Diffrax:数值微分方程求解器。
Optimistix:根查找、最小化、定点和最小二乘。
Lineax:线性求解器。
BlackJAX: 概率+贝叶斯采样。
sympy2jax: SymPy<->JAX 转换;通过梯度下降训练符号表达式。
PySR: 符号回归。(非 JAX 值得提及的项目!)

Awesome JAX
[Awesome JAX](https://github.com/n2cholas/awesome-jax):更多 JAX 项目列表。

项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

白日梦AI

白日梦AI提供专注于AI视频生成的多样化功能,包括文生视频、动态画面和形象生成等,帮助用户快速上手,创造专业级内容。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

讯飞绘镜

讯飞绘镜是一个支持从创意到完整视频创作的智能平台,用户可以快速生成视频素材并创作独特的音乐视频和故事。平台提供多样化的主题和精选作品,帮助用户探索创意灵感。

Project Cover

讯飞文书

讯飞文书依托讯飞星火大模型,为文书写作者提供从素材筹备到稿件撰写及审稿的全程支持。通过录音智记和以稿写稿等功能,满足事务性工作的高频需求,帮助撰稿人节省精力,提高效率,优化工作与生活。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号