#JAX

keras - 多后端支持的深度学习框架,兼容JAX、TensorFlow和PyTorch
Keras 3深度学习框架JAXTensorFlowPyTorchGithub开源项目
Keras 3 提供高效的模型开发,支持计算机视觉、自然语言处理等任务。选择最快的后端(如JAX),性能提升高达350%。无缝扩展,从本地到大规模集群,适合企业和初创团队。安装简单,支持GPU,兼容tf.keras代码,避免框架锁定。
dopamine - 用于快速原型设计的强化学习研究框架
Dopamine强化学习JAXDQNTensorflowGithub开源项目
Dopamine是一个用于快速原型设计强化学习算法的研究框架,旨在便于用户进行自由实验。其设计原则包括易于实验、灵活开发、紧凑可靠和结果可重复。支持的算法有DQN、C51、Rainbow、IQN和SAC,主要实现于jax。Dopamine提供了Docker容器及源码安装方法,适用于Atari和Mujoco环境,并推荐使用虚拟环境。更多信息请参阅官方文档。
EasyDeL - 多模型训练优化框架
EasyDeL机器学习JAXFlax模型训练Github开源项目
EasyDeL是一个开源框架,用于通过Jax/Flax优化机器学习模型的训练,特别适合在TPU/GPU上进行大规模部署。它支持多种模型架构和量化方法,包括Transformers、Mamba等,并提供高级训练器和API引擎。EasyDeL的架构完全可定制和透明,允许用户修改每个组件,并促进实验和社区驱动的开发。不论是前沿研究还是生产系统构建,EasyDeL都提供灵活强大的工具以满足不同需求。最新更新包括性能优化、KV缓存改进和新模型支持。
keras-nlp - 兼容多框架的自然语言处理工具和预训练模型
KerasNLPTensorFlowJAXPyTorch自然语言处理Github开源项目
KerasNLP 是一个兼容 TensorFlow、JAX 和 PyTorch 的自然语言处理库,提供预训练模型和低级模块。基于 Keras 3,支持 GPU 和 TPU 的微调,并可跨框架训练和序列化。设置 KERAS_BACKEND 环境变量即可切换框架,安装方便,立即体验强大 NLP 功能。
EasyLM - 简化的大规模语言模型训练与部署
EasyLMJAXLLaMATPUGPT-JGithub开源项目
EasyLM提供了一站式解决方案,用于在JAX/Flax中预训练、微调、评估和部署大规模语言模型。通过JAX的pjit功能,可以扩展到数百个TPU/GPU加速器。基于Hugginface的transformers和datasets,EasyLM代码库易于使用和定制。支持Google Cloud TPU Pods上的多TPU/GPU和多主机训练,兼容LLaMA系列模型。推荐加入非官方的Discord社区,了解更多关于Koala聊天机器人和OpenLLaMA的详细信息及安装指南。
dm_pix - 基于JAX的高性能图像处理库
PIXJAX图像处理机器学习平行优化Github开源项目
PIX是一个基于JAX的开源图像处理库,具备优化和并行化能力。支持通过jax.jit、jax.vmap和jax.pmap进行加速与并行处理,适用于高性能计算需求。安装便捷,只需通过pip安装后即可使用。提供丰富的示例代码,易于上手操作,同时配备完整的测试套件,确保开发环境的可靠性,并接受社区贡献。
penzai - 用于构建、编辑和可视化神经网络的 JAX 研究工具包
PenzaiJAX深度学习模型可视化神经网络Github开源项目
Penzai是一个基于JAX的库,专为通过函数式pytree数据结构编写模型而设计,并提供丰富的工具用于可视化、修改和分析。适用于反向工程、模型组件剥离、内部激活检查、模型手术和调试等领域。Penzai包括Treescope交互式Python打印工具、JAX树和数组操作工具、声明式神经网络库及常见Transformer架构的模块化实现。该库简化了模型处理过程,为研究神经网络的内部机制与训练动态提供了支持。
GradCache - 突破GPU/TPU内存限制,实现对比学习无限扩展
Gradient Cache对比学习PytorchJAXGPUGithub开源项目
Gradient Cache技术突破了GPU/TPU内存限制,可以无限扩展对比学习的批处理大小。仅需一个GPU即可完成原本需要8个V100 GPU的训练,并能够用更具成本效益的高FLOP低内存系统替换大内存GPU/TPU。该项目支持Pytorch和JAX框架,并已整合至密集段落检索工具DPR。
dm-haiku - JAX神经网络构建的简洁解决方案
JAXHaikuDeepMind神经网络谷歌Github开源项目
Haiku是一个为JAX设计的简洁神经网络库,具备面向对象编程模型和纯函数转换功能。由Sonnet的开发者创建,Haiku能简化模型参数和状态管理,并与其他JAX库无缝集成。虽然Google DeepMind建议新项目使用Flax,Haiku仍将在维护模式下持续支持,专注于修复bug和兼容性更新。
equinox - 强大且易用的JAX兼容神经网络库
EquinoxJAX神经网络转换APIPyTreeGithub开源项目
Equinox是一款专为JAX设计的神经网络库,拥有类似PyTorch的语法。该库支持过滤API和PyTree操作,并兼容JAX及其生态系统中的所有工具。对于新手用户,推荐使用MNIST卷积神经网络示例,简化模型构建过程。Equinox还提供运行时错误处理等高级功能。
diffrax - JAX 自动微分与 GPU 支持的数值微分方程解析工具
DiffraxJAXSDEODECDEGithub开源项目
Diffrax 是基于 JAX 的数值微分方程解析库,适用于常微分方程、随机微分方程和受控微分方程的求解。其特点包括多种解析器选择(如 Tsit5、Dopri8、辛解析器、隐式解析器)、使用 PyTree 作为状态存储、支持稠密解和多种反向传播方法,并支持神经微分方程。兼容 Python 3.9+、JAX 0.4.13+ 和 Equinox 0.10.11+。
axlearn - 支持构建大规模深度学习模型的高效工具库
AXLearnJAXXLA深度学习机器学习Github开源项目
AXLearn是一个基于JAX和XLA的深度学习库,支持大规模模型的构建、迭代和维护。该库允许用户通过配置系统从可重用模块中组合模型,并兼容Flax和Hugging Face transformers等库。AXLearn能够高效地在众多加速器上训练数百亿参数的模型,涵盖自然语言处理、计算机视觉和语音识别等领域,还支持在公共云上运行并提供作业和数据管理工具。了解更多详情,请参阅其核心组件和设计文档。
awesome-jax - 自动微分与XLA在高性能机器学习中的应用
JAX机器学习自动微分XLA编译器加速器Github开源项目
该页面收录了JAX相关的优质库、项目和资源,旨在帮助机器学习研究人员在GPU和TPU等加速器上实现高性能计算。资源涵盖神经网络库、强化学习工具和概率编程等多个领域,并提供了详细的库介绍、学术论文和教程。用户可以找到如Flax、Haiku、Objax等知名库,以及新兴的FedJAX、BRAX等库,适用于机器学习和科研项目中使用JAX进行快速原型开发和高效计算。
tf2jax - 实验性TensorFlow到JAX函数转换库
TF2JAXTensorFlowJAX函数转换机器学习Github开源项目
tf2jax是一个实验性库,用于将TensorFlow函数和计算图转换为JAX函数。它支持SavedModel和TensorFlow Hub格式,使现有TensorFlow模型能够在JAX环境中重用。该库提供透明的转换过程,便于调试和分析。tf2jax支持自定义梯度和随机性处理,并提供灵活的配置选项。尽管存在一些限制,tf2jax为JAX用户提供了一种集成TensorFlow功能的有效方法。
gymnax - JAX驱动的高效强化学习环境集合
gymnax强化学习JAX环境仿真加速计算Github开源项目
gymnax是基于JAX构建的强化学习环境库,充分利用JAX的即时编译和向量化功能,显著提升了传统gym API的性能。该库涵盖经典控制、bsuite和MinAtar等多种环境,支持精确控制环境参数。通过在加速器上并行处理环境和策略,gymnax实现了高效的强化学习实验,尤其适合大规模并行和元强化学习研究。
jaxlie - JAX Lie群库为计算机视觉和机器人应用提供刚体变换
jaxlieLie群计算机视觉机器人学JAXGithub开源项目
jaxlie是一个基于JAX的Lie群实现库,专注于计算机视觉和机器人应用中的刚体变换。它实现了SO2、SE2、SO3和SE3等常用Lie群,支持自动微分、优化和JAX函数变换。该库提供前向和反向模式AD、流形优化、广播和序列化等功能,为开发者提供刚体变换的高效工具。
jaxdf - JAX框架打造可微分物理模拟器
jaxdfJAX数值模拟偏微分方程自动微分Github开源项目
jaxdf是基于JAX的开源框架,用于创建可微分数值模拟器。该框架支持任意离散化,主要应用于物理系统建模,如波传播和偏微分方程求解。jaxdf生成的纯函数模型可与JAX编写的可微分程序无缝集成,适用于神经网络层或物理损失函数。框架提供自定义算子、多种离散化方法,并附有详细文档和示例。
learned_optimization - 基于JAX的元学习优化器研究框架
learned_optimization元学习优化器JAX机器学习Github开源项目
learned_optimization是一个研究代码库,主要用于学习型优化器的训练、设计、评估和应用。该项目实现了多种优化器和训练算法,包括手工设计的优化器、学习型优化器、元训练任务以及ES、PES和截断反向传播等外部训练方法。项目提供了详细的文档和教程,包括Colab笔记本,方便用户快速入门。learned_optimization适用于元学习和动态系统训练的研究,为相关领域提供了功能丰富的工具。
brax - 基于JAX的高性能物理引擎 适用于机器人和强化学习仿真
Brax物理引擎JAX机器学习仿真Github开源项目
Brax是一款基于JAX的高性能物理引擎,专注于机器人、人体感知、材料科学和强化学习等领域的仿真应用。它支持单设备高效仿真和多设备并行仿真,无需依赖大型数据中心。Brax提供多种物理模拟管道,如MuJoCo XLA、广义坐标和基于位置的动力学,并统一API接口。此外,Brax集成了多种高效学习算法,能在短时间内完成智能体训练。
XLB - 基于JAX的可微分格子玻尔兹曼方法库
XLBLattice Boltzmann流体动力学JAX深度学习Github开源项目
XLB是一款开源的格子玻尔兹曼方法库,基于JAX构建。该库支持2D和3D模拟,具有全可微分特性,能高效解决流体动力学问题。XLB支持多GPU分布式计算,可进行大规模模拟。提供多种边界条件和碰撞核选择,并采用Python接口设计,便于使用和扩展。这些特性使XLB成为物理驱动机器学习研究的有力工具。
evosax - 基于JAX的高性能进化策略框架
evosax进化策略JAX优化算法机器学习Github开源项目
evosax是基于JAX的进化策略框架,通过XLA编译和自动向量化/并行化技术实现大规模进化策略的高效计算。它支持CMA-ES、OpenAI-ES等多种经典和现代神经进化算法,采用ask-evaluate-tell API设计。evosax兼容JAX的jit、vmap和lax.scan,可扩展至不同硬件加速器。该框架为进化计算研究和应用提供了高性能、灵活的工具。
levanter - 专注可读性与可扩展性的大语言模型训练框架
Levanter大语言模型机器学习框架分布式训练JAXGithub开源项目
Levanter是一个用于训练大型语言模型和基础模型的框架。该框架使用Haliax命名张量库编写易读的深度学习代码,同时保持高性能。Levanter支持大型模型训练,兼容GPU和TPU等硬件。框架具有比特级确定性,保证配置一致性。其功能包括分布式训练、Hugging Face生态系统兼容、在线数据预处理缓存、Sophia优化器支持和多种日志后端。
mctx - 高效JAX实现的蒙特卡洛树搜索库
MctxJAX蒙特卡洛树搜索强化学习深度学习Github开源项目
Mctx是一个基于JAX的蒙特卡洛树搜索库,实现了AlphaZero和MuZero等算法。该库支持JIT编译和并行批处理,以提高计算效率。Mctx平衡了性能和易用性,为研究人员提供了探索搜索型强化学习算法的便利工具。它包含通用搜索函数和具体策略实现,用户只需提供学习到的环境模型组件即可使用。
evojax - 基于JAX的高性能神经进化工具包
EvoJAX神经进化JAX硬件加速机器学习Github开源项目
EvoJAX是基于JAX库开发的神经进化工具包,支持在多个TPU/GPU上并行运行神经网络。通过在NumPy中实现进化算法、神经网络和任务,并即时编译到加速器上运行,EvoJAX显著提升了神经进化算法的性能。该工具包提供了多个示例,涵盖监督学习、强化学习和生成艺术等领域,展示了如何在几分钟内完成原本需要数小时或数天的进化实验。EvoJAX为研究人员提供了一个高效、灵活的神经进化开发平台。
synjax - 基于JAX的结构化概率分布神经网络库
SynJaxJAX概率分布神经网络库机器学习Github开源项目
SynJax是一个基于JAX的神经网络库,专注于结构化概率分布处理。它支持多种分布类型,包括线性链CRF、半马尔可夫CRF和成分树CRF等。该库提供计算对数概率、边际概率和最可能结构等标准操作,并兼容JAX的主要转换功能。SynJax采用纯Python编写,结合JAX的C++代码,为结构化概率建模提供了高效灵活的解决方案。
jumanji - JAX驱动的多样化强化学习环境套件 加速研究与应用
Jumanji强化学习JAX环境套件开源项目Github
Jumanji是一个基于JAX的强化学习环境套件,提供22个可扩展环境。通过硬件加速,它支持快速迭代和大规模实验。简洁API、丰富环境、主流框架兼容性和示例代码使强化学习研究更易开展,同时促进研究成果向工业应用转化。
blackjax - JAX贝叶斯采样库 支持CPU和GPU运算
BlackJAXJAX采样器概率编程GPUGithub开源项目
BlackJAX是一个为JAX开发的贝叶斯采样库,支持CPU和GPU计算。它提供多种采样器,可与概率编程语言集成。适用于需要采样器的研究人员、算法开发者和概率编程语言开发者。其模块化设计便于创建和定制采样算法,促进采样算法研究。BlackJAX通过简洁API和高性能,连接了简单框架与可定制库。
s2fft - 基于JAX和PyTorch的球谐变换Python库
S2FFT球谐变换Wigner变换JAXPyTorchGithub开源项目
S2FFT是一个用于计算球面和旋转群傅里叶变换的Python库。基于JAX和PyTorch实现,S2FFT提供可微分的球谐变换和维格纳变换,支持在GPU和TPU等硬件加速器上运行。该库采用高度并行化的新算法结构,提供多种优化选项和采样方案,包括等角采样和HEALPix采样。用户可根据可用资源和所需角分辨率灵活选择。
optimistix - JAX生态系统中的高效非线性求解器
OptimistixJAX非线性求解器数值优化Python库Github开源项目
Optimistix是一个基于JAX的非线性求解器库,专门用于根查找、最小化、不动点和最小二乘问题。该库提供可互操作的求解器和模块化优化器,支持PyTree状态,并与Optax兼容。Optimistix具有快速编译和运行时间,充分利用JAX的自动微分、自动并行和GPU/TPU支持等特性,为科学计算和机器学习领域提供高效工具。
jaxtyping - JAX数组与PyTrees的类型注解和运行时检查工具
jaxtyping类型注解运行时类型检查JAXPyTreeGithub开源项目
jaxtyping是一款为JAX数组和PyTrees提供类型注解及运行时类型检查的开源工具。除JAX外,它还支持PyTorch、NumPy和TensorFlow等主流框架,使用户能够为数组的形状和数据类型添加精确的类型提示。该项目安装简便,与多种运行时类型检查包兼容,并提供完整的在线文档。通过增强类型安全,jaxtyping为科学计算和深度学习项目提供了更可靠的开发环境。
jax - 高性能科学计算和机器学习的Python加速库
JAX自动微分XLAGPU加速神经网络Github开源项目
JAX是一个专为高性能数值计算和大规模机器学习设计的Python库。它利用XLA编译器实现加速器导向的数组计算和程序转换,支持自动微分、GPU和TPU加速。JAX提供jit、vmap和pmap等函数转换工具,让研究人员能够方便地表达复杂算法并获得出色性能,同时保持Python的灵活性。
scenic - 多模态视觉智能研究框架
Scenic计算机视觉JAXTransformer深度学习Github开源项目
Scenic是一个基于JAX的开源视觉智能研究框架,聚焦注意力机制模型。它提供轻量级共享库和完整项目实现,支持分类、分割、检测等任务,可处理图像、视频、音频等多模态数据。Scenic内置多个前沿模型和基线,有助于快速原型设计和大规模实验。
flashbax - JAX强化学习高效体验回放缓冲库
Flashbax经验回放缓冲区强化学习JAX深度学习Github开源项目
Flashbax是一个为JAX设计的高效体验回放缓冲库,适用于强化学习算法。它提供平坦缓冲、轨迹缓冲及其优先级变体等多种缓冲类型,特点是高效内存使用、易于集成到编译函数中,并支持优先级采样。Flashbax还具有Vault功能,可将大型缓冲区保存到磁盘。这个简单灵活的框架适用于学术研究、工业应用和个人项目中的体验回放处理。
gemma - Google DeepMind开源的Gemma大语言模型
Gemma大语言模型Google DeepMind开源权重JAXGithub开源项目
Gemma是Google DeepMind推出的开源大语言模型系列,基于Gemini技术开发。项目提供Flax和JAX框架的推理实现和示例,支持CPU、GPU和TPU等多种硬件平台。包括模型权重下载、入门指南、示例代码和教程,便于开发者学习和应用。Gemma共有2B和7B两种参数规模的模型可供选择。
jax-triton - JAX与Triton集成实现GPU计算加速
JAXTritonjax-tritonCUDAGPU加速Github开源项目
jax-triton项目实现了JAX和Triton的集成,让开发者能在JAX中使用Triton的GPU计算功能。通过triton_call函数,可在JAX编译函数中应用Triton内核,提高计算密集型任务效率。项目提供文档和示例,适合机器学习和科学计算领域的GPU计算优化需求。
Mava - 基于JAX的高效多智能体强化学习框架
Mava多智能体强化学习JAX分布式计算环境包装器Github开源项目
Mava是基于JAX的分布式多智能体强化学习框架,提供精简代码实现和快速迭代工具。它集成了MARL算法、环境封装、教学资源和评估方法,充分利用JAX并行计算优势,在多个环境中实现卓越性能和训练速度。Mava设计简洁易懂,便于扩展,适合MARL研究人员和实践者使用。