Flax:为JAX设计的灵活性神经网络库和生态系统
概述 | 快速安装 | Flax是什么样子的? | 文档
📣 新消息:查看NNXAPI!
这个README只是一个简短的介绍。要了解关于Flax的所有信息,请参阅我们的完整文档。
Flax最初由Google Research的Brain团队的工程师和研究人员发起(与JAX团队密切合作),现在与开源社区共同开发。
Flax正被Alphabet各研究部门的数百名人员在日常工作中使用,同时也被越来越多的开源项目社区采用。
Flax团队的使命是服务于不断增长的JAX神经网络研究生态系统——不仅在Alphabet内部,也包括更广泛的社区,并探索JAX表现出色的应用场景。我们几乎所有的协调和计划都在GitHub上进行,也在那里讨论即将到来的设计变更。我们欢迎对任何讨论、问题和拉取请求线程提供反馈。我们正在将一些剩余的内部设计文档和对话线程转移到GitHub的讨论、问题和拉取请求中。我们希望能越来越多地满足更广泛生态系统的需求和澄清要求。请告诉我们我们如何能帮到您!
请在我们的讨论论坛中报告任何功能请求、问题、疑问或担忧,或者只是让我们知道您正在进行什么工作!
我们预计会改进Flax,但不会对核心API进行重大的破坏性更改。我们尽可能使用Changelog条目和弃用警告。
如果您想直接联系我们,我们的邮箱是flax-dev@google.com。
概述
Flax是一个为JAX设计的高性能神经网络库和生态系统,专为灵活性而设计: 通过分叉示例并修改训练循环来尝试新的训练形式,而不是通过向框架添加功能。
Flax正在与JAX团队密切合作开发,并提供开始研究所需的一切,包括:
-
神经网络API(
flax.linen
):Dense、Conv、{Batch|Layer|Group} Norm、Attention、Pooling、{LSTM|GRU} Cell、Dropout -
实用工具和模式:复制训练、序列化和检查点、指标、设备预取
-
开箱即用的教育示例:MNIST、LSTM seq2seq、图神经网络、序列标记
-
快速、调优的大规模端到端示例:CIFAR10、ImageNet上的ResNet、Transformer LM1b
快速安装
您需要Python 3.6或更高版本,以及一个可用的JAX安装(无论是否支持GPU - 参考说明)。 对于仅CPU版本的JAX:
pip install --upgrade pip # 支持manylinux2010 wheels。
pip install --upgrade jax jaxlib # 仅CPU
然后,从PyPi安装Flax:
pip install flax
要升级到最新版本的Flax,可以使用:
pip install --upgrade git+https://github.com/google/flax.git
要安装一些额外的依赖项(如matplotlib
),这些是某些依赖项需要但未包含的,可以使用:
pip install "flax[all]"
Flax是什么样子的?
我们提供了三个使用Flax API的示例:一个简单的多层感知器、一个CNN和一个自动编码器。
要了解更多关于Module
抽象的信息,请查看我们的文档,以及我们的Module抽象广泛介绍。如需更多具体的最佳实践演示,请参考我们的指南和开发者笔记。
from typing import Sequence
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
class MLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, x):
for feat in self.features[:-1]:
x = nn.relu(nn.Dense(feat)(x))
x = nn.Dense(self.features[-1])(x)
return x
model = MLP([12, 8, 4])
batch = jnp.ones((32, 10))
variables = model.init(jax.random.key(0), batch)
output = model.apply(variables, batch)
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # 展平
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
model = CNN()
batch = jnp.ones((32, 64, 64, 10)) # (N, H, W, C) 格式
variables = model.init(jax.random.key(0), batch)
output = model.apply(variables, batch)
class AutoEncoder(nn.Module):
encoder_widths: Sequence[int]
decoder_widths: Sequence[int]
input_shape: Sequence[int]
def setup(self):
input_dim = np.prod(self.input_shape)
self.encoder = MLP(self.encoder_widths)
self.decoder = MLP(self.decoder_widths + (input_dim,))
def __call__(self, x):
return self.decode(self.encode(x))
def encode(self, x):
assert x.shape[1:] == self.input_shape
return self.encoder(jnp.reshape(x, (x.shape[0], -1)))
def decode(self, z):
z = self.decoder(z)
x = nn.sigmoid(z)
x = jnp.reshape(x, (x.shape[0],) + self.input_shape)
return x
model = AutoEncoder(encoder_widths=[20, 10, 5],
decoder_widths=[5, 10, 20],
input_shape=(12,))
batch = jnp.ones((16, 12))
variables = model.init(jax.random.key(0), batch)
encoded = model.apply(variables, batch, method=model.encode)
decoded = model.apply(variables, encoded, method=model.decode)
🤗 Hugging Face
在🤗 Transformers 仓库中,正在积极维护用于训练和评估各种Flax模型的详细示例,涵盖了自然语言处理、计算机视觉和语音识别领域。
截至2021年10月,Flax支持19种最常用的Transformer架构,并且已有超过5000个预训练的Flax检查点上传到🤗 Hub。
引用Flax
要引用此仓库:
@software{flax2020github,
author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},
title = {{F}lax: A neural network library and ecosystem for {JAX}},
url = {http://github.com/google/flax},
version = {0.8.6},
year = {2023},
}
在上述bibtex条目中,姓名按字母顺序排列,版本号来自flax/version.py,年份对应项目的开源发布年份。
注意
Flax 是一个由谷歌研究院专门团队维护的开源项目,但并非谷歌的官方产品。