JAX: 自动微分和 XLA
快速入门 | 转换 | 安装指南 | 神经网络库 | 更新日志 | 参考文档
JAX 是什么?
JAX 是一个面向加速器的数组计算和程序转换 Python 库,专为高性能数值计算和大规模机器学习而设计。
通过其更新版本的 Autograd,JAX 可以自动微分原生 Python 和 NumPy 函数。它可以对循环、分支、递归和闭包进行微分,还可以求导数的导数的导数。它支持通过 grad
进行反向模式微分(又称反向传播),以及前向模式微分,两者可以任意组合到任何阶数。
JAX 的新特性是使用 XLA 来编译和在 GPU 和 TPU 上运行 NumPy 程序。编译默认在后台进行,库调用会即时编译并执行。但 JAX 还允许您使用单函数 API jit
将自己的 Python 函数即时编译成 XLA 优化的内核。编译和自动微分可以任意组合,因此您可以表达复杂的算法并获得最大性能,而无需离开 Python。您甚至可以使用 pmap
同时编程多个 GPU 或 TPU 核心,并对整个过程进行微分。
深入一点,您会发现 JAX 实际上是一个用于可组合函数转换的可扩展系统。grad
和 jit
都是此类转换的实例。其他转换包括用于自动向量化的 vmap
和用于多个加速器单程序多数据 (SPMD) 并行编程的 pmap
,未来还会有更多。
这是一个研究项目,而不是 Google 的官方产品。请预期会有错误和棘手问题。请通过尝试使用、报告错误并让我们知道您的想法来提供帮助!
import jax.numpy as jnp
from jax import grad, jit, vmap
def predict(params, inputs):
for W, b in params:
outputs = jnp.dot(inputs, W) + b
inputs = jnp.tanh(outputs) # 下一层的输入
return outputs # 最后一层没有激活函数
def loss(params, inputs, targets):
preds = predict(params, inputs)
return jnp.sum((preds - targets)**2)
grad_loss = jit(grad(loss)) # 编译后的梯度评估函数
perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # 快速逐样本梯度
目录
快速入门:云端 Colab
使用浏览器中的笔记本直接开始,连接到 Google Cloud GPU。 以下是一些入门笔记本:
JAX 现在可以在 Cloud TPU 上运行。 要试用预览版,请参阅 Cloud TPU Colabs。
深入了解 JAX:
转换
JAX 的核心是一个用于转换数值函数的可扩展系统。以下是四个主要的转换:grad
、jit
、vmap
和 pmap
。
使用 grad
进行自动微分
JAX 的 API 与 Autograd 大致相同。最常用的函数是用于反向模式梯度的 grad
:
from jax import grad
import jax.numpy as jnp
def tanh(x): # 定义一个函数
y = jnp.exp(-2.0 * x)
return (1.0 - y) / (1.0 + y)
grad_tanh = grad(tanh) # 获取其梯度函数
print(grad_tanh(1.0)) # 在 x = 1.0 处求值
# 输出 0.4199743
您可以使用 grad
求任意阶导数。
print(grad(grad(grad(tanh)))(1.0))
# 输出 0.62162673
对于更高级的自动微分,您可以使用 jax.vjp
进行反向模式向量-雅可比积,使用 jax.jvp
进行前向模式雅可比-向量积。这两者可以与其他 JAX 转换任意组合。以下是一种组合它们以创建高效计算完整 Hessian 矩阵的函数的方法:
from jax import jit, jacfwd, jacrev
def hessian(fun):
return jit(jacfwd(jacrev(fun)))
与 Autograd 一样,您可以自由地在 Python 控制结构中使用微分:
def abs_val(x):
if x > 0:
return x
else:
return -x
abs_val_grad = grad(abs_val)
print(abs_val_grad(1.0)) # 输出 1.0
print(abs_val_grad(-1.0)) # 输出 -1.0(abs_val 被重新求值)
有关更多信息,请参阅自动微分参考文档和 JAX 自动微分食谱。
使用 jit
进行编译
您可以使用 XLA 通过 jit
对函数进行端到端编译,可以用作 @jit
装饰器或高阶函数。
import jax.numpy as jnp
from jax import jit
def slow_f(x):
# 元素级操作从融合中获得巨大收益
return x * x + x * 2.0
x = jnp.ones((5000, 5000))
fast_f = jit(slow_f)
%timeit -n10 -r3 fast_f(x) # 在 Titan X 上约 4.5 ms/循环
%timeit -n10 -r3 slow_f(x) # 约 14.5 ms/循环(通过 JAX 也在 GPU 上)
您可以随意组合 jit
、grad
和任何其他 JAX 转换。
使用 jit
会对函数可以使用的 Python 控制流类型施加限制;更多信息请参阅注意事项笔记本。
使用 vmap
进行自动向量化
vmap
是向量化映射。它具有沿数组轴映射函数的熟悉语义,但不是将循环保持在外部,而是将循环下推到函数的基本操作中以获得更好的性能。
使用 vmap
可以避免在代码中携带批次维度。例如,考虑这个简单的非批处理神经网络预测函数:
def predict(params, input_vec):
assert input_vec.ndim == 1
activations = input_vec
for W, b in params:
outputs = jnp.dot(W, activations) + b # `activations` 在右侧!
activations = jnp.tanh(outputs) # 下一层的输入
return outputs # 最后一层没有激活函数
我们通常会写成 jnp.dot(activations, W)
以允许 activations
左侧有一个批次维度,但这个特定的预测函数只适用于单个输入向量。如果我们想一次性对一批输入应用这个函数,从语义上讲,我们可以这样写:
from functools import partial
predictions = jnp.stack(list(map(partial(predict, params), input_batch)))
但是一次只处理一个样本会很慢!最好将计算向量化,这样在每一层我们都在进行矩阵-矩阵乘法,而不是矩阵-向量乘法。
vmap
函数为我们完成了这种转换。也就是说,如果我们写:
from jax import vmap
predictions = vmap(partial(predict, params))(input_batch)
# 或者,另一种写法
predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)
那么 vmap
函数会将外部循环推到函数内部,我们的机器最终会执行矩阵-矩阵乘法,就好像我们手动进行了批处理一样。
不使用 vmap
手动批处理一个简单的神经网络很容易,但在其他情况下,手动向量化可能不切实际或不可能。比如高效计算每个样本梯度的问题:对于一组固定的参数,我们想要计算损失函数在批次中每个样本上单独评估的梯度。使用 vmap
,这很容易:
per_example_gradients = vmap(partial(grad(loss), params))(inputs, targets)
当然,vmap
可以任意组合 jit
、grad
和任何其他 JAX 变换!我们在 jax.jacfwd
、jax.jacrev
和 jax.hessian
中使用 vmap
进行正向和反向自动微分,以快速计算雅可比矩阵和海森矩阵。
使用 pmap
进行 SPMD 编程
对于多个加速器(如多个 GPU)的并行编程,使用 pmap
。使用 pmap
,你可以编写单程序多数据(SPMD)程序,包括快速并行集体通信操作。应用 pmap
意味着你编写的函数将由 XLA 编译(类似于 jit
),然后在设备上复制并并行执行。
以下是在 8 GPU 机器上的示例:
from jax import random, pmap
import jax.numpy as jnp
# 创建 8 个随机 5000 x 6000 矩阵,每个 GPU 一个
keys = random.split(random.key(0), 8)
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)
# 在每个设备上并行运行本地矩阵乘法(无数据传输)
result = pmap(lambda x: jnp.dot(x, x.T))(mats) # result.shape 是 (8, 5000, 5000)
# 在每个设备上并行计算平均值并打印结果
print(pmap(jnp.mean)(result))
# 打印 [1.1566595 1.1805978 ... 1.2321935 1.2015157]
除了表达纯映射外,你还可以使用设备间的快速集体通信操作:
from functools import partial
from jax import lax
@partial(pmap, axis_name='i')
def normalize(x):
return x / lax.psum(x, 'i')
print(normalize(jnp.arange(4.)))
# 打印 [0. 0.16666667 0.33333334 0.5 ]
你甚至可以嵌套 pmap
函数以实现更复杂的通信模式。
所有这些都可以组合,所以你可以自由地对并行计算进行微分:
from jax import grad
@pmap
def f(x):
y = jnp.sin(x)
@pmap
def g(z):
return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()
return grad(lambda w: jnp.sum(g(w)))(x)
print(f(x))
# [[ 0. , -0.7170853 ],
# [-3.1085174 , -0.4824318 ],
# [10.366636 , 13.135289 ],
# [ 0.22163185, -0.52112055]]
print(grad(lambda x: jnp.sum(f(x)))(x))
# [[ -3.2369726, -1.6356447],
# [ 4.7572474, 11.606951 ],
# [-98.524414 , 42.76499 ],
# [ -1.6007166, -1.2568436]]
当对 pmap
函数进行反向模式微分(例如使用 grad
)时,计算的反向传播会像正向传播一样并行化。
更多信息请参见 SPMD Cookbook 和 SPMD MNIST 分类器从头开始示例。
当前注意事项
要更全面地了解当前的注意事项,包括示例和解释,我们强烈建议阅读 Gotchas Notebook。一些突出的问题包括:
-
JAX 变换仅适用于纯函数,这些函数没有副作用并遵守引用透明性(即使用
is
进行对象身份测试不会保留)。如果您对非纯Python函数使用JAX变换,可能会看到类似Exception: Can't lift Traced...
或Exception: Different traces at same level
的错误。 -
数组的原地突变更新,如
x[i] += y
,不受支持,但有函数式替代方案。在jit
下,这些函数式替代方案会自动在原地重用缓冲区。 -
如果您在寻找卷积运算符,它们在
jax.lax
包中。 -
JAX默认强制使用单精度(32位,例如
float32
)值,要启用双精度(64位,例如float64
),需要在启动时设置jax_enable_x64
变量(或设置环境变量JAX_ENABLE_X64=True
)。在TPU上,JAX默认对所有内容使用32位值,除了"类矩阵乘法"操作(如jax.numpy.dot
和lax.conv
)中的内部临时变量。这些操作有一个precision
参数,可以通过三次bfloat16传递来近似32位操作,可能会导致运行时间变慢。TPU上的非矩阵乘法操作会转换为通常强调速度而非精度的实现,因此实际上TPU上的计算会比其他后端上的类似计算精度更低。 -
NumPy的一些涉及Python标量和NumPy类型混合的dtype提升语义没有保留,即
np.add(1, np.array([2], np.float32)).dtype
是float64
而不是float32
。 -
一些转换,如
jit
,限制了您使用Python控制流的方式。如果出现问题,您总会收到明确的错误提示。您可能需要使用jit
的static_argnums
参数,结构化控制流原语如lax.scan
,或者仅对较小的子函数使用jit
。
安装
支持的平台
Linux x86_64 | Linux aarch64 | Mac x86_64 | Mac ARM | Windows x86_64 | Windows WSL2 x86_64 | |
---|---|---|---|---|---|---|
CPU | 是 | 是 | 是 | 是 | 是 | 是 |
NVIDIA GPU | 是 | 是 | 否 | 不适用 | 否 | 实验性 |
Google TPU | 是 | 不适用 | 不适用 | 不适用 | 不适用 | 不适用 |
AMD GPU | 实验性 | 否 | 否 | 不适用 | 否 | 否 |
Apple GPU | 不适用 | 否 | 实验性 | 实验性 | 不适用 | 不适用 |
安装指南
硬件 | 指令 |
---|---|
CPU | pip install -U jax |
NVIDIA GPU | pip install -U "jax[cuda12]" |
Google TPU | pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html |
AMD GPU | 使用Docker或从源代码构建。 |
Apple GPU | 按照Apple的说明进行操作。 |
有关其他安装策略的信息,请参阅文档。这些包括从源代码编译、使用Docker安装、使用其他版本的CUDA、社区支持的conda构建,以及一些常见问题的答案。
神经网络库
多个Google研究团队开发并分享了用JAX训练神经网络的库。如果您想要一个功能齐全的神经网络训练库,并附有示例和操作指南,可以尝试Flax。查看新的NNX API以获得简化的开发体验。
Google X维护神经网络库Equinox。这被用作JAX生态系统中几个其他库的基础。
此外,DeepMind已开源了围绕JAX的一系列库,包括用于梯度处理和优化的Optax,用于强化学习算法的RLax,以及用于可靠代码和测试的chex。(观看NeurIPS 2020 JAX Ecosystem在DeepMind的演讲点击这里)
引用JAX
要引用此仓库:
@software{jax2018github,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/google/jax},
version = {0.3.13},
year = {2018},
}
在上述bibtex条目中,名字按字母顺序排列,版本号应为jax/version.py中的版本,年份对应项目的开源发布。
JAX的一个初始版本,仅支持自动微分和编译到XLA,在2018年SysML会议上的一篇论文中有描述。我们目前正在撰写一篇更全面和最新的论文,涵盖JAX的理念和功能。
参考文档
有关JAX API的详细信息,请参阅参考文档。
对于希望成为JAX开发者的入门指南,请参阅开发者文档。