Project Icon

jax

高性能科学计算和机器学习的Python加速库

JAX是一个专为高性能数值计算和大规模机器学习设计的Python库。它利用XLA编译器实现加速器导向的数组计算和程序转换,支持自动微分、GPU和TPU加速。JAX提供jit、vmap和pmap等函数转换工具,让研究人员能够方便地表达复杂算法并获得出色性能,同时保持Python的灵活性。

logo

JAX: 自动微分和 XLA

持续集成 PyPI 版本

快速入门 | 转换 | 安装指南 | 神经网络库 | 更新日志 | 参考文档

JAX 是什么?

JAX 是一个面向加速器的数组计算和程序转换 Python 库,专为高性能数值计算和大规模机器学习而设计。

通过其更新版本的 Autograd,JAX 可以自动微分原生 Python 和 NumPy 函数。它可以对循环、分支、递归和闭包进行微分,还可以求导数的导数的导数。它支持通过 grad 进行反向模式微分(又称反向传播),以及前向模式微分,两者可以任意组合到任何阶数。

JAX 的新特性是使用 XLA 来编译和在 GPU 和 TPU 上运行 NumPy 程序。编译默认在后台进行,库调用会即时编译并执行。但 JAX 还允许您使用单函数 API jit 将自己的 Python 函数即时编译成 XLA 优化的内核。编译和自动微分可以任意组合,因此您可以表达复杂的算法并获得最大性能,而无需离开 Python。您甚至可以使用 pmap 同时编程多个 GPU 或 TPU 核心,并对整个过程进行微分。

深入一点,您会发现 JAX 实际上是一个用于可组合函数转换的可扩展系统。gradjit 都是此类转换的实例。其他转换包括用于自动向量化的 vmap 和用于多个加速器单程序多数据 (SPMD) 并行编程的 pmap,未来还会有更多。

这是一个研究项目,而不是 Google 的官方产品。请预期会有错误和棘手问题。请通过尝试使用、报告错误并让我们知道您的想法来提供帮助!

import jax.numpy as jnp
from jax import grad, jit, vmap

def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jnp.tanh(outputs)  # 下一层的输入
  return outputs                # 最后一层没有激活函数

def loss(params, inputs, targets):
  preds = predict(params, inputs)
  return jnp.sum((preds - targets)**2)

grad_loss = jit(grad(loss))  # 编译后的梯度评估函数
perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0)))  # 快速逐样本梯度

目录

快速入门:云端 Colab

使用浏览器中的笔记本直接开始,连接到 Google Cloud GPU。 以下是一些入门笔记本:

JAX 现在可以在 Cloud TPU 上运行。 要试用预览版,请参阅 Cloud TPU Colabs

深入了解 JAX:

转换

JAX 的核心是一个用于转换数值函数的可扩展系统。以下是四个主要的转换:gradjitvmappmap

使用 grad 进行自动微分

JAX 的 API 与 Autograd 大致相同。最常用的函数是用于反向模式梯度的 grad

from jax import grad
import jax.numpy as jnp

def tanh(x):  # 定义一个函数
  y = jnp.exp(-2.0 * x)
  return (1.0 - y) / (1.0 + y)

grad_tanh = grad(tanh)  # 获取其梯度函数
print(grad_tanh(1.0))   # 在 x = 1.0 处求值
# 输出 0.4199743

您可以使用 grad 求任意阶导数。

print(grad(grad(grad(tanh)))(1.0))
# 输出 0.62162673

对于更高级的自动微分,您可以使用 jax.vjp 进行反向模式向量-雅可比积,使用 jax.jvp 进行前向模式雅可比-向量积。这两者可以与其他 JAX 转换任意组合。以下是一种组合它们以创建高效计算完整 Hessian 矩阵的函数的方法:

from jax import jit, jacfwd, jacrev

def hessian(fun):
  return jit(jacfwd(jacrev(fun)))

Autograd 一样,您可以自由地在 Python 控制结构中使用微分:

def abs_val(x):
  if x > 0:
    return x
  else:
    return -x

abs_val_grad = grad(abs_val)
print(abs_val_grad(1.0))   # 输出 1.0
print(abs_val_grad(-1.0))  # 输出 -1.0(abs_val 被重新求值)

有关更多信息,请参阅自动微分参考文档JAX 自动微分食谱

使用 jit 进行编译

您可以使用 XLA 通过 jit 对函数进行端到端编译,可以用作 @jit 装饰器或高阶函数。

import jax.numpy as jnp
from jax import jit

def slow_f(x):
  # 元素级操作从融合中获得巨大收益
  return x * x + x * 2.0

x = jnp.ones((5000, 5000))
fast_f = jit(slow_f)
%timeit -n10 -r3 fast_f(x)  # 在 Titan X 上约 4.5 ms/循环
%timeit -n10 -r3 slow_f(x)  # 约 14.5 ms/循环(通过 JAX 也在 GPU 上)

您可以随意组合 jitgrad 和任何其他 JAX 转换。

使用 jit 会对函数可以使用的 Python 控制流类型施加限制;更多信息请参阅注意事项笔记本

使用 vmap 进行自动向量化

vmap 是向量化映射。它具有沿数组轴映射函数的熟悉语义,但不是将循环保持在外部,而是将循环下推到函数的基本操作中以获得更好的性能。 使用 vmap 可以避免在代码中携带批次维度。例如,考虑这个简单的非批处理神经网络预测函数:

def predict(params, input_vec):
  assert input_vec.ndim == 1
  activations = input_vec
  for W, b in params:
    outputs = jnp.dot(W, activations) + b  # `activations` 在右侧!
    activations = jnp.tanh(outputs)        # 下一层的输入
  return outputs                           # 最后一层没有激活函数

我们通常会写成 jnp.dot(activations, W) 以允许 activations 左侧有一个批次维度,但这个特定的预测函数只适用于单个输入向量。如果我们想一次性对一批输入应用这个函数,从语义上讲,我们可以这样写:

from functools import partial
predictions = jnp.stack(list(map(partial(predict, params), input_batch)))

但是一次只处理一个样本会很慢!最好将计算向量化,这样在每一层我们都在进行矩阵-矩阵乘法,而不是矩阵-向量乘法。

vmap 函数为我们完成了这种转换。也就是说,如果我们写:

from jax import vmap
predictions = vmap(partial(predict, params))(input_batch)
# 或者,另一种写法
predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)

那么 vmap 函数会将外部循环推到函数内部,我们的机器最终会执行矩阵-矩阵乘法,就好像我们手动进行了批处理一样。

不使用 vmap 手动批处理一个简单的神经网络很容易,但在其他情况下,手动向量化可能不切实际或不可能。比如高效计算每个样本梯度的问题:对于一组固定的参数,我们想要计算损失函数在批次中每个样本上单独评估的梯度。使用 vmap,这很容易:

per_example_gradients = vmap(partial(grad(loss), params))(inputs, targets)

当然,vmap 可以任意组合 jitgrad 和任何其他 JAX 变换!我们在 jax.jacfwdjax.jacrevjax.hessian 中使用 vmap 进行正向和反向自动微分,以快速计算雅可比矩阵和海森矩阵。

使用 pmap 进行 SPMD 编程

对于多个加速器(如多个 GPU)的并行编程,使用 pmap。使用 pmap,你可以编写单程序多数据(SPMD)程序,包括快速并行集体通信操作。应用 pmap 意味着你编写的函数将由 XLA 编译(类似于 jit),然后在设备上复制并并行执行。

以下是在 8 GPU 机器上的示例:

from jax import random, pmap
import jax.numpy as jnp

# 创建 8 个随机 5000 x 6000 矩阵,每个 GPU 一个
keys = random.split(random.key(0), 8)
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)

# 在每个设备上并行运行本地矩阵乘法(无数据传输)
result = pmap(lambda x: jnp.dot(x, x.T))(mats)  # result.shape 是 (8, 5000, 5000)

# 在每个设备上并行计算平均值并打印结果
print(pmap(jnp.mean)(result))
# 打印 [1.1566595 1.1805978 ... 1.2321935 1.2015157]

除了表达纯映射外,你还可以使用设备间的快速集体通信操作

from functools import partial
from jax import lax

@partial(pmap, axis_name='i')
def normalize(x):
  return x / lax.psum(x, 'i')

print(normalize(jnp.arange(4.)))
# 打印 [0.         0.16666667 0.33333334 0.5       ]

你甚至可以嵌套 pmap 函数以实现更复杂的通信模式。

所有这些都可以组合,所以你可以自由地对并行计算进行微分:

from jax import grad

@pmap
def f(x):
  y = jnp.sin(x)
  @pmap
  def g(z):
    return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()
  return grad(lambda w: jnp.sum(g(w)))(x)

print(f(x))
# [[ 0.        , -0.7170853 ],
#  [-3.1085174 , -0.4824318 ],
#  [10.366636  , 13.135289  ],
#  [ 0.22163185, -0.52112055]]

print(grad(lambda x: jnp.sum(f(x)))(x))
# [[ -3.2369726,  -1.6356447],
#  [  4.7572474,  11.606951 ],
#  [-98.524414 ,  42.76499  ],
#  [ -1.6007166,  -1.2568436]]

当对 pmap 函数进行反向模式微分(例如使用 grad)时,计算的反向传播会像正向传播一样并行化。

更多信息请参见 SPMD CookbookSPMD MNIST 分类器从头开始示例

当前注意事项

要更全面地了解当前的注意事项,包括示例和解释,我们强烈建议阅读 Gotchas Notebook。一些突出的问题包括:

  1. JAX 变换仅适用于纯函数,这些函数没有副作用并遵守引用透明性(即使用is进行对象身份测试不会保留)。如果您对非纯Python函数使用JAX变换,可能会看到类似Exception: Can't lift Traced...Exception: Different traces at same level的错误。

  2. 数组的原地突变更新,如x[i] += y,不受支持,但有函数式替代方案。在jit下,这些函数式替代方案会自动在原地重用缓冲区。

  3. 随机数是不同的,但这是有充分理由的

  4. 如果您在寻找卷积运算符,它们在jax.lax包中。

  5. JAX默认强制使用单精度(32位,例如float32)值,要启用双精度(64位,例如float64),需要在启动时设置jax_enable_x64变量(或设置环境变量JAX_ENABLE_X64=True)。在TPU上,JAX默认对所有内容使用32位值,除了"类矩阵乘法"操作(如jax.numpy.dotlax.conv)中的内部临时变量。这些操作有一个precision参数,可以通过三次bfloat16传递来近似32位操作,可能会导致运行时间变慢。TPU上的非矩阵乘法操作会转换为通常强调速度而非精度的实现,因此实际上TPU上的计算会比其他后端上的类似计算精度更低。

  6. NumPy的一些涉及Python标量和NumPy类型混合的dtype提升语义没有保留,即np.add(1, np.array([2], np.float32)).dtypefloat64而不是float32

  7. 一些转换,如jit限制了您使用Python控制流的方式。如果出现问题,您总会收到明确的错误提示。您可能需要使用jitstatic_argnums参数结构化控制流原语lax.scan,或者仅对较小的子函数使用jit

安装

支持的平台

Linux x86_64Linux aarch64Mac x86_64Mac ARMWindows x86_64Windows WSL2 x86_64
CPU
NVIDIA GPU不适用实验性
Google TPU不适用不适用不适用不适用不适用
AMD GPU实验性不适用
Apple GPU不适用实验性实验性不适用不适用

安装指南

硬件指令
CPUpip install -U jax
NVIDIA GPUpip install -U "jax[cuda12]"
Google TPUpip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
AMD GPU使用Docker从源代码构建
Apple GPU按照Apple的说明进行操作。

有关其他安装策略的信息,请参阅文档。这些包括从源代码编译、使用Docker安装、使用其他版本的CUDA、社区支持的conda构建,以及一些常见问题的答案。

神经网络库

多个Google研究团队开发并分享了用JAX训练神经网络的库。如果您想要一个功能齐全的神经网络训练库,并附有示例和操作指南,可以尝试Flax。查看新的NNX API以获得简化的开发体验。

Google X维护神经网络库Equinox。这被用作JAX生态系统中几个其他库的基础。

此外,DeepMind已开源了围绕JAX的一系列库,包括用于梯度处理和优化的Optax,用于强化学习算法的RLax,以及用于可靠代码和测试的chex。(观看NeurIPS 2020 JAX Ecosystem在DeepMind的演讲点击这里

引用JAX

要引用此仓库:

@software{jax2018github,
  author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
  title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
  url = {http://github.com/google/jax},
  version = {0.3.13},
  year = {2018},
}

在上述bibtex条目中,名字按字母顺序排列,版本号应为jax/version.py中的版本,年份对应项目的开源发布。

JAX的一个初始版本,仅支持自动微分和编译到XLA,在2018年SysML会议上的一篇论文中有描述。我们目前正在撰写一篇更全面和最新的论文,涵盖JAX的理念和功能。

参考文档

有关JAX API的详细信息,请参阅参考文档

对于希望成为JAX开发者的入门指南,请参阅开发者文档

项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

白日梦AI

白日梦AI提供专注于AI视频生成的多样化功能,包括文生视频、动态画面和形象生成等,帮助用户快速上手,创造专业级内容。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

讯飞绘镜

讯飞绘镜是一个支持从创意到完整视频创作的智能平台,用户可以快速生成视频素材并创作独特的音乐视频和故事。平台提供多样化的主题和精选作品,帮助用户探索创意灵感。

Project Cover

讯飞文书

讯飞文书依托讯飞星火大模型,为文书写作者提供从素材筹备到稿件撰写及审稿的全程支持。通过录音智记和以稿写稿等功能,满足事务性工作的高频需求,帮助撰稿人节省精力,提高效率,优化工作与生活。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号