Infini-Transformer
概述
Infini-Transformer(https://arxiv.org/abs/2404.07143)是一个强大而多功能的transformer模型,专为各种自然语言处理任务而设计。它利用最先进的技术和架构,实现了卓越的性能,并可扩展到无限的上下文长度。
特性
- 可扩展的架构,能够处理长序列
- 在多样化数据集上进行大规模预训练
- 支持多种下游任务,包括文本分类、问答和语言生成
- 高效的微调,适应特定任务
- 包含一个融合了Infini-Attention的Mixture-of-Depths(https://arxiv.org/abs/2404.02258)transformer层
- 实现了符合Infini-Attention和Mixture-of-Depth内存高效设计的RoPE(https://arxiv.org/abs/2104.09864)
- 实现了符合Infini-Attention和Mixture-of-Depth内存高效设计的YaRN(https://arxiv.org/abs/2309.00071)
目录结构
infini-transformer/
│
├── infini_transformer/
│ ├── __init__.py
│ ├── transformer.py
│ ├── compressive_memory.py
│ ├── positional_embedder.py
│ └── activations.py
│
├── examples/
│ ├── __init__.py
│ └── modinfiniformer.py
│
├── tests/
│ ├── __init__.py
│ └── test_transformer.py
│
├── LICENSE
├── README.md
├── requirements.txt
├── MANIFEST.in
└── pyproject.toml
入门指南
要开始使用Infini-Transformer,你可以克隆仓库并从源代码安装:
git clone https://github.com/dingo-actual/infini-transformer.git
cd infini-transformer
pip install -e .
使用方法
CompressiveMemory
CompressiveMemory
模块是Infini-Transformer架构的关键组件。它旨在通过压缩并将输入标记存储在内存矩阵和归一化向量中,有效地处理长序列。这使得模型能够在保持有限内存使用的同时维持大型上下文窗口。
它通过沿序列维度(假定为维度1)划分输入张量,执行多头自注意力的变体,并带有循环更新步骤。它首先对输入进行学习的线性投影,生成键、查询和值张量,然后从中提取每个循环步骤的片段。
在每个循环步骤中,它计算线性注意力(使用内存和归一化矩阵)和SDP注意力的学习线性组合。然后使用当前步骤的键和值矩阵,以及当前内存矩阵和归一化向量来更新内存矩阵和归一化向量。在输出之前,组合的注意力张量在所有头上堆叠,然后投影回输入维度。
每个循环步骤的输出沿序列维度(维度1)连接,生成最终的输出张量。
内存矩阵的更新有两种变体:线性和增量。
线性更新规则为: $$M_t = M_{t-1} + \bigl(\textrm{ELU}(K_{t-1}\bigr) + 1)^TV_{t-1}$$
增量更新规则为: $$M_t = M_{t-1} + \bigl(\textrm{ELU}(K_{t-1}) + 1\bigr)^T \biggl( V_{t-1} - \frac{(\textrm{ELU}(K_{t-1}) + 1)M_{t-1}}{(\textrm{ELU}(K_{t-1}) + 1)z_{t-1}}\biggr)$$
其中$M_i$是步骤$i$的内存矩阵,$z_i$是步骤$i$的归一化向量。$K$和$V$矩阵的下标表示它们对应的循环步骤。
计算尽可能沿嵌入维度堆叠,以高效利用多头注意力。
CompressiveMemory
模块接受以下参数:
dim_input
:张量的输入维度。dim_key
:键张量和查询张量的维度。dim_value
:值张量的维度。num_heads
:注意力头的数量。segment_len
:递归注意力计算中每个段的长度。sampling_factor
:如果使用混合深度(Mixture-of-Depths)则使用的采样因子(如果不使用混合深度则为None)。(默认为None。)update
:用于内存矩阵更新的类型。可以是"linear"或"delta"。(默认为"linear"。)causal
:是否在SDP计算中使用因果注意力(每个位置只能关注之前的位置)。(默认为False。)positional_embedder
:可选的PositionEmbeddings
对象:RoPEEmbeddings
或YaRNEmbeddings
(默认为None。)init_state_learnable
:初始内存状态和归一化向量是否为可学习参数。(默认为False。)
CompressiveMemory
模块的示例用法如下:
import torch
from infini_transformer.compressive_memory import CompressiveMemory
cm = CompressiveMemory(
dim_input=768,
dim_key=64,
dim_value=64,
num_heads=8,
segment_len=2048,
sampling_factor=None,
update="linear",
causal=True,
positional_embedder="rope",
init_state_learnable=False
)
batch = torch.randn(
2, # 批量大小
65536, # 序列长度
768 # 输入维度
)
output = cm(batch)
在训练过程中,不需要对输出进行特殊处理。
InfiniTransformer
InfiniTransformer
类实现了原始transformer的一个变体,它使用CompressiveMemory
代替标准的自注意力机制。这使得模型能够通过压缩和存储输入标记到内存矩阵和归一化向量中来高效处理长序列。它利用CompressiveMemory
模块执行多头自注意力的变体,并包含一个递归更新步骤。
InfiniTransformer
与普通transformer的主要区别在于用CompressiveMemory
替换了标准的多头自注意力机制。
InfiniTransformer
模块接受以下参数:
-
dim_input
:张量的输入维度。 -
dim_hidden
:多头自注意力后应用的MLP的隐藏维度。 -
dim_key
:键张量和查询张量的维度。 -
dim_value
:值张量的维度。 -
num_heads
:注意力头的数量。 -
activation
:在MLP中应用的非线性激活函数。支持以下激活函数:"relu"
:ReLU激活"abs"
:绝对值激活"gelu"
:高斯误差线性单元(GELU)激活"swish"
:Swish激活"swiglu"
:SwiGLU激活"geglu"
:门控高斯误差线性单元(GeGELU)激活"ffnglu"
:带门控线性单元的前馈网络(FFNGLU)激活"ffngeglu"
:带门控高斯误差线性单元的前馈网络(FFNGeGLU)激活"ffnswiglu"
:带Swish门控线性单元的前馈网络(FFNSwiGLU)激活
-
segment_len
:递归注意力计算中每个段的长度。 -
update
:用于内存矩阵更新的类型。可以是"linear"或"delta"。(默认为"linear"。) -
causal
:是否在SDP计算中使用因果注意力(每个位置只能关注之前的位置)。(默认为False。) -
positional_embedder
:可选的PositionEmbeddings
对象:RoPEEmbeddings
或YaRNEmbeddings
(默认为None。) -
init_state_learnable
:初始内存状态和归一化向量是否为可学习参数。(默认为False。) -
dropout
:在MLP中应用的dropout率。(默认为0.0。)
InfiniTransformer
模块的示例用法如下:
import torch
from infini_transformer import InfiniTransformer
tfm = InfiniTransformer(
dim_input=768,
dim_hidden=2048,
dim_key=64,
dim_value=64,
num_heads=8,
activation="ffngeglu",
segment_len=2048,
update="delta",
causal=True,
positional_embedder=None,
init_state_learnable=False,
dropout=0.1
)
batch = torch.randn(
2, # 批量大小
65536, # 序列长度
768 # 输入维度
)
output = tfm(batch)
在训练过程中,不需要对输出进行特殊处理。
MoDInfiniTransformer
MoDInfiniTransformer
模块扩展了InfiniTransformer
模块,引入了混合深度(Mixture-of-Depths)(Raposo等人;https://arxiv.org/abs/2404.02258)。MoDInfiniTransformer
块将其输入进行学习的线性投影到单一维度,并使用具有最高前k个值的标记执行InfiniTransformer
的操作,将所有剩余标记添加到残差连接中。这使得模型能够将其容量集中在输入序列中最重要的部分,进一步减少了整体计算和内存需求,比单独使用InfiniTransformer
更加高效。
前k个选择通常会导致递归循环中的段具有不同的长度。我们通过在所有段中均匀分配选择来避免这种情况。
由于top-k选择的非因果性质,在推理时,投影到1维时产生的分数被视为独立二元分类器的logits。因此,我们在训练模型时为每个ModInfiniFormer
层添加了一个额外的损失项,即logits与训练期间选择的top-k词元之间的二元交叉熵损失。
因此,ModInfiniTransformer
的输出是由三个张量组成的元组:
- 常规输出张量,其维度与输入张量相匹配
- 形状为
(batch_size * sequence_length, 1)
的张量,表示训练期间选择的top-k词元的二元掩码。这将作为额外二元交叉熵损失的目标。 - 形状为
(batch_size * sequence_length, 1)
的张量,包含与上述二元掩码对应的logits。这表示用于选择top-k词元的分数,被视为额外二元交叉熵损失的预测。
在推理时,可以安全地忽略元组的第二和第三个元素,因为所有词元选择逻辑都在MoDInfiniTransformer
模块内部处理。
重要提示:基于二元分类器的词元选择机制在推理时无法保证为批次中的每个元素选择相同数量的词元。如果不加以控制,这将导致一个不规则数组,目前PyTorch不支持这种情况。当前的解决方案是将批次大小强制设为1,并在单个观察值上连接前向传播。我们意识到这并不是最优解,希望在不久的将来能够解决这个问题。
MoDInfiniTransformer
模块接受以下参数:
-
dim_input
:张量的输入维度。 -
dim_hidden
:多头自注意力后应用的MLP的隐藏维度。 -
dim_key
:键张量和查询张量的维度。 -
dim_value
:值张量的维度。 -
num_heads
:注意力头的数量。 -
activation
:在MLP中应用的非线性激活函数。支持以下激活函数:"relu"
:ReLU激活"abs"
:绝对值激活"gelu"
:高斯误差线性单元(GELU)激活"swish"
:Swish激活"swiglu"
:SwiGLU激活"geglu"
:门控高斯误差线性单元(GeGELU)激活"ffnglu"
:带门控线性单元的前馈网络(FFNGLU)激活"ffngeglu"
:带门控高斯误差线性单元的前馈网络(FFNGeGLU)激活"ffnswiglu"
:带Swish门控线性单元的前馈网络(FFNSwiGLU)激活
-
segment_len
:循环注意力计算中每个段的长度。 -
sampling_factor
:区间(1,segment_len
)内的数值,决定在top-k选择期间从每个段中选择的词元数量。sampling_factor
值越大,选择的词元越少。 -
update
:用于更新记忆矩阵的方式。可以是"linear"或"delta"。(默认为"linear"。) -
causal
:在SDP计算中是否使用因果注意力(每个位置只能关注之前的位置)。(默认为False。) -
positional_embedder
:可选的PositionEmbeddings
对象:RoPEEmbeddings
或YaRNEmbeddings
(默认为None。) -
init_state_learnable
:初始记忆状态和标准化向量是否为可学习参数。(默认为False。) -
dropout
:在MLP中应用的dropout率。(默认为0.0。)
InfiniTransformer
模块的使用示例如下:
import torch
from infini_transformer import MoDInfiniTransformer
tfm = MoDInfiniTransformer(
dim_input=768,
dim_hidden=2048,
dim_key=64,
dim_value=64,
num_heads=8,
activation="ffngeglu",
segment_len=2048,
sampling_factor=8,
update="delta",
causal=True,
init_state_learnable=False,
positional_embedder=None,
dropout=0.1
)
batch = torch.randn(
2, # 批次大小
65536, # 序列长度
768 # 输入维度
)
output, select_target, select_pred = tfm(batch)
在训练过程中,我们必须考虑MoDInfiniFormer
的额外输出,以便将它们用于二元交叉熵损失。请参阅infini_transformer/example/modinfiniformer.py,了解如何将额外输出整合到整体模型输出和训练循环中的示例。
RoPEEmbeddings
RoPEEmbeddings
模块应用了Su等人的论文"RoFormer: Enhanced Transformer with Rotary Position Embedding"(https://arxiv.org/abs/2104.09864)中的RoPE。一旦实例化,它可以作为positional_embedder
参数传递给InfiniTransformer
或MoDInfiniTransformer
模块,然后传递给CompressiveMemory
,在那里将位置感知嵌入应用于键和查询张量。
RoPEEmbeddings
模块接受以下参数:
dim
: 键/值张量的维度。seq_len
:CompressiveMemory
输入序列的最大长度(必须与CompressiveMemory
的segment_len
参数匹配)。dim_embeddings_pct
: 用于位置感知嵌入的键/值张量维度比例。例如,如果dim
为64,dim_embeddings_pct
为0.5,则将使用32个维度用于位置感知嵌入。(默认为0.5)base
: 用于位置嵌入角度的基值。(默认为10000)
RoPEEmbeddings
模块的使用示例如下:
import torch
from infini_transformer import InfiniTransformer
from infini_transformer import RoPEEmbeddings
embedder = RoPEEmbeddings(
dim=64, # 必须与InfiniTransformer中的dim_key参数匹配
seq_len=2048, # 必须与InfiniTransformer中的segment_len参数匹配
dim_embeddings_pct=0.5,
base=10000
)
tfm = InfiniTransformer(
dim_input=768,
dim_hidden=2048,
dim_key=64, # 必须与RoPEEmbeddings中的dim参数匹配
dim_value=64,
num_heads=8,
activation="ffngeglu",
segment_len=2048, # 必须与RoPEEmbeddings中的seq_len参数匹配
update="delta",
causal=True,
positional_embedder=embedder,
init_state_learnable=False,
dropout=0.1
)
batch = torch.randn(
2, # 批次大小
65536, # 序列长度
768 # 输入维度
)
output = tfm(batch)
YaRNEmbeddings
YaRNEmbeddings
模块应用了Peng等人的论文"YaRN: Efficient Context Window Extension of Large Language Models"中的YaRN技术(https://arxiv.org/abs/2309.00071)。实例化后,它可以作为positional_embedder
参数传递给InfiniTransformer
或MoDInfiniTransformer
模块,然后传递给CompressiveMemory
,在那里将位置感知嵌入应用于键和查询张量。
YaRNEmbeddings
模块接受以下参数:
dim
: 键/值张量的维度。seq_len
:CompressiveMemory
输入序列的最大长度(必须与CompressiveMemory
的segment_len
参数匹配)。context_len
: 训练期间使用的上下文长度。context_len_ext
: 要扩展到的上下文长度。dim_embeddings_pct
: 用于位置感知嵌入的键/值张量维度比例。例如,如果dim
为64,dim_embeddings_pct
为0.5,则将使用32个维度用于位置感知嵌入。(默认为0.5)base
: 用于位置嵌入角度的基值。(默认为10000)alpha
: 动态缩放的插值最小值。(默认为1)beta
: 动态缩放的插值最小值。(默认为32)len_scale
: 注意力计算的长度缩放。默认为None(自动计算)。
YaRNEmbeddings
模块的使用示例如下:
import torch
from infini_transformer import InfiniTransformer
from infini_transformer import YaRNEmbeddings
embedder = YaRNEmbeddings(
dim=64, # 必须与InfiniTransformer中的dim_key匹配
seq_len=2048, # 必须与InfiniTransformer中的segment_len参数匹配
context_len=32768,
context_len_ext=65536,
dim_embeddings_pct=0.5,
base=10000,
alpha=1,
beta=32,
len_scale=None
)
tfm = InfiniTransformer(
dim_input=768,
dim_hidden=2048,
dim_key=64, # 必须与YaRNEmbeddings中的dim匹配
dim_value=64,
num_heads=8,
activation="ffngeglu",
segment_len=2048, # 必须与YaRNEmbeddings中的seq_len参数匹配
update="delta",
causal=True,
positional_embedder=embedder,
init_state_learnable=False,
dropout=0.1
)
batch = torch.randn(
2, # 批次大小
65536, # 序列长度
768 # 输入维度
)
output = tfm(batch)
使用示例
请参阅infini_transformer/example/modinfiniformer.py,了解使用MoDInfiniTransformer
模块的模型和训练流程示例。
更多示例将陆续推出。
许可证
本项目采用MIT许可证。
致谢
我们要感谢那些启发和促进Infini-Transformer和Mixture-of-Depths Transformer开发的研究人员和开发者。
同时,我们要特别感谢所有贡献者、合作者以及提供反馈的人。你们的努力使一个粗略的实现框架变成了真正可用的东西。
如果您有任何问题或需要进一步的帮助,请随时联系我,邮箱是ryan@beta-reduce.net。