FlashAttention - 亚马逊 SageMaker AI
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

FlashAttention

SMP v2 支持FlashAttention内核,可以轻松地将其应用于 Hugging Face Transformer 模型的各种场景。请注意,如果您使用 v2.0 或更高版本的 FlashAttention 软件包,SMP 使用 FlashAttention v2;但是,在 v FlashAttention 1.x 中,Triton 闪光注意力默认为闪光注意内核,因此在 v1 中仅支持该内核。 FlashAttention

模块 (nn.Module) 是一种低级 API,用于定义模型的注意层。它应在模型创建后立即应用,例如从 AutoModelForCausalLM.from_config() API,并在使用 FSDP 对模型进行转换或封装之前应用。

使用 FlashAttention 内核来集中注意力

下面的代码片段显示了如何使用 SMP v2 提供的 torch.sagemaker.nn.attn.FlashSelfAttention API。

def new_attn(self, q, k, v, attention_mask=None, head_mask=None): return ( self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"), None, ) for layer in model.gpt_neox.layers: layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention() layer.attention._attn = functools.partial(new_attn, layer.attention)

使用 FlashAttention 内核进行分组查询注意

SMP v2 还支持用于分组查询注意力 (GQA) 的FlashAttention内核,并且可以轻松地将其应用于 Hugging Face Transformer 模型的各种场景。与最初的注意架构不同,GQA 将查询磁头平均分为若干组,同一组中的查询磁头共享相同的键和值磁头。因此,q 和 kv 磁头被分别传入前向调用。注意:q 磁头的数量需要可以被 kv 磁头的数量整除。

使用示例 FlashGroupedQueryAttention

下面的代码片段显示了如何使用 SMP v2 提供的 torch.sagemaker.nn.attn.FlashGroupedQueryAttention API。

from transformers.models.llama.modeling_llama import LlamaAttention from torch.sagemaker.nn.attn import FlashGroupedQueryAttention class LlamaFlashAttention(LlamaAttention): def __init__(self, config: LlamaConfig): super().__init__(config) self.flash_attn = FlashGroupedQueryAttention( attention_dropout_prob=0.0, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ... ): query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) ... kv = (key_states, value_states) attn_output = self.flash_attn( query_states, kv, attn_mask=attention_mask, causal=True, layout="b h s d", ) ... attn_output = self.o_proj(attn_output) ... return attn_output

SMP 库还提供 torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention,它在低级别使用 torch.sagemaker.nn.attn.FlashGroupedQueryAttention API。Hugging Face 转换器在 4.36.0 版中也有一个名为 LlamaFlashAttention2 的类似实现。下面的代码片段显示了如何使用 SMP v2 LlamaFlashAttention API 或转换器 LlamaFlashAttention2 API 替换现有 Llama 模型的注意层。

from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention from transformers.models.llama.modeling_llama import LlamaFlashAttention2 flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2 attn_name = "self_attn" for layer in model.model.layers: prev_layer = getattr(layer, attn_name) setattr(layer, attn_name, flash_attn_class(model.config))