FlashAttention - Amazon SageMaker
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

FlashAttention

SMP v2 支持FlashAttention内核,可以轻松地将其应用于 Hugging Face Transformer 模型的各种场景。请注意,如果您使用 v2.0 或更高版本的 FlashAttention 软件包,SMP 使用 FlashAttention v2;但是,在 v FlashAttention 1.x 中,Triton 闪光注意默认使用闪光注意内核,因此在 v1 中仅支持该内核。 FlashAttention

模块 (nn.Module) 是一个低级 API,用于定义模型的注意力层。它应该在创建模型之后立即应用,例如从 AutoModelForCausalLM.from_config() API 中创建,也应该在使用 FSDP 转换或封装模型之前应用。

使用 FlashAttention 内核来集中注意力

以下代码片段展示了如何使用 SMP v torch.sagemaker.nn.attn.FlashSelfAttention 2 提供的 API。

def new_attn(self, q, k, v, attention_mask=None, head_mask=None): return ( self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"), None, ) for layer in model.gpt_neox.layers: layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention() layer.attention._attn = functools.partial(new_attn, layer.attention)

使用 FlashAttention 内核进行分组查询注意

SMP v2 还支持用于分组查询注意力 (GQA) 的FlashAttention内核,并且可以轻松地将其应用于 Hugging Face Transformer 模型的各种场景。与最初的注意力架构不同,GQA 将查询头部平均分成组,同一组中的查询头共享相同的键和值标头。因此,q 和 kv 头分别传递到前向调用中。注意:q 磁头的数量需要可以被 kv 磁头的数量整除。

使用示例 FlashGroupedQueryAttention

以下代码片段展示了如何使用 SMP v torch.sagemaker.nn.attn.FlashGroupedQueryAttention 2 提供的 API。

from transformers.models.llama.modeling_llama import LlamaAttention from torch.sagemaker.nn.attn import FlashGroupedQueryAttention class LlamaFlashAttention(LlamaAttention): def __init__(self, config: LlamaConfig): super().__init__(config) self.flash_attn = FlashGroupedQueryAttention( attention_dropout_prob=0.0, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ... ): query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) ... kv = (key_states, value_states) attn_output = self.flash_attn( query_states, kv, attn_mask=attention_mask, causal=True, layout="b h s d", ) ... attn_output = self.o_proj(attn_output) ... return attn_output

SMP 库还提供torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention,它在低级别使用 torch.sagemaker.nn.attn.FlashGroupedQueryAttention API。Hugging Face Transformers 也有类似的实现,LlamaFlashAttention2名为 v4.36.0。以下代码片段显示了如何使用 SMP v2 LlamaFlashAttention API 或 Transformers AP LlamaFlashAttention2 I 来替换现有 Llama 模型的注意力层。

from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention from transformers.models.llama.modeling_llama import LlamaFlashAttention2 flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2 attn_name = "self_attn" for layer in model.model.layers: prev_layer = getattr(layer, attn_name) setattr(layer, attn_name, flash_attn_class(model.config))