俳句:{JAX}的{十四行诗}
概述 | 为什么选择俳句? | 快速开始 | 安装 | 示例 | 用户手册 | 文档 | 引用俳句
[!重要] 📣 截至2023年7月,Google DeepMind建议新项目采用Flax而不是俳句。Flax是由Google Brain最初开发,现在由Google DeepMind开发的神经网络库。 📣
在撰写本文时,Flax拥有比俳句更多的特性,较大的开发团队 和更活跃的社区。Flax在用户群体中有更高的接受度,具有更广泛的文档, 示例和一个活跃的社区来创建从头到尾的示例。
俳句将保持尽力支持,但项目将进入维护模式,这意味着开发工作的重点将放在错误修复和与JAX新版本的兼容性上。
为了保持俳句与Python和JAX的新版本的兼容,我们将继续发布新的版本,但不会添加(或接受PR)新功能。
我们在Google DeepMind内部广泛使用俳句,并计划无限期地以这种模式支持俳句。
什么是俳句?
俳句是一种工具
用于构建神经网络
想象一下:“{JAX}的{十四行诗}”
俳句是由Sonnet的一些作者为JAX开发的简单神经网络库,Sonnet是一个为TensorFlow开发的神经网络库。
俳句的文档可以在https://dm-haiku.readthedocs.io/找到。
**说明:**如果您在寻找操作系统俳句,请访问https://haiku-os.org/。
概述
JAX是一个结合NumPy、自动微分和一流GPU/TPU支持的数值计算库。
俳句是一个简单的JAX神经网络库,允许用户使用熟悉的面向对象编程模型,同时允许完全访问JAX的纯函数转换。
俳句提供了两个核心工具:模块抽象hk.Module
和一个简单的函数转换hk.transform
。
hk.Module
是持有自身参数、其他模块和应用用户输入函数的方法的Python对象。
hk.transform
将使用这些面向对象、功能上“纯净”模块的函数转换成可以与jax.jit
、jax.grad
、jax.pmap
等一起使用的纯函数。
为什么选择俳句?
有很多为JAX开发的神经网络库。为什么要选择俳句?
俳句已经通过DeepMind的研究人员在大规模测试。
- DeepMind已经用俳句和JAX重新实现了许多实验的结果,包括图像和语言处理、大规模的生成模型和强化学习。
俳句是一个库,而不是框架。
- 俳句旨在简化特定任务:管理模型参数和其他模型状态。
- 你可以期望俳句与其他库组合,并且与JAX的其余部分很好地协作。
- 除此之外,俳句设计得很少干涉用户的工作——它不定义自定义优化器、检查点格式或复制API。
俳句没有重复发明轮子。
- 俳句基于DeepMind几乎普遍采用的神经网络库Sonnet的编程模型和API。它保留了Sonnet的用于状态管理的
Module
(模块)编程模型,同时保留了对JAX的函数转换的访问。 - 俳句的API和抽象与Sonnet尽可能接近。许多用户发现Sonnet是TensorFlow中一个富有成效的编程模型;俳句在JAX中实现了相同的体验。
迁移到俳句很容易。
- 通过设计,从TensorFlow和Sonnet迁移到JAX和俳句是很容易的。
- 除了新功能(如
hk.transform
)外,俳句旨在匹配Sonnet 2的API。模块、方法、参数名称、默认值和初始化方案应该匹配。
俳句简化了JAX的其他方面。
- 俳句提供了一个简单的模型来处理随机数。在转换后的函数中,
hk.next_rng_key()
返回一个唯一的随机数生成器键。 - 这些唯一键是从传递给顶级转换函数的初始随机键派生的,因此可以安全地与JAX程序转换一起使用。
快速开始
让我们来看一个示例神经网络、损失函数和训练循环。(欲了解更多示例, 请参见我们的示例目录。 MNIST 示例 是一个很好的起点。)
import haiku as hk
import jax.numpy as jnp
def softmax_cross_entropy(logits, labels):
one_hot = jax.nn.one_hot(labels, logits.shape[-1])
return -jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1)
def loss_fn(images, labels):
mlp = hk.Sequential([
hk.Linear(300), jax.nn.relu,
hk.Linear(100), jax.nn.relu,
hk.Linear(10),
])
logits = mlp(images)
return jnp.mean(softmax_cross_entropy(logits, labels))
loss_fn_t = hk.transform(loss_fn)
loss_fn_t = hk.without_apply_rng(loss_fn_t)
rng = jax.random.PRNGKey(42)
dummy_images, dummy_labels = next(input_dataset)
params = loss_fn_t.init(rng, dummy_images, dummy_labels)
def update_rule(param, update):
return param - 0.01 * update
for images, labels in input_dataset:
grads = jax.grad(loss_fn_t.apply)(params, images, labels)
params = jax.tree_util.tree_map(update_rule, params, grads)
俳句的核心是hk.transform
。transform
函数允许您编写依赖于参数的神经网络函数(此处是Linear
层的权重)而不需要明确编写初始化那些参数的样板代码。transform
通过将函数转换成纯函数对init
和apply
实现的形式来做到这点。
init
init
函数,签名为params = init(rng, ...)
(其中...
是未转换函数的参数),允许您收集网络中任何参数的初始值。俳句通过运行您的函数,跟踪任何通过hk.get_parameter
(由如hk.Linear
调用)请求的参数并返回给您。
params
对象是您的网络中所有参数的嵌套数据结构,设计供您检查和操作。
具体地,它是模块名称到模块参数的映射,其中模块参数是参数名称到参数值的映射。例如:
{'linear': {'b': ndarray(..., shape=(300,), dtype=float32),
'w': ndarray(..., shape=(28, 300), dtype=float32)},
'linear_1': {'b': ndarray(..., shape=(100,), dtype=float32),
'w': ndarray(..., shape=(1000, 100), dtype=float32)},
'linear_2': {'b': ndarray(..., shape=(10,), dtype=float32),
'w': ndarray(..., shape=(100, 10), dtype=float32)}}
apply
apply
函数,签名为result = apply(params, rng, ...)
,允许您向函数注入参数值。当调用hk.get_parameter
时,返回的值将来自您提供给apply
作为输入的params
:
loss = loss_fn_t.apply(params, rng, images, labels)
注意,由于我们的损失函数执行的实际计算不依赖于随机数,因此传入一个随机数生成器是没有必要的,因此我们也可以传入None
给rng
参数。 (请注意,如果您的计算确实使用了随机数,传入None
给rng
将引发错误)。在上面的示例中,我们让俳句自动对我们执行此操作:
loss_fn_t = hk.without_apply_rng(loss_fn_t)
既然apply
是一个纯函数,我们可以将其传递给jax.grad
(或其他JAX变换):
grads = jax.grad(loss_fn_t.apply)(params, images, labels)
训练
此示例中的训练循环非常简单。一个需要注意的细节是使用jax.tree_util.tree_map
来将sgd
函数应用于params
和grads
中的所有匹配条目。结果具有与原始params
相同的结构,并可以再次用于apply
。
安装
俳句是用纯Python编写的,但依赖于通过JAX的C++代码。
由于JAX的安装因CUDA版本而异,俳句没有在requirements.txt
中列出JAX作为依赖项。
首先,按照这些说明来安装带有相关加速器支持的JAX。
然后,使用pip安装俳句:
$ pip install git+https://github.com/deepmind/dm-haiku
或者,您可以通过PyPI安装:
$ pip install -U dm-haiku
我们的示例依赖于额外的库(例如bsuite)。您可以使用pip安装所有额外的依赖:
$ pip install -r examples/requirements.txt
用户手册
编写您自己的模块
在俳句中,所有模块都是hk.Module
的子类。您可以实现任何方法(没有被特殊处理),但通常模块实现__init__
和__call__
。
让我们一起实现一个线性层:
class MyLinear(hk.Module):
def __init__(self, output_size, name=None):
super().__init__(name=name)
self.output_size = output_size
def __call__(self, x):
j, k = x.shape[-1], self.output_size
w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))
w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)
b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.zeros)
return jnp.dot(x, w) + b
所有模块都有一个名称。如果没有传递name
参数,模块名将从Python类名中推断(例如MyLinear
变为my_linear
)。模块可以有通过hk.get_parameter(param_name, ...)
访问的命名参数。我们使用这个API(而不是直接使用对象属性),这样我们可以使用hk.transform
将您的代码转换为纯函数。
使用模块时,您需要定义函数并使用hk.transform
将其转换为一对纯函数。有关转换后函数返回的函数的详细信息,请参见我们的快速开始:
def forward_fn(x):
model = MyLinear(10)
return model(x)
# 将`forward_fn`转换为具有`init`和`apply`方法的对象。 默认情况下,
# `apply`将需要一个rng(可以是None),用于
# `hk.next_rng_key`。
forward = hk.transform(forward_fn)
x = jnp.ones([1, 1])
# 当我们运行`forward.init`时,俳句将运行`forward_fn(x)`并收集初始参数值。由于参数
# 通常是随机初始化的,因此俳句要求您传递一个随机数生成器键给`init`:
key = hk.PRNGSequence(42)
params = forward.init(next(key), x)
# 当我们运行`forward.apply`时,俳句将运行`forward_fn(x)`并从作为第一个参数传递的`params`中注入参数
# 值。 请注意,通过`hk.transform(f)`转换的模型必须使用额外的
# `rng`参数调用:`forward.apply(params, rng, x)`。 使用
# `hk.without_apply_rng(hk.transform(f))`如果不需要这样做。
y = forward.apply(params, None, x)
处理随机模型
一些模型可能需要随机采样作为计算的一部分。例如,在使用重参数化技巧的变分自编码器中,需要从标准正态分布中进行随机采样。对于dropout,我们需要一个随机掩码来从输入中丢弃单元。使其与JAX一起工作的主要障碍在于PRNG键的管理。
在Haiku中,我们提供了一个简单的API来维护与模块关联的PRNG键序列:hk.next_rng_key()
(或对于多个键使用next_rng_keys()
):
class MyDropout(hk.Module):
def __init__(self, rate=0.5, name=None):
super().__init__(name=name)
self.rate = rate
def __call__(self, x):
key = hk.next_rng_key()
p = jax.random.bernoulli(key, 1.0 - self.rate, shape=x.shape)
return x * p / (1.0 - self.rate)
forward = hk.transform(lambda x: MyDropout()(x))
key1, key2 = jax.random.split(jax.random.PRNGKey(42), 2)
params = forward.init(key1, x)
prediction = forward.apply(params, key2, x)
要更全面地了解与随机模型一起工作的情况,请参阅我们的VAE示例。
注意: hk.next_rng_key()
不是功能纯的,这意味着你应该避免在hk.transform
内使用它与JAX变换一起使用。欲了解更多信息和可能的解决方法,请查阅Haiku变换的文档和可用的Haiku网络中的JAX变换包装器。
使用不可训练状态
一些模型可能希望维护一些内部的、可变的状态。例如,在批量归一化中,训练过程中遇到的值的移动平均值是维护的。
在Haiku中,我们提供了一个简单的API来维护与模块关联的可变状态:hk.set_state
和hk.get_state
。使用这些函数时,您需要使用hk.transform_with_state
转换您的函数,因为返回的函数对的签名是不同的:
def forward(x, is_training):
net = hk.nets.ResNet50(1000)
return net(x, is_training)
forward = hk.transform_with_state(forward)
# `init`函数现在返回参数和状态。状态包含使用`hk.set_state`创建的任何内容。结构与参数相同(例如,这是一个按模块命名值的映射)。
params, state = forward.init(rng, x, is_training=True)
# `apply`函数现在接受参数和状态。此外,它将返回更新的状态值。在resnet示例中,这将是用于批量规范化层中的移动平均值的更新值。
logits, state = forward.apply(params, state, rng, x, is_training=True)
如果你忘记使用hk.transform_with_state
,不要担心,我们会打印一个明确的错误,指向你hk.transform_with_state
,而不是默默地丢弃你的状态。
使用jax.pmap
进行分布式训练
从hk.transform
(或hk.transform_with_state
)返回的纯函数完全兼容jax.pmap
。有关使用jax.pmap
进行SPMD编程的更多详细信息,请查看此处。
在Haiku中使用jax.pmap
的一个常见用途是对许多加速器进行数据并行训练,可能跨多个主机。在Haiku中,这可能如下所示:
def loss_fn(inputs, labels):
logits = hk.nets.MLP([8, 4, 2])(x)
return jnp.mean(softmax_cross_entropy(logits, labels))
loss_fn_t = hk.transform(loss_fn)
loss_fn_t = hk.without_apply_rng(loss_fn_t)
# 在单个设备上初始化模型。
rng = jax.random.PRNGKey(428)
sample_image, sample_label = next(input_dataset)
params = loss_fn_t.init(rng, sample_image, sample_label)
# 将参数复制到所有设备上。
num_devices = jax.local_device_count()
params = jax.tree_util.tree_map(lambda x: np.stack([x] * num_devices), params)
def make_superbatch():
"""构造一个超级批次,即每个设备一个数据批次"""
# 获取N个批次,然后拆分成图像列表和标签列表。
superbatch = [next(input_dataset) for _ in range(num_devices)]
superbatch_images, superbatch_labels = zip(*superbatch)
# 将超级批次堆叠为一个具有前导维度的数组,而不是一个Python列表。这是`jax.pmap`期望的输入。
superbatch_images = np.stack(superbatch_images)
superbatch_labels = np.stack(superbatch_labels)
return superbatch_images, superbatch_labels
def update(params, inputs, labels, axis_name='i'):
"""基于输入和标签的表现更新参数。"""
grads = jax.grad(loss_fn_t.apply)(params, inputs, labels)
# 在所有数据并行副本之间取梯度的平均值。
grads = jax.lax.pmean(grads, axis_name)
# 使用SGD或Adam或...更新参数
new_params = my_update_rule(params, grads)
return new_params
# 进行几次训练更新。
for _ in range(10):
superbatch_images, superbatch_labels = make_superbatch()
params = jax.pmap(update, axis_name='i')(params, superbatch_images,
superbatch_labels)
要更全面地了解分布式Haiku训练,请查看我们的ImageNet上的ResNet-50示例。
引用Haiku
要引用此存储库:
@software{haiku2020github,
author = {Tom Hennigan and Trevor Cai and Tamara Norman and Lena Martens and Igor Babuschkin},
title = {{H}aiku: {S}onnet for {JAX}},
url = {http://github.com/deepmind/dm-haiku},
version = {0.0.10},
year = {2020},
}
在此bibtex条目中,版本号应取自haiku/__init__.py
,年份对应于项目的开源发布年份。