Diffrax
JAX中的数值微分方程求解器。自动微分和支持GPU。
Diffrax 是一个基于 JAX 的库,提供数值微分方程求解器。
功能包括:
- ODE/SDE/CDE(常微分/随机微分/受控微分)求解器;
- 多种不同的求解器(包括
Tsit5
、Dopri8
、辛求解器、隐式求解器); - vmappable 所有东西(包括积分区域);
- 使用 PyTree 作为状态;
- 密集解;
- 多种反向传播的伴随方法;
- 支持神经微分方程。
从技术角度看,该库的内部结构非常酷——所有类型的方程(ODEs、SDEs、CDEs)都以统一的方式被求解(而不是分别处理),生成一个小巧紧凑的库。
安装
pip install diffrax
需要 Python 3.9+、JAX 0.4.13+ 和 Equinox 0.10.11+。
文档
可在 https://docs.kidger.site/diffrax 获取。
快速示例
from diffrax import diffeqsolve, ODETerm, Dopri5
import jax.numpy as jnp
def f(t, y, args):
return -y
term = ODETerm(f)
solver = Dopri5()
y0 = jnp.array([2., 3.])
solution = diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0)
这里,Dopri5
是指 Dormand--Prince 5(4) 数值微分方程求解器,这是许多问题的标准选择。
引用
如果您在学术研究中发现这个库有用,请引用:(arXiv链接)
@phdthesis{kidger2021on,
title={{O}n {N}eural {D}ifferential {E}quations},
author={Patrick Kidger},
year={2021},
school={University of Oxford},
}
(也请考虑在GitHub上为该项目加星。)
另见:JAX生态系统中的其他库
总是有用
Equinox: 神经网络和核心JAX中还没有的所有东西!
jaxtyping: 数组形状/数据类型的类型注解。
深度学习
Optax: 一阶梯度(SGD、Adam等)优化器。
Orbax: 检查点(异步/多宿主/多设备)。
Levanter: 可扩展且可靠的基础模型(如LLMs)训练。
科学计算
Optimistix: 求根、最小化、固定点和最小二乘法。
Lineax: 线性求解器。
BlackJAX: 概率+贝叶斯采样。
sympy2jax: SymPy<->JAX 转换;通过梯度下降训练符号表达式。
PySR: 符号回归。(非JAX荣誉提名!)
Awesome JAX
Awesome JAX: 更长的其他JAX项目列表。