Flash Attention
Flash attention在做一件什么事
Flash Attention在做的事情,其实都包含在它的命名中了(Fast and Memory Efficient Exact Attention with IO-Awareness),我们逐一来看:
(1)Fast(with IO-Awareness),计算快。在Flash Attention之前,也出现过一些加速Transformer计算的方法,这些方法的着眼点是“减少计算量FLOPs”,例如用一个稀疏attention做近似计算。但是Flash attention就不一样了,它并没有减少总的计算量,因为它发现:计算慢的卡点不在运算能力,而是在读写速度上。所以它通过降低对显存(HBM)的访问次数来加快整体运算速度,这种方法又被称为O-Awareness。在后文中,我们会详细来看Flash Attention是如何通过分块计算(tiling)和核函数融合(kernel fusion)来降低对显存的访问。
(2)Memory Efficicent,节省显存。在标准attention场景中,forward时我们会计算并保存N*N大小的注意力矩阵;在backward时我们又会读取它做梯度计算,这就给硬件造成了 的存储压力。在Flash Attention中,则巧妙避开了这点,使得存储压力降至 。在后文中我们会详细看这个trick。
(3)Exact Attention,精准注意力。在(1)中我们说过,之前的办法会采用类似于“稀疏attention”的方法做近似。这样虽然能减少计算量,但算出来的结果并不完全等同于标准attention下的结果。但是Flash Attention却做到了完全等同于标准attention的实现方式,这也是后文我们讲述的要点。
计算限制与内存限制
在第一部分中我们提过,Flash Attention一个很重要的改进点是:由于它发现Transformer的计算瓶颈不在运算能力,而在读写速度上。因此它着手降低了对显存数据的访问次数,这才把整体计算效率提了上来。所以现在我们要问了:它是怎么知道卡点在读写速度上的?
为了解答这个问题,我们先来看几个重要概念:
- :硬件算力上限。指的是一个计算平台倾尽全力每秒钟所能完成的浮点运算数。单位是 FLOPS or FLOP/s。
- :硬件带宽上限。指的是一个计算平台倾尽全力每秒所能完成的内存交换量。单位是Byte/s。
- :某个算法所需的总运算量,单位是FLOPs。下标 表示total。
- :某个算法所需的总数据读取存储量,单位是Byte。下标 表示total。
这里再强调一下对FLOPS和FLOPs的解释:
FLOPS:等同于FLOP/s,表示Floating Point Operations Per Second,即每秒执行的浮点数操作次数,用于衡量硬件计算性能。 FLOPs:表示Floating Point Operations,表示某个算法的总计算量(即总浮点运算次数),用于衡量一个算法的复杂度。