Rethinking Attention with Performers — Part II & Final

Ahmed Taha
7 min readFeb 14, 2023

This article’s objective is to summarize the Performers [1] paper. The article highlights key details and documents some personal comments at the end. A previous article presents a hand-wavy understanding of Performers using a hashing analogy.

Vanilla Transformers leverage self-attention layers defined as follows

Figure 1: Vanilla self-attention with quadratic space complexity

This formula has quadratic space complexity O(L²) where L is the input sequence length. This hinders transformers for long-input sequences, e.g., long text sequences, high-resolution images, and long protein sequences.

To tackle this problem, Choromanski et al. [1] have proposed Performers to estimate full-rank-attention through Fast Attention Via Positive Orthogonal Random features (FAVOR+). The term FAVOR+ comprises three parts:

  1. Fast-Attention (FA),
  2. Positive Random Feature (+/PRF),
  3. Orthogonal Random features (ORF).

This article presents each part separately, then presents quantitative evaluations

  1. Fast-Attention (FA)

Choromanski et al. [1] regard the attention-softmax mechanism through kernelization. Concretely, the attention-softmax SM(Q, K) is a kernel that can be approximated using random feature maps ϕ. Fig. 2 summarizes this idea

Figure 2: The Attention-softmax SM(Q, K) can be approximated using random feature maps ϕ. Please note that the softmax function always returns positive values ≥0.

With the right ϕ, the original self-attention formulation (Fig. 1) can be reformulated as follows

Figure 3: Performers approximate vanilla self-attention using random features maps ϕ(Q)=Q’ and ϕ(K)=K’. Please note the order of matrix multiplication.

By changing the order of matrix multiplication, this formulation escapes the O(L²) space complexity and reduces time complexity from O(L² d) to O(Lrd) as shown in Fig. 4.

Figure 4: Approximation of the regular attention mechanism AV (before D−1-renormalization) via (random) feature maps. Dashed-blocks indicate the order of computation with corresponding time complexities attached.

Given this kernelization trick, the next question becomes: how to approximate the softmax-kernel? Concretely, what is the right ϕ?

2. Positive Random Feature (PRF)

Choromanski et al. [1] propose the following formula for ϕ

Figure 5: When regarded as a kernel, Softmax attention can be approximated using random feature maps ϕ. Besides the softmax kernel, this ϕ formulation can model most kernels used in practice.

This formula can model most kernels used in practice where w_1, .. ,w_m are random vectors sampled from distribution D. By setting (h, l, and f), this formula approximates different kernels. Tab. 1 summarizes some of the proposed values for (h, l, and f) and the corresponding kernels

Table 1: ϕ(x) can model different kernels by setting (h, l, and f) differently. For instance, the Gaussian kernel can be approximated by setting h(x)=1, l=2, f_1(u)=sin(u), and f_2(u)=cos(u).

The Softmax-kernel SM^{trig} — third kernel in Tab.1 — should approximate the softmax used in self-attention. However, there is a caveat. This approximation uses sin and cos functions which can generate negative values. This contradicts the vanilla softmax function which always returns positive values ≥0 as shown in Fig. 2. Choromanski et al. [1] found SM^{trig} unstable especially when kernel scores close to 0, i.e., low relevance tokens. Low relevance indicates a large angle (low similarity) between tokens, while high relevance indicates a small angle (high similarity) between tokens.

To resolve the aforementioned stability issue, Choromanski et al. [1] approximate the softmax-kernel using positive random features (PRF) with exp functions, i.e., instead of sin/cos. This leads to SM^{+} and SM^{hyp+} as shown in Tab. 1. To evaluate these positive random features quantitatively, the paper plots the mean square error (MSE) ratio between SM^{trig} and SM^{+} as shown in Fig. 6.

Figure 6: The ratio between the mean square error (MSE) of SM^{trig} and SM^{+}. The vertical axis r denotes the ratio of the mean squared errors (MSEs) of estimators built on: trigonometric and positive random features. The horizontal plane plots the angle ϕ (in radians) between input feature vectors and their lengths l. Larger values (vertical r-axis) indicate regions of (ϕ, l)-space where positive random features SM^{+} outperform trigonometric features SM^{trig}. For critical regions with ϕ large enough (softmax-kernel scores close to 0 for low relevance tokens), SM^{+} is more accurate than SM^{trig}.

3. Orthogonal Random features (ORF)

Now, the vanilla softmax can be approximated using both the kernelization trick (FA) and positive random features (PRF). Yet, how many random features r are needed? Can we get away with a small number r?

To reduce both the kernel’s variance and the number of random features r, Choromanski et al. [1] leverage orthogonal random features (ORF). Since (w_1,…, w_m) are randomly sampled features, Choromanski et al. [1] use the Gram-Schmidt orthogonalization procedure. Gram-Schmidt is a linear algebra procedure for orthonormalizing a set of vectors. Fig. 7 shows a toy Gram-Schmidt example where two random vectors (v_1, v_2) are orthonormalized into (u_1, u_2).

Figure 7: Gram-Schmidt process using a toy 2D space with basis V =[v1, v2]. While V is not an orthogonal basis, the Gram-Schmidt process turns it into an orthogonal basis U. Gram-Schmidt computes u_i by (1) projecting v_i orthogonally onto the subspace U generated by u_1, …, u_{i−1} to compute proj(v_i); then (2) subtracting the aforementioned projection (proj(v_i)) from v_i.

While (1)FA, (2) PRF, and (3) ORF are the paper’s primary ideas, the paper has secondary tricks to increase robustness: (1) random feature maps are periodically redrawn; (2) regularized softmax-kernel SMREG. For those interested in these tricks, please read the paper.

Now that we presented FAVOR+ (a.k.a. Performers), we turn to the experiments. We start by evaluating the computational cost during both forward and backward passes. Fig. 8 compares vanilla-transformers with Performers using long-range sequence modeling. Performer reaches nearly linear time and sub-quadratic memory consumption. With large sequences L, the memory and backward pass efficiencies allow large batch training and lower wall clock time, respectively.

Figure 8: Comparison of Transformer and Performer in terms of forward and backward pass speed and maximum L allowed. “X” (OPT) denotes the maximum possible speedup. Plots show up to when a model produces an out-of-memory error on a V100 GPU with 16GB. The vocabulary size used was 256.

Performers are first evaluated using two language benchmarks: (1) the small-range LM1B benchmark; (2) the large-range PG-19 benchmark. Performers can not substitute for vanilla attention out-of-the-box. Fig.9 (left) shows performer achieves an inferior accuracy (0.07) when substituting for vanilla attention even on the small-range LM1B benchmark. Fortunately, by fine-tuning for a small number of gradient steps, accuracy can be recovered. On the large-range PG-19 benchmark (Fig. 9 right), the trigonometric softmax SM^{trig} becomes highly unstable. Even with positive feature maps, positive softmax (SM^{+}) plateaus at an inferior perplexity. To match vanilla attention, positive softmax (SM^{+}) needs other tricks like (1) redrawing and (2) SMREG.

Figure 9: (Left) Performer evaluation on small sequence benchmark (LM1B). When replacing the original attention layers, Performer (SM^{trig}) produces an initial non-zero 0.07 accuracy (dotted orange line). Fortunately, Performer recovers accuracy quickly via a small number of gradient steps. (Right) Performer evaluation on large sequence benchmark (PG-19). Performer (SM^{trig}) becomes highly unstable, while Performer (SM^{+}) plateaus at an inferior perplexity. To match vanilla attention, Performer (SM^{+}) needs other secondary tricks like (1) redrawing and (2) SMREG. These tricks are not presented in this article. For those interested in these tricks, please read the paper.

Besides language models, Performers are valuable for protein sequences with L=1024. Using the TrEMBL benchmark, Fig. 10 evaluates Performers with 36-layer model. Linformer — an alternative efficient transformer — suffers significantly. Yet, Performer-Softmax and Performer-RELU achieve the highest accuracy. It is worth noting that this is the first experiment to introduce and evaluate Performer-RELU.

Figure 10: Quantitative evaluation using protein sequences TrEMBL benchmark. The model parameters (nheads, nlayers, dff , d) = (8, 36, 1024, 512). While Linformer suffers, Performer-Softmax and Performer-RELU achieve superior performance.

Finally, Choromanski et al. [1] evaluate Performer on an image-generation task using the ImageNet64 benchmark. In this task, 64x64x3 images are generated, i.e., L=12288, which is unfeasible for regular Transformers. Thus, Fig. 11 evaluates Performers and Reformers only. Performer/6-layers matches the Reformer/12-layers, while the Performer/12-layers matches the Reformer/24-layers.

Figure 11: Train = Dashed, Validation = Solid. On ImageNet64 benchmark, Performer/6-layers matches the Reformer/12-layers, while the Performer/12-layers matches the Reformer/24-layers.

Fig. 12 evaluates different Performer’ kernels (RELU and Softmax) on the same ImageNet64 benchmark. Performer-softmax with positive features achieves the same result as Performer-ReLU.

Figure 12: Train = Dashed, Validation = Solid. Quantitative evaluation of Performer using 6 layers on ImageNet64 benchmark. Approximate softmax with positive features achieves the same result as generalized ReLU attention.

Comments

  1. I like this paper but it could have been written/presented better:

a. The same symbols (ϕ, l, r) have multiple meanings; ϕ denotes random feature maps in Sec 2.3, but also denotes the angle between tokens in Fig. 2; l denotes the number of functions in Eq. 5, but also denotes the vector length in Fig. 2; r denotes the dimension of the random feature maps in Sec 2.2, but also denotes the ratio of the mean squared errors (MSEs) of estimators in Fig. 2.

b. The paper elaborates on the softmax kernel approximations and estimated MSE (Sec. 2 & 3). Yet, the paper seems to conclude that the RELU kernel is better (Sec. 4.4) right before the conclusion section! If RELU is indeed better, it would have been better to elaborate more on this. The RELU kernel is vaguely defined, i.e., h(x) is not explicitly mentioned. Fortunately, the paper’s code has been released for both JAX [8] and PyTorch [9]. So, one can infer missing details.

2. Do we need Performers for high-resolution spatial data? CoAt-Net [2] argues that downsampling (pooling) a high-resolution spatial input (image) into a low-resolution then performing vanilla attention — on the low-resolution — is good enough. While this downsampling trick works for natural images, downsampling a sentence or a protein sequence is non-trivial. Furthermore, I am concerned about this downsampling approach when an image contains a tiny object of interest, e.g., a malignant tissue in a medical image or a sea ship in a satellite image.

3. In terms of similar methods, the work by Prof. Chris Ré’s Lab [3–4] looks interesting with promising results.

4. In terms of vision applications, these long-large modeling approaches deliver great value for satellite imagery [5–6] and medical imaging [7] applications.

5. Funny joke: Gram-Schmidt method is a simple method with a single core idea. If Gram came up with this idea, what Schmidt did?! :)

References

[1] Choromanski, K., Likhosherstov, V., Dohan, D., Song, X., Gane, A., Sarlos, T., Hawkins, P., Davis, J., Mohiuddin, A., Kaiser, L. and Belanger, D., 2020. Rethinking attention with performers.

[2] Dai, Z., Liu, H., Le, Q.V. and Tan, M., 2021. Coatnet: Marrying convolution and attention for all data sizes. Advances in Neural Information Processing Systems

[3] Gu, A., Goel, K. and Ré, C., 2021. Efficiently modeling long sequences with structured state spaces. arXiv preprint arXiv:2111.00396.

[4] Dao, T., Fu, D.Y., Saab, K.K., Thomas, A.W., Rudra, A. and Ré, C., 2022. Hungry Hungry Hippos: Towards Language Modeling with State Space Models. arXiv preprint arXiv:2212.14052.

[5] Zhang, Z., Zhang, L., Wang, Y., Feng, P. and He, R., 2021. ShipRSImageNet: A large-scale fine-grained dataset for ship detection in high-resolution optical remote sensing images. IEEE Journal of Selected Topics in Applied Earth Observations and Remote Sensing

[6] Lam, D., Kuzma, R., McGee, K., Dooley, S., Laielli, M., Klaric, M., Bulatov, Y. and McCord, B., 2018. xview: Objects in context in overhead imagery. arXiv preprint arXiv:1802.07856.

[7] Taha, A., Truong Vu, Y.N., Mombourquette, B., Matthews, T.P., Su, J. and Singh, S., 2022, September. Deep is a Luxury We Don’t Have. In Medical Image Computing and Computer Assisted Intervention–MICCAI 2022: 25th International Conference.

[8] https://github.com/google-research/google-research/blob/master/performer/fast_attention/jax/fast_attention.py

[9] https://github.com/lucidrains/performer-pytorch

--

--