春分
春分是你的单一 JAX 库,提供核心 JAX 中没有的一切:
- 神经网络(或更广泛的任何模型),使用简便的类似 PyTorch 的语法;
- 变换的过滤 API;
- 有用的 PyTree 操作例程;
- 高级功能如运行时错误;
最重要的是,春分不是一个框架:你在春分中编写的一切都兼容 JAX 或其生态系统中的其他内容。
如果你完全不熟悉 JAX,请从这个 CNN on MNIST 示例 开始。
从 Flax 或 Haiku 转过来?主要区别在于春分 (a) 提供了许多这些库中没有的高级功能,例如 PyTree 操作或运行时错误;(b) 拥有更简单的模型构建方式:它们只是 PyTrees,所以它们可以顺利地通过 JIT/grad/etc. 边界。
安装
pip install equinox
需要 Python 3.9+ 和 JAX 0.4.13+。
文档
可在 https://docs.kidger.site/equinox 获取。
快速示例
使用类似 PyTorch 的语法定义模型:
import equinox as eqx
import jax
class Linear(eqx.Module):
weight: jax.Array
bias: jax.Array
def __init__(self, in_size, out_size, key):
wkey, bkey = jax.random.split(key)
self.weight = jax.random.normal(wkey, (out_size, in_size))
self.bias = jax.random.normal(bkey, (out_size,))
def __call__(self, x):
return self.weight @ x + self.bias
并与普通 JAX 操作完全兼容:
@jax.jit
@jax.grad
def loss_fn(model, x, y):
pred_y = jax.vmap(model)(x)
return jax.numpy.mean((y - pred_y) ** 2)
batch_size, in_size, out_size = 32, 2, 3
model = Linear(in_size, out_size, key=jax.random.PRNGKey(0))
x = jax.numpy.zeros((batch_size, in_size))
y = jax.numpy.zeros((batch_size, out_size))
grads = loss_fn(model, x, y)
最后,背后没有任何魔法。所有 eqx.Module
做的就是将你的类注册为一个 PyTree。从那时起,JAX 已经知道如何处理 PyTrees。
引用
如果你发现在学术工作中有用,请引用: (arXiv 链接)
@article{kidger2021equinox,
author={Patrick Kidger and Cristian Garcia},
title={{E}quinox: neural networks in {JAX} via callable {P}y{T}rees and filtered transformations},
year={2021},
journal={Differentiable Programming workshop at Neural Information Processing Systems 2021}
}
(也请考虑在 GitHub 上给项目加星。)
另见:JAX 生态系统中的其他库
总是有用
jaxtyping: 数组形状/类型的注解。
深度学习
Optax: 一阶梯度(SGD、Adam 等)优化器。
Orbax: 检查点(异步/多主机/多设备)。
Levanter:基础模型(例如 LLMs)的可扩展且可靠的训练。
科学计算
Diffrax:数值微分方程求解器。
Optimistix:根查找、最小化、定点和最小二乘。
Lineax:线性求解器。
BlackJAX: 概率+贝叶斯采样。
sympy2jax: SymPy<->JAX 转换;通过梯度下降训练符号表达式。
PySR: 符号回归。(非 JAX 值得提及的项目!)
Awesome JAX
[Awesome JAX](https://github.com/n2cholas/awesome-jax):更多 JAX 项目列表。