einx - 基于爱因斯坦启发的符号系统的通用张量操作
einx 是一个 Python 库,为 Numpy、PyTorch、Jax 和 Tensorflow 等框架提供了一个通用接口来表述张量操作。其设计基于以下原则:
- 提供一套基本的张量操作,遵循类似 Numpy 的命名方式:
einx.{sum|max|where|add|dot|flip|get_at|...}
- 使用 einx 符号系统来表达基本操作的向量化。einx 符号系统受 einops 启发,但引入了一些新概念,如
[]
括号符号和完全可组合性,使其成为张量操作的通用语言。
einx 可以无缝地集成和混合到现有代码中。所有操作都通过 Python 的 exec() 即时编译成常规 Python 函数,并调用相应框架的操作。
入门:
安装
pip install einx
更多信息请参见安装说明。
einx 是什么样的?
张量操作
import einx
x = {np.asarray|torch.as_tensor|jnp.asarray|...}(...) # 创建张量
einx.sum("a [b]", x) # 沿第二个轴进行求和归约
einx.flip("... (g [c])", x, c=2) # 沿最后一个轴翻转成对的值
einx.mean("b [...] c", x) # 空间平均池化
einx.multiply("a..., b... -> (a b)...", x, y) # 克罗内克积
einx.sum("b (s [ds])... c", x, ds=(2, 2)) # 使用 2x2 核进行求和池化
einx.add("a, b -> a b", x, y) # 外加
einx.dot("a [b], [b] c -> a c", x, y) # 矩阵乘法
einx.get_at("b [h w] c, b i [2] -> b i c", x, y) # 在坐标处获取值
einx.rearrange("b (q + k) -> b q, b k", x, q=2) # 拆分
einx.rearrange("b c, 1 -> b (c + 1)", x, [42]) # 在每个通道后追加数字
# 应用自定义操作:
einx.vmap("b [s...] c -> b c", x, op=np.mean) # 空间平均池化
einx.vmap("a [b], [b] c -> a c", x, y, op=np.dot) # 矩阵乘法
常见神经网络操作
# 层归一化
mean = einx.mean("b... [c]", x, keepdims=True)
var = einx.var("b... [c]", x, keepdims=True)
x = (x - mean) * torch.rsqrt(var + epsilon)
# 预置类别标记
einx.rearrange("b s... c, c -> b (1 + (s...)) c", x, cls_token)
# 多头注意力
attn = einx.dot("b q (h c), b k (h c) -> b q k h", q, k, h=8)
attn = einx.softmax("b q [k] h", attn)
x = einx.dot("b q k h, b k (h c) -> b q (h c)", attn, v)
# 线性层的矩阵乘法
einx.dot("b... [c1->c2]", x, w) # - 常规
einx.dot("b... (g [c1->c2])", x, w) # - 分组:每组使用相同权重
einx.dot("b... ([g c1->g c2])", x, w) # - 分组:每组使用不同权重
einx.dot("b [s...->s2] c", x, w) # - MLP-mixer 中的空间混合
更多示例请参见常见神经网络操作。
可选:深度学习模块
import einx.nn.{torch|flax|haiku|equinox|keras} as einn
batchnorm = einn.Norm("[b...] c", decay_rate=0.9)
layernorm = einn.Norm("b... [c]") # 在 transformers 中使用
instancenorm = einn.Norm("b [s...] c")
groupnorm = einn.Norm("b [s...] (g [c])", g=8)
rmsnorm = einn.Norm("b... [c]", mean=False, bias=False)
channel_mix = einn.Linear("b... [c1->c2]", c2=64)
spatial_mix1 = einn.Linear("b [s...->s2] c", s2=64)
spatial_mix2 = einn.Linear("b [s2->s...] c", s=(64, 64))
patch_embed = einn.Linear("b (s [s2->])... [c1->c2]", s2=4, c2=64)
dropout = einn.Dropout("[...]", drop_rate=0.2)
spatial_dropout = einn.Dropout("[b] ... [c]", drop_rate=0.2)
droppath = einn.Dropout("[b] ...", drop_rate=0.2)
查看 examples/train_{torch|flax|haiku|equinox|keras}.py
以获取 CIFAR10 的示例训练,GPT-2 和 Mamba 以获取使用 einx 实现的语言模型示例,以及教程:神经网络以获取更多详细信息。
即时编译
einx 将给定调用所需的后端操作追踪到图形表示中,并使用 Python 的 exec()
即时编译成常规 Python 函数。这将开销减少到单次缓存查找,并允许检查生成的函数。例如:
>>> x = np.zeros((3, 10, 10))
>>> graph = einx.sum("... (g [c])", x, g=2, graph=True)
>>> print(graph)
import numpy as np
def op0(i0):
x0 = np.reshape(i0, (3, 10, 2, 5))
x1 = np.sum(x0, axis=3)
return x1
更多详细信息请参见即时编译。