Project Icon

flax

灵活强大的JAX神经网络库和生态系统

Flax是一个基于JAX的高性能神经网络库,以灵活性为核心设计理念。它提供神经网络API、实用工具、教育示例和优化的大规模端到端示例。Flax支持MLP、CNN和自编码器等多种网络结构,并与Hugging Face集成,涵盖自然语言处理、计算机视觉和语音识别等领域。作为Google Research与开源社区合作开发的项目,Flax致力于促进JAX神经网络研究生态系统的发展。

logo

Flax:为JAX设计的灵活性神经网络库和生态系统

构建 覆盖率

概述 | 快速安装 | Flax是什么样子的? | 文档

📣 新消息:查看NNXAPI!

这个README只是一个简短的介绍。要了解关于Flax的所有信息,请参阅我们的完整文档

Flax最初由Google Research的Brain团队的工程师和研究人员发起(与JAX团队密切合作),现在与开源社区共同开发。

Flax正被Alphabet各研究部门的数百名人员在日常工作中使用,同时也被越来越多的开源项目社区采用。

Flax团队的使命是服务于不断增长的JAX神经网络研究生态系统——不仅在Alphabet内部,也包括更广泛的社区,并探索JAX表现出色的应用场景。我们几乎所有的协调和计划都在GitHub上进行,也在那里讨论即将到来的设计变更。我们欢迎对任何讨论、问题和拉取请求线程提供反馈。我们正在将一些剩余的内部设计文档和对话线程转移到GitHub的讨论、问题和拉取请求中。我们希望能越来越多地满足更广泛生态系统的需求和澄清要求。请告诉我们我们如何能帮到您!

请在我们的讨论论坛中报告任何功能请求、问题、疑问或担忧,或者只是让我们知道您正在进行什么工作!

我们预计会改进Flax,但不会对核心API进行重大的破坏性更改。我们尽可能使用Changelog条目和弃用警告。

如果您想直接联系我们,我们的邮箱是flax-dev@google.com

概述

Flax是一个为JAX设计的高性能神经网络库和生态系统,专为灵活性而设计: 通过分叉示例并修改训练循环来尝试新的训练形式,而不是通过向框架添加功能。

Flax正在与JAX团队密切合作开发,并提供开始研究所需的一切,包括:

  • 神经网络APIflax.linen):Dense、Conv、{Batch|Layer|Group} Norm、Attention、Pooling、{LSTM|GRU} Cell、Dropout

  • 实用工具和模式:复制训练、序列化和检查点、指标、设备预取

  • 开箱即用的教育示例:MNIST、LSTM seq2seq、图神经网络、序列标记

  • 快速、调优的大规模端到端示例:CIFAR10、ImageNet上的ResNet、Transformer LM1b

快速安装

您需要Python 3.6或更高版本,以及一个可用的JAX安装(无论是否支持GPU - 参考说明)。 对于仅CPU版本的JAX:

pip install --upgrade pip # 支持manylinux2010 wheels。
pip install --upgrade jax jaxlib # 仅CPU

然后,从PyPi安装Flax:

pip install flax

要升级到最新版本的Flax,可以使用:

pip install --upgrade git+https://github.com/google/flax.git

要安装一些额外的依赖项(如matplotlib),这些是某些依赖项需要但未包含的,可以使用:

pip install "flax[all]"

Flax是什么样子的?

我们提供了三个使用Flax API的示例:一个简单的多层感知器、一个CNN和一个自动编码器。 要了解更多关于Module抽象的信息,请查看我们的文档,以及我们的Module抽象广泛介绍。如需更多具体的最佳实践演示,请参考我们的指南开发者笔记

from typing import Sequence

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn

class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.relu(nn.Dense(feat)(x))
    x = nn.Dense(self.features[-1])(x)
    return x

model = MLP([12, 8, 4])
batch = jnp.ones((32, 10))
variables = model.init(jax.random.key(0), batch)
output = model.apply(variables, batch)
class CNN(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # 展平
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    x = nn.log_softmax(x)
    return x

model = CNN()
batch = jnp.ones((32, 64, 64, 10))  # (N, H, W, C) 格式
variables = model.init(jax.random.key(0), batch)
output = model.apply(variables, batch)
class AutoEncoder(nn.Module):
  encoder_widths: Sequence[int]
  decoder_widths: Sequence[int]
  input_shape: Sequence[int]

  def setup(self):
    input_dim = np.prod(self.input_shape)
    self.encoder = MLP(self.encoder_widths)
    self.decoder = MLP(self.decoder_widths + (input_dim,))

  def __call__(self, x):
    return self.decode(self.encode(x))

  def encode(self, x):
    assert x.shape[1:] == self.input_shape
    return self.encoder(jnp.reshape(x, (x.shape[0], -1)))

  def decode(self, z):
    z = self.decoder(z)
    x = nn.sigmoid(z)
    x = jnp.reshape(x, (x.shape[0],) + self.input_shape)
    return x

model = AutoEncoder(encoder_widths=[20, 10, 5],
                    decoder_widths=[5, 10, 20],
                    input_shape=(12,))
batch = jnp.ones((16, 12))
variables = model.init(jax.random.key(0), batch)
encoded = model.apply(variables, batch, method=model.encode)
decoded = model.apply(variables, encoded, method=model.decode)

🤗 Hugging Face

🤗 Transformers 仓库中,正在积极维护用于训练和评估各种Flax模型的详细示例,涵盖了自然语言处理、计算机视觉和语音识别领域。

截至2021年10月,Flax支持19种最常用的Transformer架构,并且已有超过5000个预训练的Flax检查点上传到🤗 Hub

引用Flax

要引用此仓库:

@software{flax2020github,
  author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},
  title = {{F}lax: A neural network library and ecosystem for {JAX}},
  url = {http://github.com/google/flax},
  version = {0.8.6},
  year = {2023},
}

在上述bibtex条目中,姓名按字母顺序排列,版本号来自flax/version.py,年份对应项目的开源发布年份。

注意

Flax 是一个由谷歌研究院专门团队维护的开源项目,但并非谷歌的官方产品。

项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

白日梦AI

白日梦AI提供专注于AI视频生成的多样化功能,包括文生视频、动态画面和形象生成等,帮助用户快速上手,创造专业级内容。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

讯飞绘镜

讯飞绘镜是一个支持从创意到完整视频创作的智能平台,用户可以快速生成视频素材并创作独特的音乐视频和故事。平台提供多样化的主题和精选作品,帮助用户探索创意灵感。

Project Cover

讯飞文书

讯飞文书依托讯飞星火大模型,为文书写作者提供从素材筹备到稿件撰写及审稿的全程支持。通过录音智记和以稿写稿等功能,满足事务性工作的高频需求,帮助撰稿人节省精力,提高效率,优化工作与生活。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号