SynJax
什么是SynJax? | 安装 | 示例 | 引用SynJax
什么是SynJax?
SynJax是一个用于JAX的结构化概率分布的神经网络库。目前支持的分布包括:
- 线性链条件随机场(CRF),
- 半马尔可夫条件随机场,
- 成分树条件随机场,
- 生成树条件随机场 -- 包括可选的投影性、(非)定向性和单根边约束,
- 对齐条件随机场 -- 包括单调(一对多和多对多)和非单调(一对一)对齐,
- CTC对齐,
- 概率上下文无关文法(PCFG),
- 张量分解PCFG,
- 隐马尔可夫模型(HMM)。
所有这些分布都支持标准操作,如计算结构的对数概率、计算结构部分的边际概率、寻找最可能的结构、采样、top-k、熵、交叉熵、KL散度等。
所有操作都支持标准的JAX转换,包括jax.vmap
、jax.jit
、jax.pmap
和jax.grad
。唯一的例外是argmax、sample和top-k,它们不支持jax.grad
。
如果你想了解SynJax的详细信息,可以查看这篇论文。
安装
SynJax使用纯Python编写,但通过JAX依赖C++代码。由于JAX的安装方式取决于你的CUDA版本,SynJax在requirements.txt
中没有列出JAX作为依赖项。
首先,按照这些说明安装带有相关加速器支持的JAX。
然后,使用pip安装SynJax:
$ pip install git+https://github.com/google-deepmind/synjax
示例
notebooks目录包含了展示SynJax工作原理的示例:
引用SynJax
引用SynJax时,请同时使用SynJax论文引用:
@article{synjax2023,
title="{SynJax: Structured Probability Distributions for JAX}",
author={Milo\v{s} Stanojevi\'{c} and Laurent Sartran},
year={2023},
journal={arXiv preprint arXiv:2308.03291},
url={https://arxiv.org/abs/2308.03291},
}
以及当前的DeepMind JAX生态系统引用。