JAXopt
⚠️ 我们正在将JAXopt合并到Optax中。因此,JAXopt现在处于维护模式,我们不会实现新功能 ⚠️
在JAX中实现的硬件加速、可批处理和可微分的优化器。
- 硬件加速: 我们的实现除了可在CPU上运行,还可在GPU和TPU上运行。
- 可批处理: 使用JAX的vmap可以自动向量化同一优化问题的多个实例。
- 可微分: 优化问题的解可以根据其输入进行隐式微分或通过展开算法迭代进行自动微分。
安装
要安装JAXopt的最新版本,请使用以下命令:
$ pip install jaxopt
要安装开发版本,请使用以下命令:
$ pip install git+https://github.com/google/jaxopt
或者,可以通过以下命令从源代码安装:
$ python setup.py install
引用我们
我们的隐式微分框架在这篇论文中有详细描述。引用格式如下:
@article{jaxopt_implicit_diff,
title={Efficient and Modular Implicit Differentiation},
author={Blondel, Mathieu and Berthet, Quentin and Cuturi, Marco and Frostig, Roy
and Hoyer, Stephan and Llinares-L{\'o}pez, Felipe and Pedregosa, Fabian
and Vert, Jean-Philippe},
journal={arXiv preprint arXiv:2105.15183},
year={2021}
}
免责声明
JAXopt是由Google Research中的专门团队维护的开源项目,但不是Google的官方产品。