Optimistix
Optimistix 是一个用于非线性求解器的 JAX 库:用于求根、最小化、不动点和最小二乘法。
特点包括:
- 可互操作的求解器:例如,将求根问题自动转换为最小二乘问题,然后使用最小化算法求解。
- 模块化优化器:例如,使用带有信赖域更新的狗腿下降路径的 BFGS 二次碗。
- 使用 PyTree 作为状态。
- 快速编译和运行时间。
- 与 Optax 的互操作性。
- 使用 JAX 的所有优势:自动微分、自动并行、GPU/TPU 支持等。
安装
pip install optimistix
需要 Python 3.9+ 和 JAX 0.4.14+ 以及 Equinox 0.11.0+。
文档
可在 https://docs.kidger.site/optimistix 获取。
快速示例
import jax.numpy as jnp
import optimistix as optx
# 让我们用隐式欧拉法求解微分方程 dy/dt=tanh(y(t))。
# 我们需要找到 y1,使得 y1 = y0 + tanh(y1)dt。
y0 = jnp.array(1.)
dt = jnp.array(0.1)
def fn(y, args):
return y0 + jnp.tanh(y) * dt
solver = optx.Newton(rtol=1e-5, atol=1e-5)
sol = optx.fixed_point(fn, solver, y0)
y1 = sol.value # 满足 y1 == fn(y1)
引用
如果您在学术工作中发现这个库很有用,请引用:(arXiv 链接)
@article{optimistix2024,
title={Optimistix: modular optimisation in JAX and Equinox},
author={Jason Rader and Terry Lyons and Patrick Kidger},
journal={arXiv:2402.09983},
year={2024},
}
另请参阅:JAX 生态系统中的其他库
始终有用
Equinox:神经网络和核心 JAX 中尚未包含的所有内容!
jaxtyping:数组形状/数据类型的类型注解。
深度学习
Optax:一阶梯度(SGD、Adam 等)优化器。
Orbax:检查点(异步/多主机/多设备)。
Levanter:可扩展且可靠的基础模型(如 LLM)训练。
科学计算
Diffrax:数值微分方程求解器。
Lineax:线性求解器。
BlackJAX:概率和贝叶斯采样。
sympy2jax:SymPy<->JAX 转换;通过梯度下降训练符号表达式。
PySR:符号回归。(非 JAX 的荣誉提名!)
Awesome JAX
Awesome JAX:更多 JAX 项目的列表。
致谢
Optimistix 主要由 Jason Rader (@packquickly) 构建:Twitter;GitHub;网站。