Awesome JAX
JAX 通过使用类似 NumPy 的 API 将自动微分和 XLA 编译器 集成在一起,以实现 GPU 和 TPU 等加速器上高性能的机器学习研究。
这是一个精心整理的 JAX 库、项目及其他资源的优秀列表。欢迎任何贡献!
目录
库
- 神经网络库
- Flax - 注重灵活性和清晰性。
- Haiku - 由 DeepMind 的 Sonnet 作者创建,注重简单性。
- Objax - 拥有类似于 PyTorch 的面向对象设计。
- Elegy - JAX 中的高级深度学习 API,支持 Flax、Haiku 和 Optax。
- Trax - 提供常见工作负载解决方案的“内置电池”深度学习库。
- Jraph - 轻量级图神经网络库。
- Neural Tangents - 用于指定有限和_无限_宽度神经网络的高级 API。
- HuggingFace - 为广泛的自然语言任务提供预训练的 Transformers 生态系统(Flax)。
- Equinox - 可调用的 PyTrees 和过滤 JIT/grad 转换 => JAX 中的神经网络。
- Scenic - 一个用于计算机视觉研究及其扩展的 JAX 库。
- Levanter - 具备名称张量和 JAX 的易懂、可扩展、可复制的基础模型。
- EasyLM - 简化 LLMs 的预训练、微调、评估和服务(基于 JAX/Flax 实现)。
- NumPyro - 基于 Pyro 库的概率编程。
- Chex - 编写和测试高可靠性 JAX 代码的工具库。
- Optax - 梯度处理和优化库。
- RLax - 实现强化学习代理的库。
- JAX, M.D. - 加速、可微分的分子动力学。
- Coax - 将强化学习论文转化为代码,轻松实现。
- Distrax - 重新实现 TensorFlow Probability,包含概率分布及其双射函数。
- cvxpylayers - 构建可微分的凸优化层。
- TensorLy - 简化张量学习。
- NetKet - 用于量子物理的机器学习工具箱。
- Fortuna - AWS 的深度学习不确定性量化库。
- BlackJAX - 针对 JAX 的采样器库。
新库
此部分包含制作精良且实用的库,但不一定经过大量用户严格测试。
- 神经网络库
- FedJAX - 基于Optax和Haiku在JAX中的联邦学习。
- Equivariant MLP - 构建等变神经网络层。
- jax-resnet - 实现和检查点在Flax中的ResNet变体。
- Parallax - JAX中不可变的Torch模块。
- jax-unirep - 实现了用于蛋白质机器学习应用的UniRep模型的库。
- jax-flows - JAX中的归一化流。
- sklearn-jax-kernels - 使用JAX的
scikit-learn
核矩阵。 - jax-cosmo - 微分宇宙学库。
- efax - JAX中的指数家族。
- mpi4jax - 在CPU和GPU上将MPI操作与Jax代码结合。
- imax - 图像增强和变换。
- FlaxVision - TorchVision的Flax版本。
- Oryx - 基于程序转换的概率编程语言。
- Optimal Transport Tools - 解决优化传输问题的工具箱。
- delta PV - 具有自动微分功能的光伏模拟器。
- jaxlie - 刚体转换和优化的李理论库。
- BRAX - 可微物理引擎,用于模拟环境以及训练这些环境的代理学习算法。
- flaxmodels - 适用于Jax/Flax的预训练模型。
- CR.Sparse - 用于稀疏表示和压缩感知的XLA加速算法。
- exojax - 自动可微的系外行星/棕矮星光谱建模库,兼容JAX。
- JAXopt - JAX中的硬件加速(GPU/TPU)、可批处理和可微优化器。
- PIX - JAX中的图像处理库,为JAX而生。
- bayex - 由JAX驱动的贝叶斯优化库。
- JaxDF - 具有任意离散化的微分模拟框架。
- tree-math - 将操作数组的函数转换为操作PyTrees的函数。
- jax-models - 实现最初没有代码或用其他框架编写的研究论文。
- PGMax - 一个用于构建离散概率图模型(PGM)并在JAX中进行推理的框架。
- EvoJAX - 硬件加速的神经进化库。
- evosax - 基于JAX的进化策略。
- SymJAX - 符号CPU/GPU/TPU编程。
- mcx - 表达和编译概率程序以实现高效推理。
- Einshape - 基于DSL的JAX和其他框架的重构库。
- ALX - 使用交替最小二乘法进行分布式矩阵分解的开源库,更多信息见ALX: Large Scale Matrix Factorization on TPUs。
- Diffrax - JAX中的数值微分方程求解器。
- tinygp - JAX中“最小”的高斯过程库。
- gymnax - 具有知名gym API的强化学习环境。
- Mctx - 原生JAX中的蒙特卡洛树搜索算法。
- KFAC-JAX - 用于神经网络的近似曲率二阶优化。
- TF2JAX - 将函数/图转换为JAX函数。
- jwave - 用于可微分声学模拟的库。
- GPJax - JAX中的高斯过程。
- Jumanji - 一套行业驱动的硬件加速强化学习环境,使用JAX编写。
- Eqxvision - Torchvision的Equinox版本。
- JAXFit - 用于非线性最小二乘问题的加速曲线拟合库(见arXiv论文)。
- econpizza - 使用JAX求解具有异质代理的宏观经济模型。
- SPU - 一个特定领域的编译器和运行时套件,用于使用JAX代码进行MPC(安全多方计算)。
- jax-tqdm - 为JAX扫描和循环添加tqdm进度条。
- safejax - 使用🤗
safetensors
序列化JAX、Flax、Haiku或Objax模型参数。 - Kernex - JAX中的可微分矩阵装饰器。
- MaxText - 一个使用纯Python/JAX编写的简单、高性能和可扩展的JAX大语言模型,面向Google Cloud TPU。
- Pax - 一个基于JAX的大规模模型训练机器学习框架。
- Praxis - Pax的层库,旨在供其他JAX基础的机器学习项目使用。
- purejaxrl - JAX中的可向量化端到端强化学习算法。
- Lorax - 自动将LoRA应用于JAX模型(Flax、Haiku等)。
- SCICO - JAX中的科学计算成像。
- Spyx - 用于神经形态硬件机器学习的JAX尖峰神经网络。
- BrainPy - 用Python编写的大脑动态编程。
- OTT-JAX - JAX中的最优传输工具。
- QDax - JAX中的质量多样性优化。
- JAX Toolbox - 使用T5x、Paxml和Transformer Engine等库在NVIDIA GPU上优化的JAX例子和夜间CI。
- Pgx - 带有AlphaZero示例的强化学习向量化棋盘游戏环境。
- EasyDeL - EasyDeL 🔮是一个开源库,使您的训练更快、更优化,并提供用于JAX中训练和服务(Llama、MPT、Mixtral、Falcon等)的酷选项。
- XLB - Python中用于基于物理的机器学习的可微分大规模并行格子Boltzmann库。
- dynamiqs - JAX中高性能和可微分的量子系统模拟。
模型和项目
JAX
- Fourier Feature Networks - 官方实现 Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains。
- kalman-jax - 使用迭代卡尔曼滤波和平滑进行马尔可夫(即时间)高斯过程的近似推理。
- jaxns - 在 JAX 中的嵌套采样。
- Amortized Bayesian Optimization - 与 Amortized Bayesian Optimization over Discrete Spaces 相关的代码。
- Accurate Quantized Training - 在 JAX 和 Flax 中运行和分析神经网络量化实验的工具和库。
- BNN-HMC - 为论文 What Are Bayesian Neural Network Posteriors Really Like? 提供的实现。
- JAX-DFT - 在 JAX 中的一维密度泛函理论 (DFT),实现了 Kohn-Sham equations as regularizer: building prior knowledge into machine-learned physics。
- Robust Loss - 为论文 A General and Adaptive Robust Loss Function 提供的参考代码。
- Symbolic Functionals - 来自 Evolving symbolic density functionals 的演示。
- TriMap - TriMap: Large-scale Dimensionality Reduction Using Triplets 的官方 JAX 实现。
Flax
- Performer - Performer (通过 FAVOR+ 的线性 Transformer) 架构的 Flax 实现。
- JaxNeRF - 支持多设备 GPU/TPU 的 NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis 实现。
- mip-NeRF - Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields 的官方实现。
- RegNeRF - RegNeRF: Regularizing Neural Radiance Fields for View Synthesis from Sparse Inputs 的官方实现。
- Big Transfer (BiT) - Big Transfer (BiT): General Visual Representation Learning 的实现。
- JAX RL - 强化学习算法的实现。
- gMLP - Pay Attention to MLPs 的实现。
- MLP Mixer - MLP-Mixer: An all-MLP Architecture for Vision 的简洁实现。
- Distributed Shampoo - Second Order Optimization Made Practical 的实现。
- NesT - Aggregating Nested Transformers 的官方实现。
- XMC-GAN - Cross-Modal Contrastive Learning for Text-to-Image Generation 的官方实现。
- FNet - FNet: Mixing Tokens with Fourier Transforms 的官方实现。
- GFSA - Learning Graph Structure With A Finite-State Automaton Layer 的官方实现。
- IPA-GNN - Learning to Execute Programs with Instruction Pointer Attention Graph Neural Networks 的官方实现。
- Flax Models - 在 Flax 中实现的模型和方法集合。
- Protein LM - 实现了 BERT 和自回归蛋白质模型,如 Biological Structure and Function Emerge from Scaling Unsupervised Learning to 250 Million Protein Sequences 和 ProGen: Language Modeling for Protein Generation 中所述。
- Slot Attention - Differentiable Patch Selection for Image Recognition 的参考实现。
- Vision Transformer - An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale 的官方实现。
- FID computation - 将 mseitzer/pytorch-fid 移植到 Flax 上。
- ARDM - Autoregressive Diffusion Models 的官方实现。
- D3PM - Structured Denoising Diffusion Models in Discrete State-Spaces 的官方实现。
- Gumbel-max Causal Mechanisms - Learning Generalized Gumbel-max Causal Mechanisms 的代码,额外代码见 GuyLor/gumbel_max_causal_gadgets_part2。
- Latent Programmer - ICML 2021 论文 Latent Programmer: Discrete Latent Codes for Program Synthesis 的代码。
- SNeRG - ICLR 2022 论文 Baking Neural Radiance Fields for Real-Time View Synthesis 的官方实现。
- Spin-weighted Spherical CNNs - Spin-Weighted Spherical CNNs 的改编。
- VDVAE - Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on Images 的改编,原始代码见 openai/vdvae。
- MUSIQ - ICCV 2021 论文 MUSIQ: Multi-scale Image Quality Transformer 的检查点和模型推理代码。
- AQuaDem - Continuous Control with Action Quantization from Demonstrations 的官方实现。
- Combiner - Combiner: Full Attention Transformer with Sparse Computation Cost 的官方实现。
- Dreamfields - ICLR 2022 论文 Progressive Distillation for Fast Sampling of Diffusion Models 的官方实现。
- GIFT - Gradual Domain Adaptation in the Wild: When Intermediate Distributions are Absent 的官方实现。
- Light Field Neural Rendering - Light Field Neural Rendering 的官方实现。
- Sharpened Cosine Similarity in JAX by Raphael Pisoni - Sharpened Cosine Similarity 层的 JAX/Flax 实现。
- GNNs for Solving Combinatorial Optimization Problems - 使用 JAX + Flax 实现的 Combinatorial Optimization with Physics-Inspired Graph Neural Networks。
Haiku
- AlphaFold - AlphaFold v2.0 推理流程的实现,见论文 Highly accurate protein structure prediction with AlphaFold。
- Adversarial Robustness - 参考代码,见论文 Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples 和 Fixing Data Augmentation to Improve Adversarial Robustness。
- Bootstrap Your Own Latent - 自监督学习新方法,见论文 Bootstrap your own latent: A new approach to self-supervised Learning。
- Gated Linear Networks - GLN 是一类不依赖反向传播的神经网络。
- Glassy Dynamics - 开源实现,见论文 Unveiling the predictive power of static structure in glassy systems。
- MMV - 实现模型代码,见论文 Self-Supervised MultiModal Versatile Networks。
- Normalizer-Free Networks - 官方 Haiku 实现,见论文 NFNets。
- NuX - 使用 JAX 的正则化流。
- OGB-LSC - 包含 DeepMind 在 OGB Large-Scale Challenge (OGB-LSC) 中的 PCQM4M-LSC (量子化学) 和 MAG240M-LSC (学术图) 赛道的参赛代码。
- Persistent Evolution Strategies - 代码实现,见论文 Unbiased Gradient Estimation in Unrolled Computation Graphs with Persistent Evolution Strategies。
- Two Player Auction Learning - JAX 实现,见论文 Auction learning as a two-player game。
- WikiGraphs - 基准代码,见论文 WikiGraphs: A Wikipedia Text - Knowledge Graph Paired Datase。
Trax
- Reformer - 高效 Transformer 架构 Reformer 的实现。
NumPyro
- lqg - 线性-二次高斯问题的贝叶斯逆最优控制的官方实现,见论文 Putting perception into action with inverse optimal control for continuous psychophysics。
视频
- NeurIPS 2020: JAX Ecosystem Meetup - JAX 在 DeepMind 的使用及工程师、科学家和 JAX 核心团队之间的讨论。
- Introduction to JAX - 从零开始用 JAX 构建简单神经网络。
- JAX: Accelerated Machine Learning Research | SciPy 2020 | VanderPlas - JAX 的核心设计、推动新研究的方式及如何使用它。
- Bayesian Programming with JAX + NumPyro — Andy Kitchen - 使用 NumPyro 进行贝叶斯建模的介绍。
- JAX: Accelerated machine-learning research via composable function transformations in Python | NeurIPS 2019 | Skye Wanderman-Milne - 在 Program Transformations for Machine Learning 研讨会上,JAX 简介演讲。
- JAX on Cloud TPUs | NeurIPS 2020 | Skye Wanderman-Milne and James Bradbury - TPU 主机访问演示。
- Deep Implicit Layers - Neural ODEs, Deep Equilibirum Models, and Beyond | NeurIPS 2020 - 教程由 Zico Kolter、David Duvenaud 和 Matt Johnson 创建,Colab 笔记本见 Deep Implicit Layers。
- Solving y=mx+b with Jax on a TPU Pod slice - Mat Kelcey - 一系列四部分 YouTube 教程及 Colab 笔记,本系列从 Jax 基础讲起,最后讨论在 v3-32 TPU Pod 切片上进行数据并行训练。
- JAX, Flax & Transformers 🤗 - 围绕 JAX / Flax、Transformers、大规模语言建模及其他精彩话题的三天讲座。
论文
此部分包含围绕 JAX 的论文(例如基于 JAX 的库白皮书、JAX 研究等)。使用 JAX 实现的论文见 Models/Projects 章节。
- Compiling machine learning programs via high-level tracing. Roy Frostig, Matthew James Johnson, Chris Leary. MLSys 2018. - 描述 JAX 早期版本的白皮书,详述了计算如何被追踪和编译。
- JAX, M.D.: A Framework for Differentiable Physics. Samuel S. Schoenholz, Ekin D. Cubuk. NeurIPS 2020. - 介绍 JAX, M.D.,一个包含仿真环境、相互作用势、神经网络等在内的可微物理学库。
- Enabling Fast Differentially Private SGD via Just-in-Time Compilation and Vectorization. Pranav Subramani, Nicholas Vadivelu, Gautam Kamath. arXiv 2020. - 使用 JAX 的 JIT 和 VMAP 实现比现有库更快的差分隐私 SGD。
- XLB: A Differentiable Massively Parallel Lattice Boltzmann Library in Python. Mohammadmehdi Ataei, Hesam Salehipour. arXiv 2023. - 描述 XLB 库的白皮书:基准测试、验证及库的更多细节。
教程和博文
- 使用 JAX 加速我们的研究 by David Budden 和 Matteo Hessel - 描述了 DeepMind 中 JAX 和 JAX 生态系统的现状。
- 开始使用 JAX (MLPs, CNNs & RNNs) by Robert Lange - 从头开始使用基本的 JAX 算子构建神经网络模块。
- 学习 JAX: 从线性回归到神经网络 by Rito Ghosh - 一个温和的介绍,使用 JAX 实现线性和逻辑回归,以及神经网络模型,并使用它们解决实际问题。
- 教程:使用 JAX 和 Flax Linen 进行图像分类 by 8bitmp3 - 学习如何使用 Flax 的 Linen API 创建一个简单的卷积网络,并训练其识别手写数字。
- 使用 JAX by Nick Doiron - 比较了 Flax、Haiku 和 Objax 在 Kaggle 花卉分类挑战中的表现。
- 在 JAX 中的 50 行元学习 by Eric Jang - 介绍了 JAX 和元学习。
- 在 JAX 中的 100 行归一化流 by Eric Jang - 简明实现 RealNVP。
- 在 GPU/TPU 上进行可微路径追踪 by Eric Jang - 实现路径追踪的教程。
- 集成网络 by Mat Kelcey - 集成网络是一种将多个模型集成在一个逻辑模型中的方法。
- 分布外 (OOD) 检测 by Mat Kelcey - 实现不同的 OOD 检测方法。
- 使用 JAX 理解自动微分 by Srihari Radhakrishna - 使用 JAX 理解自动微分的工作原理。
- 从 PyTorch 到 JAX: 向净化状态代码的神经网络框架过渡 by Sabrina J. Mielke - 展示如何从类似 PyTorch 的编码风格转向更具功能性的编码风格。
- 使用自定义 C++ 和 CUDA 代码扩展 JAX by Dan Foreman-Mackey - 介绍为 JAX 提供自定义操作所需的基础设施的教程。
- 在 JAX 中进化神经网络 by Robert Tjarko Lange - 探索如何利用 JAX 推动下一代可扩展的神经进化算法。
- 使用 JAX 探索超参数元损失图景 by Luke Metz - 演示如何使用 JAX 进行内损失优化(使用 SGD 和 Momentum)、外损失优化(使用梯度)以及外损失优化(使用进化策略)。
- 在 JAX 中实现确定性 ADVI by Martin Ingram - 使用 JAX 轻松、清晰地实现自动微分变分推断(ADVI)的教程。
- 进化通道选择 by Mat Kelcey - 训练一个能适应不同分辨率下输入通道组合的分类模型,然后使用遗传算法决定最佳组合。
- JAX 入门 by Kevin Murphy - 介绍语言各方面并应用于简单的机器学习问题的 Colab。
- 在 JAX 中编写 MCMC 采样器 by Jeremie Coullon - 关于在 JAX 中编写 MCMC 采样器的不同方法和速度基准的教程。
- 如何为 JAX 扫描和循环添加进度条 by Jeremie Coullon - 教程,教你如何使用
host_callback
模块为编译的循环添加进度条。 - 开始使用 JAX by Aleksa Gordić - 一系列讲解如何从零学习 JAX 知识到在 Haiku 中构建神经网络的笔记本和视频。
- 在 JAX+FLAX 中编写训练循环 by Saurav Maheshkar 和 Soumik Rakshit - 关于在 JAX、Flax 和 Optax 中编写简单的端到端训练和评估管道的教程。
- 在 JAX 中实现 NeRF by Soumik Rakshit 和 Saurav Maheshkar - 关于 JAX 中进行由神经辐射场表示的场景的 3D 体素渲染教程。
- 使用 JAX+Flax 进行深度学习教程 by Phillip Lippe - 一系列讲解各种深度学习概念的笔记,从基础(例如 JAX/Flax 简介、激活函数)到最新进展(例如视觉变压器、SimCLR),并翻译成 PyTorch 代码。
- 使用 PureJaxRL 实现 4000 倍加速 - 介绍 JAX 如何通过向量化大幅加速强化学习训练的一篇博客文章。
书籍
- Jax 实战 - 一本关于使用 JAX 进行深度学习和其他数学密集型应用的实践指南。
社区
贡献
欢迎贡献!请先阅读 贡献指南。