FlashAttention - 亚马逊 SageMaker AI
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

FlashAttention

SMPv2 支持FlashAttention内核,可以轻松地将其应用于 Hugging Face Transformer 模型的各种场景。请注意,如果您使用 v2.0 或更高版本的 FlashAttention 软件包,则SMP使用 FlashAttention v2;但是,在 v FlashAttention 1.x 中,Triton 闪光注意力默认为 flash 注意力内核,因此在 v1 中仅支持该内核。 FlashAttention

模块 (nn.Module) 是一个低级别API,用于定义模型的注意力层。它应该在模型创建之后立即应用,AutoModelForCausalLM.from_config()API例如,在模型被转换或封装之前FSDP。

使用 FlashAttention 内核来集中注意力

以下代码片段显示了如何使用 SMP v2 torch.sagemaker.nn.attn.FlashSelfAttention API 提供的。

def new_attn(self, q, k, v, attention_mask=None, head_mask=None): return ( self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"), None, ) for layer in model.gpt_neox.layers: layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention() layer.attention._attn = functools.partial(new_attn, layer.attention)

使用 FlashAttention 内核进行分组查询注意

SMPv2 还支持用于分组查询注意FlashAttention力的内核 (GQA),并且可以轻松地将其应用于 Hugging Face Transformer 模型的各种场景。与最初的注意力架构不同,GQA同样将查询头分成组,而同一组中的查询头共享相同的键和值头。因此,q 和 kv 磁头被分别传入前向调用。注意:q 磁头的数量需要可以被 kv 磁头的数量整除。

使用示例 FlashGroupedQueryAttention

以下代码片段显示了如何使用 SMP v2 torch.sagemaker.nn.attn.FlashGroupedQueryAttention API 提供的。

from transformers.models.llama.modeling_llama import LlamaAttention from torch.sagemaker.nn.attn import FlashGroupedQueryAttention class LlamaFlashAttention(LlamaAttention): def __init__(self, config: LlamaConfig): super().__init__(config) self.flash_attn = FlashGroupedQueryAttention( attention_dropout_prob=0.0, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ... ): query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) ... kv = (key_states, value_states) attn_output = self.flash_attn( query_states, kv, attn_mask=attention_mask, causal=True, layout="b h s d", ) ... attn_output = self.o_proj(attn_output) ... return attn_output

该SMP库还提供torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention,它torch.sagemaker.nn.attn.FlashGroupedQueryAttentionAPI在低级别使用。Hugging Face 转换器在 4.36.0 版中也有一个名为 LlamaFlashAttention2 的类似实现。以下代码片段显示了如何使用 SMP v2 LlamaFlashAttention API 或 Transformers LlamaFlashAttention2 API 来替换现有 Llama 模型的注意力层。

from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention from transformers.models.llama.modeling_llama import LlamaFlashAttention2 flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2 attn_name = "self_attn" for layer in model.model.layers: prev_layer = getattr(layer, attn_name) setattr(layer, attn_name, flash_attn_class(model.config))