FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
Standard attention suffers quadratic complexity in terms of the sequence length (number of tokens). To reduce complexity, efficient attention methods have proposed sparse and/or low-rank approximations. These approximations reduce complexity to linear or near-linear with respect to the sequence length. Yet, these methods either lag in performance or achieve no wall-clock speedup compared to standard attention.
This paper [1] proposes an IO-aware attention formulation that is aware of reads and writes operations to different levels of fast and slow GPU memory. Modern GPUs have multiple memory levels with different sizes and speeds as shown in Fig. 1. In this article, we focus on High Bandwidth Memory (HBM) and Static Random Access Memory (SRAM). Like other memory hierarchies, HBM is both large and cheap but slow compared to SRAM which is both small and expensive but fast.
HBM is used to store tensors (e.g., feature maps/activations), while SRAM is used to perform compute operations on those tensors. For instance, when applying a RELU operation on a tensor x, we (1) move x from HBM (read-op) to SRAM; (2) apply RELU operation on x (compute-op), and (3) move x back from SRAM to HBM (write-op).
This paper [1] proposes an IO-aware algorithm that computes exact attention while reducing the number of memory read and write operations (read-ops and write-ops). To achieve this goal, the paper makes two contributions:
- Implement a CUDA kernel to fuse all the attention operations (matmul, mask, softmax, etc) into a single GPU kernel.
- Compute the softmax operation with neither computing nor storing the NxN attention matrix, where N is the number of tokens.
1st Contribution: A New CUDA Kernel
By default, every tensor operation is implemented as follows: (1)Read operation (read-op), (2) compute-op, and (3) write-op. This design makes it trivial to apply multiple tensor operations on top of each other. Concretely, one can apply a (1) matrix multiplication, then (2) masking, then (3) a softmax while making no assumption about the preceding or the succeeding operation. The three aforementioned operations will be executed as shown in Fig. 2.
Each operation requires a read-op and a write-op as highlighted in green. These memory access operations become a bottleneck, especially for a simple/fast compute-operation (e.g., RELU). Dao et al., [1] observe that self-attention has become a standard operation: (1) matmul, (2) mask, (3) softmax, (4) dropout, (5) matmul. Accordingly, the paper implemented a fused kernel, a single kernel that combines all these five operations. Within the fused kernel, there would be a single read-op and write-op which reduces the cost of memory operations significantly. Fig. 3 compares standard attention (left) with flash attention (right).
By having a fused kernel, the proposed flash-attention reduces the number of memory operations, which translates into a large speed-up during training as shown in Fig. 4.
2nd Contribution: Computing softmax without realizing the attention matrix
Besides implementing a fused kernel, the paper [1] makes another contribution. Flash attention computes exact attention without realizing the NxN Attention matrix A! It has been mistakenly assumed that an exact softmax operation requires both computing and storing the attention matrix A. This assumption stems from the softmax denominator which operates on the entire row of A as shown in Fig. 5.
Flash attention refutes this assumption through two tricks: (1) matrix-tiling as shown in Fig. 6; (2) summary statistics as shown in Figures 7 and 8.
Through tiling, flash attention splits the inputs Q, K, and V into blocks, then loads them from the slow HBM to the fast SRAM, then computes the attention output with respect to those blocks as shown in Fig. 6. Of course, softmax computed over separate blocks is inaccurate because softmax normalizes by the entire row. To tackle this, flash attention keeps track of summary statistics as it proceeds from one block to the next. When flash attention reaches the last block, the summary statistics will contain the exact softmax denominator.
Fig. 7 (Left) depicts a toy Q, K, and V to illustrate how summary statistics work. Fig. 7 (Right) presents a pseudo code of how softmax operation can be computed block-by-block using summary statistics {D and O}, i.e., without storing the attention matrix A.
Fig. 8 leverages the toy Q, K, V from Fig. 7, and shows the pseudo-code in action. At any given iteration, flash attention accesses the current block/item only and not the entire row. To compute exact attention, flash attention keeps track of summary statistics {D and O} which are updated after each iteration/block.
Before presenting the quantitive evaluations, there is one more technicality to be addressed. So far, we have explained how flash attention works during a feed-forward pass. Yet, the backward pass typically requires the gradients with respect to the attention matrix A in order to propagate the gradient to earlier layers. Since the attention matrix A is never realized, flash-attention does not have these gradients, at least without recomputation. Using the computed output and summary statistics from the feed-forward pass, flash-attention recomputes the attention matrix entries and their gradient during the backward pass, i.e., again without storing the entire matrix. This means that flash attention incurs more FLOPs compared to standard attention. Yet, even with more FLOPs, flash attention speeds up the backward pass due to reduced HBM accesses. Fig. 9 emphasizes this technicality by comparing the standard attention with flash attention through GFLOPs, memory accesses, and runtime.
While flash attention incurs more FLOPs, it is significantly faster in terms of runtime due to reduced HBM memory accesses.
Flash attention achieves two goals: (1) training speedup, (2) support longer sequences (context). Flash attention yields faster training times for GPT-2. Flash attention shows up to 3x and 1.7x end-to-end speedup compared to Huggingface and Megatron-LM, respectively as shown in Fig. 10. These speedups come without any accuracy sacrifice as flash attention computes exact attention.
The runtime and memory efficiency of Flash-Attention enables a larger context length by 4x on GPT-2 baseline. This still runs faster than the optimized implementation from Megatron-LM. Fig. 11 shows that GPT-2 with Flash attention and context length 4K is still 30% faster than GPT-2 from Megatron with context length 1K. Of course, this large context boosts performance (0.7 better perplexity).
The paper [1] presents other quantitative evaluations for those interested in fast attention. This article wraps up with Path-X and Path-256 benchmarks. This is a challenging benchmark where the task is to classify whether two points in a black and white 128×128 (or 256×256) image have a path connecting them. In this benchmark, images are fed to the transformer one pixel at a time which results in a very large sequence length. Fig. 12 shows a pair of examples from Path-X benchmarks.
In prior work, transformers-based models have either run out of memory or only achieved random performance. Flash attention is the first transformer to achieve better-than-chance performance on these benchmarks. Fig. 13 shows how Flash-attention achieves 61.4% accuracy on Path-X (seq. length 16K) and 63.1% accuracy on Path-256 (seq. length 64K) challenges.
Final thoughts:
- [S/W] Flash Attention has been integrated into PyTorch 2.0. So it is easy to use but it comes with a few constraints (e.g., only supports certain GPUs and requires CUDA 11). Fig. 14 lists these constraints.
- [W] Flash-attention fuses all the standard-attention’s operations (e.g., matmul, mask, softmax, etc) into a single CUDA fused kernel. Accordingly, any change to these operations requires a corresponding change to the CUDA kernel. Concretely, every time a new operation is introduced to boost standard attention, the flash-attention’s kernel needs to be updated accordingly.
- [W] Flash-attention’s speedups assume an ideal data loading procedure. The paper reported impressive speedups with flash attention during training. Yet, if the training process is bottlenecked by data loading, these speedups won’t be realized. A similar observation can be made about memory saving. While flash attention will always use less memory, these savings are significant with large sequences only.
- [S] I like the simplicity of flash attention and how it reduces standard-attention costs with smart engineering. Flash attention is just one tile within the great facet of Prof Christopher Ré’s lab at Stanford [2].
- I wrote a toy script [3] to profile flash attention against standard attention. The script is available here. Fig. 15 shows the running time (seconds) as the sequence length increases. Standard attention throws an out-of-memory error for sequences larger ≥ 4K, while flash attention supports seq_len=16K.
[1] Dao, T., Fu, D., Ermon, S., Rudra, A. and Ré, C., 2022. FlashAttention: Fast and memory-efficient exact attention with io-awareness. NeurIPS
[2] https://github.com/HazyResearch
[3] https://discuss.pytorch.org/t/flash-attention/174955/14?u=ahmdtaha