本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。
FlashAttention 支持
对 FlashAttention 的支持是该库仅适用于分布式转换器模型的功能,这种模式是使用 smp.DistributedModel()
只有在将 attention_head_size 设置为 8 的倍数且小于 128 的值时,FlashAttention
例如,假设您使用 hidden_width=864 和 num_heads=48 配置转换器模型。FlashAttention 的头大小计算公式为 attention_head_size = hidden_width / num_heads = 864 / 48 = 18。要启用 FlashAttention,您需要将 num_heads 参数调节为 54,这样 attention_head_size = hidden_width / num_heads = 864
/ 54 = 16,其值是 8 的倍数。