Rethinking Attention with Performers — Part II & Final
This article’s objective is to summarize the Performers [1] paper. The article highlights key details and documents some personal comments at the end. A previous article presents a hand-wavy understanding of Performers using a hashing analogy.
Vanilla Transformers leverage self-attention layers defined as follows
This formula has quadratic space complexity O(L²) where L is the input sequence length. This hinders transformers for long-input sequences, e.g., long text sequences, high-resolution images, and long protein sequences.
To tackle this problem, Choromanski et al. [1] have proposed Performers to estimate full-rank-attention through Fast Attention Via Positive Orthogonal Random features (FAVOR+). The term FAVOR+ comprises three parts:
- Fast-Attention (FA),
- Positive Random Feature (+/PRF),
- Orthogonal Random features (ORF).
This article presents each part separately, then presents quantitative evaluations
- Fast-Attention (FA)
Choromanski et al. [1] regard the attention-softmax mechanism through kernelization. Concretely, the attention-softmax SM(Q, K) is a kernel that can be approximated using random feature maps ϕ. Fig. 2 summarizes this idea
With the right ϕ, the original self-attention formulation (Fig. 1) can be reformulated as follows
By changing the order of matrix multiplication, this formulation escapes the O(L²) space complexity and reduces time complexity from O(L² d) to O(Lrd) as shown in Fig. 4.
Given this kernelization trick, the next question becomes: how to approximate the softmax-kernel? Concretely, what is the right ϕ?
2. Positive Random Feature (PRF)
Choromanski et al. [1] propose the following formula for ϕ
This formula can model most kernels used in practice where w_1, .. ,w_m are random vectors sampled from distribution D. By setting (h, l, and f), this formula approximates different kernels. Tab. 1 summarizes some of the proposed values for (h, l, and f) and the corresponding kernels
The Softmax-kernel SM^{trig} — third kernel in Tab.1 — should approximate the softmax used in self-attention. However, there is a caveat. This approximation uses sin and cos functions which can generate negative values. This contradicts the vanilla softmax function which always returns positive values ≥0 as shown in Fig. 2. Choromanski et al. [1] found SM^{trig} unstable especially when kernel scores close to 0, i.e., low relevance tokens. Low relevance indicates a large angle (low similarity) between tokens, while high relevance indicates a small angle (high similarity) between tokens.
To resolve the aforementioned stability issue, Choromanski et al. [1] approximate the softmax-kernel using positive random features (PRF) with exp functions, i.e., instead of sin/cos. This leads to SM^{+} and SM^{hyp+} as shown in Tab. 1. To evaluate these positive random features quantitatively, the paper plots the mean square error (MSE) ratio between SM^{trig} and SM^{+} as shown in Fig. 6.
3. Orthogonal Random features (ORF)
Now, the vanilla softmax can be approximated using both the kernelization trick (FA) and positive random features (PRF). Yet, how many random features r are needed? Can we get away with a small number r?
To reduce both the kernel’s variance and the number of random features r, Choromanski et al. [1] leverage orthogonal random features (ORF). Since (w_1,…, w_m) are randomly sampled features, Choromanski et al. [1] use the Gram-Schmidt orthogonalization procedure. Gram-Schmidt is a linear algebra procedure for orthonormalizing a set of vectors. Fig. 7 shows a toy Gram-Schmidt example where two random vectors (v_1, v_2) are orthonormalized into (u_1, u_2).
While (1)FA, (2) PRF, and (3) ORF are the paper’s primary ideas, the paper has secondary tricks to increase robustness: (1) random feature maps are periodically redrawn; (2) regularized softmax-kernel SMREG. For those interested in these tricks, please read the paper.
Now that we presented FAVOR+ (a.k.a. Performers), we turn to the experiments. We start by evaluating the computational cost during both forward and backward passes. Fig. 8 compares vanilla-transformers with Performers using long-range sequence modeling. Performer reaches nearly linear time and sub-quadratic memory consumption. With large sequences L, the memory and backward pass efficiencies allow large batch training and lower wall clock time, respectively.
Performers are first evaluated using two language benchmarks: (1) the small-range LM1B benchmark; (2) the large-range PG-19 benchmark. Performers can not substitute for vanilla attention out-of-the-box. Fig.9 (left) shows performer achieves an inferior accuracy (0.07) when substituting for vanilla attention even on the small-range LM1B benchmark. Fortunately, by fine-tuning for a small number of gradient steps, accuracy can be recovered. On the large-range PG-19 benchmark (Fig. 9 right), the trigonometric softmax SM^{trig} becomes highly unstable. Even with positive feature maps, positive softmax (SM^{+}) plateaus at an inferior perplexity. To match vanilla attention, positive softmax (SM^{+}) needs other tricks like (1) redrawing and (2) SMREG.
Besides language models, Performers are valuable for protein sequences with L=1024. Using the TrEMBL benchmark, Fig. 10 evaluates Performers with 36-layer model. Linformer — an alternative efficient transformer — suffers significantly. Yet, Performer-Softmax and Performer-RELU achieve the highest accuracy. It is worth noting that this is the first experiment to introduce and evaluate Performer-RELU.
Finally, Choromanski et al. [1] evaluate Performer on an image-generation task using the ImageNet64 benchmark. In this task, 64x64x3 images are generated, i.e., L=12288, which is unfeasible for regular Transformers. Thus, Fig. 11 evaluates Performers and Reformers only. Performer/6-layers matches the Reformer/12-layers, while the Performer/12-layers matches the Reformer/24-layers.
Fig. 12 evaluates different Performer’ kernels (RELU and Softmax) on the same ImageNet64 benchmark. Performer-softmax with positive features achieves the same result as Performer-ReLU.
Comments
- I like this paper but it could have been written/presented better:
a. The same symbols (ϕ, l, r) have multiple meanings; ϕ denotes random feature maps in Sec 2.3, but also denotes the angle between tokens in Fig. 2; l denotes the number of functions in Eq. 5, but also denotes the vector length in Fig. 2; r denotes the dimension of the random feature maps in Sec 2.2, but also denotes the ratio of the mean squared errors (MSEs) of estimators in Fig. 2.
b. The paper elaborates on the softmax kernel approximations and estimated MSE (Sec. 2 & 3). Yet, the paper seems to conclude that the RELU kernel is better (Sec. 4.4) right before the conclusion section! If RELU is indeed better, it would have been better to elaborate more on this. The RELU kernel is vaguely defined, i.e., h(x) is not explicitly mentioned. Fortunately, the paper’s code has been released for both JAX [8] and PyTorch [9]. So, one can infer missing details.
2. Do we need Performers for high-resolution spatial data? CoAt-Net [2] argues that downsampling (pooling) a high-resolution spatial input (image) into a low-resolution then performing vanilla attention — on the low-resolution — is good enough. While this downsampling trick works for natural images, downsampling a sentence or a protein sequence is non-trivial. Furthermore, I am concerned about this downsampling approach when an image contains a tiny object of interest, e.g., a malignant tissue in a medical image or a sea ship in a satellite image.
3. In terms of similar methods, the work by Prof. Chris Ré’s Lab [3–4] looks interesting with promising results.
4. In terms of vision applications, these long-large modeling approaches deliver great value for satellite imagery [5–6] and medical imaging [7] applications.
5. Funny joke: Gram-Schmidt method is a simple method with a single core idea. If Gram came up with this idea, what Schmidt did?! :)
References
[1] Choromanski, K., Likhosherstov, V., Dohan, D., Song, X., Gane, A., Sarlos, T., Hawkins, P., Davis, J., Mohiuddin, A., Kaiser, L. and Belanger, D., 2020. Rethinking attention with performers.
[2] Dai, Z., Liu, H., Le, Q.V. and Tan, M., 2021. Coatnet: Marrying convolution and attention for all data sizes. Advances in Neural Information Processing Systems
[3] Gu, A., Goel, K. and Ré, C., 2021. Efficiently modeling long sequences with structured state spaces. arXiv preprint arXiv:2111.00396.
[4] Dao, T., Fu, D.Y., Saab, K.K., Thomas, A.W., Rudra, A. and Ré, C., 2022. Hungry Hungry Hippos: Towards Language Modeling with State Space Models. arXiv preprint arXiv:2212.14052.
[5] Zhang, Z., Zhang, L., Wang, Y., Feng, P. and He, R., 2021. ShipRSImageNet: A large-scale fine-grained dataset for ship detection in high-resolution optical remote sensing images. IEEE Journal of Selected Topics in Applied Earth Observations and Remote Sensing
[6] Lam, D., Kuzma, R., McGee, K., Dooley, S., Laielli, M., Klaric, M., Bulatov, Y. and McCord, B., 2018. xview: Objects in context in overhead imagery. arXiv preprint arXiv:1802.07856.
[7] Taha, A., Truong Vu, Y.N., Mombourquette, B., Matthews, T.P., Su, J. and Singh, S., 2022, September. Deep is a Luxury We Don’t Have. In Medical Image Computing and Computer Assisted Intervention–MICCAI 2022: 25th International Conference.