GPU/TPU Jax实现的环形注意力机制
本代码库提供了带有块状变换器的环形注意力机制的实现。该模型在论文《使用块状变换器的环形注意力机制实现近乎无限的上下文》和《用于大上下文模型的块状并行变换器》中有所描述。
带有块状并行变换器的环形注意力机制使得训练序列长度可以达到"设备数量"倍于BPT可能实现的长度。这是通过在多个设备上分配注意力和前馈计算,并将通信与计算重叠来实现的。由于注意力和前馈网络的块状计算,可以在不增加任何通信或计算开销的情况下,使用数千万个标记作为上下文大小进行训练。
示例用法及代码片段
首先,安装软件包:
pip install ringattention
然后,可以按如下方式导入和使用ringattention
和blockwise_feedforward
:
from ringattention import ringattention, blockwise_feedforward
您可以通过将ringattention
函数包装在shard_map
中来跨多个设备分片计算。以下是如何使用带分片的ringattention
函数的示例:
ring_attention_sharded = shard_map(
partial(
ringattention,
axis_name="sp",
float32_logits=True,
cache_idx=None,
blockwise_kwargs=dict(
causal_block_size=1,
deterministic=True,
dropout_rng=None,
attn_pdrop=0.0,
query_chunk_size=512,
key_chunk_size=512,
policy=jax.checkpoint_policies.nothing_saveable,
dtype=jax.numpy.float32,
precision=None,
prevent_cse=True,
)
),
mesh=LLaMAConfig.get_jax_mesh(self.config.mesh_dim),
in_specs=(
PS(("dp", "fsdp"), "sp", "tp", None),
PS(("dp", "fsdp"), "sp", "tp", None),
PS(("dp", "fsdp"), "sp", "tp", None),
PS(("dp", "fsdp"), None, None, None),
PS(("dp", "fsdp"), None),
),
out_specs=PS(("dp", "fsdp"), "sp", "tp", None),
check_rep=False
)
attn_output = ring_attention_sharded(xq, xk, xv, attention_bias, segment_ids)
参数说明:
-
query_chunk_size
和key_chunk_size
是查询和键的块大小。选择尽可能大的值以加速计算,直到内存不足为止。 -
policy
是注意力权重的检查点策略,使用jax.checkpoint_policies.nothing_saveable
来启用检查点。 -
causal_block_size
是块因果注意力的块大小。causal_block_size=1
等同于因果注意力。 -
cache_idx
是推理时的缓存索引。如果cache_idx
不为None
,注意力权重将被缓存并在下一次推理中重复使用。
环形注意力机制在大世界模型(LWM)中用于百万长度的视觉语言训练,在那里可以找到使用环形注意力和块状变换器的完整示例:LWM代码库
参考文献
如果您觉得我们的工作与您的研究相关,请引用:
@article{liu2023blockwise,
title={Blockwise Parallel Transformer for Large Context Models},
author={Liu, Hao and Abbeel, Pieter},
journal={Advances in neural information processing systems},
year={2023}
}
@article{liu2023ring,
title={Ring Attention with Blockwise Transformers for Near-Infinite Context},
author={Liu, Hao and Zaharia, Matei and Abbeel, Pieter},
journal={arXiv preprint arXiv:2310.01889},
year={2023}
}