Endia是一个用于科学计算的动态数组库,类似于PyTorch、Numpy和JAX。它提供:
- 自动微分:计算任意阶导数。
- 复数支持:用Endia进行高级科学应用。
- 双重API:可选择类似PyTorch的命令式或类似JAX的函数式接口。
- JIT编译:利用MAX加速训练和推理。
安装
-
安装Mojo和MAX 🔥 (v24.4)
-
克隆仓库:选择以下选项之一克隆仓库:
git clone https://github.com/endia-org/Endia.git cd Endia
如果您想使用每日(开发)版本,请切换到
nightly
分支:git checkout nightly
-
设置环境:
chmod +x setup.sh ./setup.sh
所需依赖项:
torch
、numpy
、graphviz
。这些将由安装脚本自动安装。
一个小例子
在本指南中,我们将展示如何计算一个简单函数的值、梯度和黑塞矩阵(即二阶导数)。首先使用Endia的类PyTorch API,然后使用更类似JAX的函数式API。在两个示例中,我们首先定义一个函数foo,它接受一个数组并返回其元素平方和。
PyTorch方式
使用Endia的命令式(类PyTorch)接口时,我们通过在函数输出上调用backward方法来计算函数的梯度。这种命令式风格需要显式管理计算图,包括为输入数组(即叶节点)设置requires_grad=True
,并在计算高阶导数时在backward方法中使用create_graph=True
。
from endia import Tensor, sum, arange
import endia.autograd.functional as F
# 定义函数
def foo(x: Tensor) -> Tensor:
return sum(x ** 2)
def main():
# 初始化变量 - 需要requires_grad=True!
x = arange(1.0, 4.0, requires_grad=True) # [1.0, 2.0, 3.0]
# 计算结果、一阶和二阶导数
y = foo(x)
y.backward(create_graph=True)
dy_dx = x.grad()
d2y_dx2 = F.grad(outs=sum(dy_dx), inputs=x)[0]
# 打印结果
print(y) # 14.0
print(dy_dx) # [2.0, 4.0, 6.0]
print(d2y_dx2) # [2.0, 2.0, 2.0]
JAX方式
使用Endia的函数式(类JAX)接口时,计算图是隐式处理的。通过在foo上调用grad
或jacobian
函数,我们创建一个计算完整雅可比矩阵的Callable
。这个Callable
可以再次传递给grad
或jacobian
函数以计算高阶导数。
from endia import grad, jacobian
from endia.numpy import sum, arange, ndarray
def foo(x: ndarray) -> ndarray:
return sum(x**2)
def main():
# 创建一阶和二阶导数的Callable
foo_jac = grad(foo)
foo_hes = jacobian(foo_jac)
x = arange(1.0, 4.0) # [1.0, 2.0, 3.0]
print(foo(x)) # 14.0
print(foo_jac(x)[ndarray]) # [2.0, 4.0, 6.0]
print(foo_hes(x)[ndarray]) # [[2.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 2.0]]
还有更多功能!Endia可以处理复值函数,可以执行前向和反向自动微分,甚至内置了JIT编译器来加速运算。在文档中探索完整的功能列表。
为什么要另一个ML框架?
- 🧠 推进AI和科学计算:通过清晰易懂的算法突破界限
- 🚀 Mojo驱动的清晰性:高性能开源代码始终保持可读性和Python风格
- 📐 可解释性:优先考虑清晰度和教育价值,而不是exhaustive功能
"生活中没有什么需要害怕,只有需要理解。现在是时候去理解更多,这样我们就能减少恐惧。" - 玛丽·居里
贡献
欢迎对Endia做出贡献!如果您想贡献,请遵循仓库中CONTRIBUTING.md文件中的贡献指南。
引用
如果您在研究或项目中使用Endia,请按以下方式引用:
@software{Fehrenbach_Endia_2024,
author = {Fehrenbach, Tillmann},
license = {Apache-2.0 with LLVM Exceptions},
doi = {10.5281/zenodo.12810766},
month = jul,
title = {{Endia}},
url = {https://github.com/endia-org/Endia},
version = {24.4.2},
year = {2024}
}
许可证
Endia根据Apache-2.0 license with LLVM Exeptions许可。