Project Icon

dm-haiku

JAX神经网络构建的简洁解决方案

Haiku是一个为JAX设计的简洁神经网络库,具备面向对象编程模型和纯函数转换功能。由Sonnet的开发者创建,Haiku能简化模型参数和状态管理,并与其他JAX库无缝集成。虽然Google DeepMind建议新项目使用Flax,Haiku仍将在维护模式下持续支持,专注于修复bug和兼容性更新。

俳句:{JAX}的{十四行诗}

概述 | 为什么选择俳句? | 快速开始 | 安装 | 示例 | 用户手册 | 文档 | 引用俳句

pytest docs pypi

[!重要] 📣 截至2023年7月,Google DeepMind建议新项目采用Flax而不是俳句。Flax是由Google Brain最初开发,现在由Google DeepMind开发的神经网络库。 📣

在撰写本文时,Flax拥有比俳句更多的特性,较大的开发团队更活跃的社区Flax在用户群体中有更高的接受度,具有更广泛的文档示例和一个活跃的社区来创建从头到尾的示例。

俳句将保持尽力支持,但项目将进入维护模式,这意味着开发工作的重点将放在错误修复和与JAX新版本的兼容性上。

为了保持俳句与Python和JAX的新版本的兼容,我们将继续发布新的版本,但不会添加(或接受PR)新功能。

我们在Google DeepMind内部广泛使用俳句,并计划无限期地以这种模式支持俳句。

什么是俳句?

俳句是一种工具
用于构建神经网络
想象一下:“{JAX}的{十四行诗}”

俳句是由Sonnet的一些作者为JAX开发的简单神经网络库,Sonnet是一个为TensorFlow开发的神经网络库。

俳句的文档可以在https://dm-haiku.readthedocs.io/找到。

**说明:**如果您在寻找操作系统俳句,请访问https://haiku-os.org/。

概述

JAX是一个结合NumPy、自动微分和一流GPU/TPU支持的数值计算库。

俳句是一个简单的JAX神经网络库,允许用户使用熟悉的面向对象编程模型,同时允许完全访问JAX的纯函数转换。

俳句提供了两个核心工具:模块抽象hk.Module和一个简单的函数转换hk.transform

hk.Module是持有自身参数、其他模块和应用用户输入函数的方法的Python对象。

hk.transform将使用这些面向对象、功能上“纯净”模块的函数转换成可以与jax.jitjax.gradjax.pmap等一起使用的纯函数。

为什么选择俳句?

有很多为JAX开发的神经网络库。为什么要选择俳句?

俳句已经通过DeepMind的研究人员在大规模测试。

  • DeepMind已经用俳句和JAX重新实现了许多实验的结果,包括图像和语言处理、大规模的生成模型和强化学习。

俳句是一个库,而不是框架。

  • 俳句旨在简化特定任务:管理模型参数和其他模型状态。
  • 你可以期望俳句与其他库组合,并且与JAX的其余部分很好地协作。
  • 除此之外,俳句设计得很少干涉用户的工作——它不定义自定义优化器、检查点格式或复制API。

俳句没有重复发明轮子。

  • 俳句基于DeepMind几乎普遍采用的神经网络库Sonnet的编程模型和API。它保留了Sonnet的用于状态管理的Module(模块)编程模型,同时保留了对JAX的函数转换的访问。
  • 俳句的API和抽象与Sonnet尽可能接近。许多用户发现Sonnet是TensorFlow中一个富有成效的编程模型;俳句在JAX中实现了相同的体验。

迁移到俳句很容易。

  • 通过设计,从TensorFlow和Sonnet迁移到JAX和俳句是很容易的。
  • 除了新功能(如hk.transform)外,俳句旨在匹配Sonnet 2的API。模块、方法、参数名称、默认值和初始化方案应该匹配。

俳句简化了JAX的其他方面。

  • 俳句提供了一个简单的模型来处理随机数。在转换后的函数中,hk.next_rng_key()返回一个唯一的随机数生成器键。
  • 这些唯一键是从传递给顶级转换函数的初始随机键派生的,因此可以安全地与JAX程序转换一起使用。

快速开始

让我们来看一个示例神经网络、损失函数和训练循环。(欲了解更多示例, 请参见我们的示例目录MNIST 示例 是一个很好的起点。)

import haiku as hk
import jax.numpy as jnp

def softmax_cross_entropy(logits, labels):
  one_hot = jax.nn.one_hot(labels, logits.shape[-1])
  return -jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1)

def loss_fn(images, labels):
  mlp = hk.Sequential([
      hk.Linear(300), jax.nn.relu,
      hk.Linear(100), jax.nn.relu,
      hk.Linear(10),
  ])
  logits = mlp(images)
  return jnp.mean(softmax_cross_entropy(logits, labels))

loss_fn_t = hk.transform(loss_fn)
loss_fn_t = hk.without_apply_rng(loss_fn_t)

rng = jax.random.PRNGKey(42)
dummy_images, dummy_labels = next(input_dataset)
params = loss_fn_t.init(rng, dummy_images, dummy_labels)

def update_rule(param, update):
  return param - 0.01 * update

for images, labels in input_dataset:
  grads = jax.grad(loss_fn_t.apply)(params, images, labels)
  params = jax.tree_util.tree_map(update_rule, params, grads)

俳句的核心是hk.transformtransform函数允许您编写依赖于参数的神经网络函数(此处是Linear层的权重)而不需要明确编写初始化那些参数的样板代码。transform通过将函数转换成纯函数对initapply实现的形式来做到这点。

init

init函数,签名为params = init(rng, ...)(其中...是未转换函数的参数),允许您收集网络中任何参数的初始值。俳句通过运行您的函数,跟踪任何通过hk.get_parameter(由如hk.Linear调用)请求的参数并返回给您。

params对象是您的网络中所有参数的嵌套数据结构,设计供您检查和操作。 具体地,它是模块名称到模块参数的映射,其中模块参数是参数名称到参数值的映射。例如:

{'linear': {'b': ndarray(..., shape=(300,), dtype=float32),
            'w': ndarray(..., shape=(28, 300), dtype=float32)},
 'linear_1': {'b': ndarray(..., shape=(100,), dtype=float32),
              'w': ndarray(..., shape=(1000, 100), dtype=float32)},
 'linear_2': {'b': ndarray(..., shape=(10,), dtype=float32),
              'w': ndarray(..., shape=(100, 10), dtype=float32)}}

apply

apply函数,签名为result = apply(params, rng, ...),允许您向函数注入参数值。当调用hk.get_parameter时,返回的值将来自您提供给apply作为输入的params

loss = loss_fn_t.apply(params, rng, images, labels)

注意,由于我们的损失函数执行的实际计算不依赖于随机数,因此传入一个随机数生成器是没有必要的,因此我们也可以传入Nonerng参数。 (请注意,如果您的计算确实使用了随机数,传入Nonerng将引发错误)。在上面的示例中,我们让俳句自动对我们执行此操作:

loss_fn_t = hk.without_apply_rng(loss_fn_t)

既然apply是一个纯函数,我们可以将其传递给jax.grad(或其他JAX变换):

grads = jax.grad(loss_fn_t.apply)(params, images, labels)

训练

此示例中的训练循环非常简单。一个需要注意的细节是使用jax.tree_util.tree_map来将sgd函数应用于paramsgrads中的所有匹配条目。结果具有与原始params相同的结构,并可以再次用于apply

安装

俳句是用纯Python编写的,但依赖于通过JAX的C++代码。

由于JAX的安装因CUDA版本而异,俳句没有在requirements.txt中列出JAX作为依赖项。

首先,按照这些说明来安装带有相关加速器支持的JAX。

然后,使用pip安装俳句:

$ pip install git+https://github.com/deepmind/dm-haiku

或者,您可以通过PyPI安装:

$ pip install -U dm-haiku

我们的示例依赖于额外的库(例如bsuite)。您可以使用pip安装所有额外的依赖:

$ pip install -r examples/requirements.txt

用户手册

编写您自己的模块

在俳句中,所有模块都是hk.Module的子类。您可以实现任何方法(没有被特殊处理),但通常模块实现__init____call__

让我们一起实现一个线性层:

class MyLinear(hk.Module):

  def __init__(self, output_size, name=None):
    super().__init__(name=name)
    self.output_size = output_size

  def __call__(self, x):
    j, k = x.shape[-1], self.output_size
    w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))
    w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)
    b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.zeros)
    return jnp.dot(x, w) + b

所有模块都有一个名称。如果没有传递name参数,模块名将从Python类名中推断(例如MyLinear变为my_linear)。模块可以有通过hk.get_parameter(param_name, ...)访问的命名参数。我们使用这个API(而不是直接使用对象属性),这样我们可以使用hk.transform将您的代码转换为纯函数。

使用模块时,您需要定义函数并使用hk.transform将其转换为一对纯函数。有关转换后函数返回的函数的详细信息,请参见我们的快速开始

def forward_fn(x):
  model = MyLinear(10)
  return model(x)

# 将`forward_fn`转换为具有`init`和`apply`方法的对象。 默认情况下,
# `apply`将需要一个rng(可以是None),用于
# `hk.next_rng_key`。
forward = hk.transform(forward_fn)

x = jnp.ones([1, 1])

# 当我们运行`forward.init`时,俳句将运行`forward_fn(x)`并收集初始参数值。由于参数
# 通常是随机初始化的,因此俳句要求您传递一个随机数生成器键给`init`:
key = hk.PRNGSequence(42)
params = forward.init(next(key), x)

# 当我们运行`forward.apply`时,俳句将运行`forward_fn(x)`并从作为第一个参数传递的`params`中注入参数
# 值。 请注意,通过`hk.transform(f)`转换的模型必须使用额外的
# `rng`参数调用:`forward.apply(params, rng, x)`。 使用
# `hk.without_apply_rng(hk.transform(f))`如果不需要这样做。
y = forward.apply(params, None, x)

处理随机模型

一些模型可能需要随机采样作为计算的一部分。例如,在使用重参数化技巧的变分自编码器中,需要从标准正态分布中进行随机采样。对于dropout,我们需要一个随机掩码来从输入中丢弃单元。使其与JAX一起工作的主要障碍在于PRNG键的管理。

在Haiku中,我们提供了一个简单的API来维护与模块关联的PRNG键序列:hk.next_rng_key()(或对于多个键使用next_rng_keys()):

class MyDropout(hk.Module):

  def __init__(self, rate=0.5, name=None):
    super().__init__(name=name)
    self.rate = rate

  def __call__(self, x):
    key = hk.next_rng_key()
    p = jax.random.bernoulli(key, 1.0 - self.rate, shape=x.shape)
    return x * p / (1.0 - self.rate)

forward = hk.transform(lambda x: MyDropout()(x))

key1, key2 = jax.random.split(jax.random.PRNGKey(42), 2)
params = forward.init(key1, x)
prediction = forward.apply(params, key2, x)

要更全面地了解与随机模型一起工作的情况,请参阅我们的VAE示例

注意: hk.next_rng_key()不是功能纯的,这意味着你应该避免在hk.transform内使用它与JAX变换一起使用。欲了解更多信息和可能的解决方法,请查阅Haiku变换的文档和可用的Haiku网络中的JAX变换包装器

使用不可训练状态

一些模型可能希望维护一些内部的、可变的状态。例如,在批量归一化中,训练过程中遇到的值的移动平均值是维护的。

在Haiku中,我们提供了一个简单的API来维护与模块关联的可变状态:hk.set_statehk.get_state。使用这些函数时,您需要使用hk.transform_with_state转换您的函数,因为返回的函数对的签名是不同的:

def forward(x, is_training):
  net = hk.nets.ResNet50(1000)
  return net(x, is_training)

forward = hk.transform_with_state(forward)

# `init`函数现在返回参数和状态。状态包含使用`hk.set_state`创建的任何内容。结构与参数相同(例如,这是一个按模块命名值的映射)。
params, state = forward.init(rng, x, is_training=True)

# `apply`函数现在接受参数和状态。此外,它将返回更新的状态值。在resnet示例中,这将是用于批量规范化层中的移动平均值的更新值。
logits, state = forward.apply(params, state, rng, x, is_training=True)

如果你忘记使用hk.transform_with_state,不要担心,我们会打印一个明确的错误,指向你hk.transform_with_state,而不是默默地丢弃你的状态。

使用jax.pmap进行分布式训练

hk.transform(或hk.transform_with_state)返回的纯函数完全兼容jax.pmap。有关使用jax.pmap进行SPMD编程的更多详细信息,请查看此处

在Haiku中使用jax.pmap的一个常见用途是对许多加速器进行数据并行训练,可能跨多个主机。在Haiku中,这可能如下所示:

def loss_fn(inputs, labels):
  logits = hk.nets.MLP([8, 4, 2])(x)
  return jnp.mean(softmax_cross_entropy(logits, labels))

loss_fn_t = hk.transform(loss_fn)
loss_fn_t = hk.without_apply_rng(loss_fn_t)

# 在单个设备上初始化模型。
rng = jax.random.PRNGKey(428)
sample_image, sample_label = next(input_dataset)
params = loss_fn_t.init(rng, sample_image, sample_label)

# 将参数复制到所有设备上。
num_devices = jax.local_device_count()
params = jax.tree_util.tree_map(lambda x: np.stack([x] * num_devices), params)

def make_superbatch():
  """构造一个超级批次,即每个设备一个数据批次"""
  # 获取N个批次,然后拆分成图像列表和标签列表。
  superbatch = [next(input_dataset) for _ in range(num_devices)]
  superbatch_images, superbatch_labels = zip(*superbatch)
  # 将超级批次堆叠为一个具有前导维度的数组,而不是一个Python列表。这是`jax.pmap`期望的输入。
  superbatch_images = np.stack(superbatch_images)
  superbatch_labels = np.stack(superbatch_labels)
  return superbatch_images, superbatch_labels

def update(params, inputs, labels, axis_name='i'):
  """基于输入和标签的表现更新参数。"""
  grads = jax.grad(loss_fn_t.apply)(params, inputs, labels)
  # 在所有数据并行副本之间取梯度的平均值。
  grads = jax.lax.pmean(grads, axis_name)
  # 使用SGD或Adam或...更新参数
  new_params = my_update_rule(params, grads)
  return new_params

# 进行几次训练更新。
for _ in range(10):
  superbatch_images, superbatch_labels = make_superbatch()
  params = jax.pmap(update, axis_name='i')(params, superbatch_images,
                                           superbatch_labels)

要更全面地了解分布式Haiku训练,请查看我们的ImageNet上的ResNet-50示例

引用Haiku

要引用此存储库:

@software{haiku2020github,
  author = {Tom Hennigan and Trevor Cai and Tamara Norman and Lena Martens and Igor Babuschkin},
  title = {{H}aiku: {S}onnet for {JAX}},
  url = {http://github.com/deepmind/dm-haiku},
  version = {0.0.10},
  year = {2020},
}

在此bibtex条目中,版本号应取自haiku/__init__.py,年份对应于项目的开源发布年份。

项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

吐司

探索Tensor.Art平台的独特AI模型,免费访问各种图像生成与AI训练工具,从Stable Diffusion等基础模型开始,轻松实现创新图像生成。体验前沿的AI技术,推动个人和企业的创新发展。

Project Cover

SubCat字幕猫

SubCat字幕猫APP是一款创新的视频播放器,它将改变您观看视频的方式!SubCat结合了先进的人工智能技术,为您提供即时视频字幕翻译,无论是本地视频还是网络流媒体,让您轻松享受各种语言的内容。

Project Cover

美间AI

美间AI创意设计平台,利用前沿AI技术,为设计师和营销人员提供一站式设计解决方案。从智能海报到3D效果图,再到文案生成,美间让创意设计更简单、更高效。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号