FlashAttention 支持
对 FlashAttention 的支持是该库仅适用于分布式转换器模型的功能,这种模式是使用 smp.DistributedModel()
只有在将 attention_head_size
设置为 8 的倍数且小于 128 的值时,FlashAttention
例如,假设您使用 hidden_width=864
和 num_heads=48
配置转换器模型。FlashAttention 的头大小计算公式为 attention_head_size = hidden_width / num_heads = 864 / 48 = 18
。要启用 FlashAttention,您需要将 num_heads
参数调节为 54
,这样 attention_head_size = hidden_width / num_heads = 864
/ 54 = 16
,其值是 8 的倍数。