FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

Ahmed Taha
8 min readJun 5, 2023

Standard attention suffers quadratic complexity in terms of the sequence length (number of tokens). To reduce complexity, efficient attention methods have proposed sparse and/or low-rank approximations. These approximations reduce complexity to linear or near-linear with respect to the sequence length. Yet, these methods either lag in performance or achieve no wall-clock speedup compared to standard attention.

Figure 1: Memory Hierarchy with Bandwidth & Memory Size.

This paper [1] proposes an IO-aware attention formulation that is aware of reads and writes operations to different levels of fast and slow GPU memory. Modern GPUs have multiple memory levels with different sizes and speeds as shown in Fig. 1. In this article, we focus on High Bandwidth Memory (HBM) and Static Random Access Memory (SRAM). Like other memory hierarchies, HBM is both large and cheap but slow compared to SRAM which is both small and expensive but fast.

HBM is used to store tensors (e.g., feature maps/activations), while SRAM is used to perform compute operations on those tensors. For instance, when applying a RELU operation on a tensor x, we (1) move x from HBM (read-op) to SRAM; (2) apply RELU operation on x (compute-op), and (3) move x back from SRAM to HBM (write-op).

This paper [1] proposes an IO-aware algorithm that computes exact attention while reducing the number of memory read and write operations (read-ops and write-ops). To achieve this goal, the paper makes two contributions:

  1. Implement a CUDA kernel to fuse all the attention operations (matmul, mask, softmax, etc) into a single GPU kernel.
  2. Compute the softmax operation with neither computing nor storing the NxN attention matrix, where N is the number of tokens.

1st Contribution: A New CUDA Kernel

By default, every tensor operation is implemented as follows: (1)Read operation (read-op), (2) compute-op, and (3) write-op. This design makes it trivial to apply multiple tensor operations on top of each other. Concretely, one can apply a (1) matrix multiplication, then (2) masking, then (3) a softmax while making no assumption about the preceding or the succeeding operation. The three aforementioned operations will be executed as shown in Fig. 2.

Figure 2: Every tensor operation has a (1) read-op, (2) compute-op, and (3) write-op. These read/write operations — highlighted in green — can bottleneck for cheap compute-op.

Each operation requires a read-op and a write-op as highlighted in green. These memory access operations become a bottleneck, especially for a simple/fast compute-operation (e.g., RELU). Dao et al., [1] observe that self-attention has become a standard operation: (1) matmul, (2) mask, (3) softmax, (4) dropout, (5) matmul. Accordingly, the paper implemented a fused kernel, a single kernel that combines all these five operations. Within the fused kernel, there would be a single read-op and write-op which reduces the cost of memory operations significantly. Fig. 3 compares standard attention (left) with flash attention (right).

Figure 3: A comparison between standard attention (left) and flash attention (right). This comparison leverages three operations (matmul, mask, softmax) only. Other operations (e.g., dropout) are omitted for presentation purposes.

By having a fused kernel, the proposed flash-attention reduces the number of memory operations, which translates into a large speed-up during training as shown in Fig. 4.

Figure 4: A comparison between a standard-attention and flash attention. (Left) Flash attention delivers a 7.6x Speedup over the PyTorch implementation. (Right) Flash attention is 15% faster over an Nvidia implementation that set the training speed record for MLPerf 1.1.

2nd Contribution: Computing softmax without realizing the attention matrix

Besides implementing a fused kernel, the paper [1] makes another contribution. Flash attention computes exact attention without realizing the NxN Attention matrix A! It has been mistakenly assumed that an exact softmax operation requires both computing and storing the attention matrix A. This assumption stems from the softmax denominator which operates on the entire row of A as shown in Fig. 5.

Figure 5: The softmax denominator operates on an entire row. Yet, flash attention computes the exact operation without storing the attention matrix A.

Flash attention refutes this assumption through two tricks: (1) matrix-tiling as shown in Fig. 6; (2) summary statistics as shown in Figures 7 and 8.

Figure 6: FlashAttention avoids the materialization of the large 𝑁 × 𝑁 attention matrix. In the outer loop (red arrows), FlashAttention loops through blocks of the K and V matrices and loads them to fast on-chip SRAM. In each block, FlashAttention loops over blocks of Q matrix (blue arrows), loading them to SRAM, and writing the output of the attention computation back to HBM.

Through tiling, flash attention splits the inputs Q, K, and V into blocks, then loads them from the slow HBM to the fast SRAM, then computes the attention output with respect to those blocks as shown in Fig. 6. Of course, softmax computed over separate blocks is inaccurate because softmax normalizes by the entire row. To tackle this, flash attention keeps track of summary statistics as it proceeds from one block to the next. When flash attention reaches the last block, the summary statistics will contain the exact softmax denominator.

Fig. 7 (Left) depicts a toy Q, K, and V to illustrate how summary statistics work. Fig. 7 (Right) presents a pseudo code of how softmax operation can be computed block-by-block using summary statistics {D and O}, i.e., without storing the attention matrix A.

Figure 7: A Toy Q, K, and V matrixes to illustrate the difference between standard and flash attention. (Left) Standard attention computes and stores the entire attention matrix A to compute the attention output O; (Right) Flash attention operates on individual blocks of attention matrix A (A[i]=Q[i]*K[i]). So there is no need to compute and store the entire attention matrix A.
Figure 8: An illustration for the pseudo-code from Fig. 7 applied on the toy Q, K, V matrixes. Flash attention computes exact softmax operation using summary statistics {D, and O} and without storing the attention matrix A. The official flash attention uses more statistics (e.g., m=max(row)) for numerical stability. These extra statistics are omitted for presentation purposes. D_b denotes the current block denominator/numerator, while D is a summary statistic that tracks the row denominator.

Fig. 8 leverages the toy Q, K, V from Fig. 7, and shows the pseudo-code in action. At any given iteration, flash attention accesses the current block/item only and not the entire row. To compute exact attention, flash attention keeps track of summary statistics {D and O} which are updated after each iteration/block.

Before presenting the quantitive evaluations, there is one more technicality to be addressed. So far, we have explained how flash attention works during a feed-forward pass. Yet, the backward pass typically requires the gradients with respect to the attention matrix A in order to propagate the gradient to earlier layers. Since the attention matrix A is never realized, flash-attention does not have these gradients, at least without recomputation. Using the computed output and summary statistics from the feed-forward pass, flash-attention recomputes the attention matrix entries and their gradient during the backward pass, i.e., again without storing the entire matrix. This means that flash attention incurs more FLOPs compared to standard attention. Yet, even with more FLOPs, flash attention speeds up the backward pass due to reduced HBM accesses. Fig. 9 emphasizes this technicality by comparing the standard attention with flash attention through GFLOPs, memory accesses, and runtime.

Figure 9: Both forward and backward runtime of standard and flash attentions. Flash-attention incurs more GFLOPs because it re-computes the attention operation during backward passes. This increases the GFLOPs but reduces the number of memory-access operations. Since memory (HBM) access is the primary factor affecting runtime, flash-attention runs faster.

While flash attention incurs more FLOPs, it is significantly faster in terms of runtime due to reduced HBM memory accesses.

Flash attention achieves two goals: (1) training speedup, (2) support longer sequences (context). Flash attention yields faster training times for GPT-2. Flash attention shows up to 3x and 1.7x end-to-end speedup compared to Huggingface and Megatron-LM, respectively as shown in Fig. 10. These speedups come without any accuracy sacrifice as flash attention computes exact attention.

Figure 10: GPT-2 small and medium using FlashAttention achieve up to 3x speed up compared to Huggingface implementation and up to 1.7x compared to Megatron-LM. Training time reported on 8xA100s GPUs.

The runtime and memory efficiency of Flash-Attention enables a larger context length by 4x on GPT-2 baseline. This still runs faster than the optimized implementation from Megatron-LM. Fig. 11 shows that GPT-2 with Flash attention and context length 4K is still 30% faster than GPT-2 from Megatron with context length 1K. Of course, this large context boosts performance (0.7 better perplexity).

Figure 11: GPT-2 small with FlashAttention, with 4x larger context length compared to Megatron-LM, is still 30% faster while achieving 0.7 better perplexity. Training time on 8xA100 GPUs is reported.

The paper [1] presents other quantitative evaluations for those interested in fast attention. This article wraps up with Path-X and Path-256 benchmarks. This is a challenging benchmark where the task is to classify whether two points in a black and white 128×128 (or 256×256) image have a path connecting them. In this benchmark, images are fed to the transformer one pixel at a time which results in a very large sequence length. Fig. 12 shows a pair of examples from Path-X benchmarks.

Figure 12: Samples of the Path-finder benchmark. (Left) A positive example where the two dots are connected by a path. (Right) A negative example where no path connects the two dots.

In prior work, transformers-based models have either run out of memory or only achieved random performance. Flash attention is the first transformer to achieve better-than-chance performance on these benchmarks. Fig. 13 shows how Flash-attention achieves 61.4% accuracy on Path-X (seq. length 16K) and 63.1% accuracy on Path-256 (seq. length 64K) challenges.

Figure 13: FlashAttention is the first transformer model that achieves non-random performance on Path-X and Path-256 benchmarks.

Final thoughts:

Figure 14: FlashAttention supports certain GPUs and data types. These constraints are listed at https://github.com/HazyResearch/flash-attention
  1. [S/W] Flash Attention has been integrated into PyTorch 2.0. So it is easy to use but it comes with a few constraints (e.g., only supports certain GPUs and requires CUDA 11). Fig. 14 lists these constraints.
  2. [W] Flash-attention fuses all the standard-attention’s operations (e.g., matmul, mask, softmax, etc) into a single CUDA fused kernel. Accordingly, any change to these operations requires a corresponding change to the CUDA kernel. Concretely, every time a new operation is introduced to boost standard attention, the flash-attention’s kernel needs to be updated accordingly.
  3. [W] Flash-attention’s speedups assume an ideal data loading procedure. The paper reported impressive speedups with flash attention during training. Yet, if the training process is bottlenecked by data loading, these speedups won’t be realized. A similar observation can be made about memory saving. While flash attention will always use less memory, these savings are significant with large sequences only.
  4. [S] I like the simplicity of flash attention and how it reduces standard-attention costs with smart engineering. Flash attention is just one tile within the great facet of Prof Christopher Ré’s lab at Stanford [2].
  5. I wrote a toy script [3] to profile flash attention against standard attention. The script is available here. Fig. 15 shows the running time (seconds) as the sequence length increases. Standard attention throws an out-of-memory error for sequences larger ≥ 4K, while flash attention supports seq_len=16K.
Figure 15: Running time of flash and standard attention on my personal GPU. Flash attention brings significant speedups and enables longer sequence length (attention context).

[1] Dao, T., Fu, D., Ermon, S., Rudra, A. and Ré, C., 2022. FlashAttention: Fast and memory-efficient exact attention with io-awareness. NeurIPS

[2] https://github.com/HazyResearch

[3] https://discuss.pytorch.org/t/flash-attention/174955/14?u=ahmdtaha

--

--