Masked Autoencoders Are Scalable Vision Learners
Annotated data is a vital pillar of deep learning. Yet, annotated data is rare in certain applications (e.g., medical and robotics). To reduce the number of annotations, self-supervised learning aims to pre-train deep networks on unannotated data to learn useful representations. Different self-supervised learning approaches propose different objectives to train a deep network with unannotated data. This paper [1] leverages the masked autoencoding objective to pre-train ViT models on images.
While the masked autoencoding objective has been proposed a long time ago, it became prominent thanks to BERT. BERT is a language model that has been pre-trained on abundant unlabeled data from the web. During pre-training, the BERT model takes a sentence and masks some words. By masking words, BERT's objective is to predict the masked words as shown in Fig. 1.
BERT has been a success in natural language processing (NLP) as it eliminates the cost required to collect and annotate labeled datasets. Despite its success in NLP, replicating BERT for vision applications has been a challenge for the following three reasons:
- The architecture gap between vision and NLP: while Transformers dominate NLP applications, CNNs used to dominate vision applications. Transformers’ blocks make it easier to mask an input unit (e.g., a set of random words). In contrast, CNNs process overlapping patches in images. This makes it unnatural to mask an input unit (e.g., a set of random pixels/patches).
- The density difference between pixels and words: a single word — without context — delivers valuable information, while a single image pixel — without context — delivers nothing. To predict a masked word, a sophisticated language understanding is required. In contrast, it is trivial to predict a masked pixel from neighboring patches with little high-level understanding of parts, objects, and scenes.
- Predicting words is technically trivial (e.g., using an MLP), but predicting pixels is computationally expensive — there are many pixels per image. Besides its computational complexity, predicting pixel values makes little sense! For instance, if the predicted image is shifted left or right by one pixel, the model would suffer a high loss despite getting the image semantics correctly. Also, if the model predicts the image semantics correctly but with wrong pixel values (a green apple instead of a yellow apple), the model would suffer a high loss unfairly.
To tackle these three challenges, the paper
- Uses a ViT model— a Transformer-based model — which has been gaining momentum in vision.
- Masks a high portion of random patches. This reduces redundancy and creates a challenging self-supervisory task that requires a holistic understanding beyond low-level image statistics.
- Leverages a lightweight decoder to reduce the computational complexity of predicting many pixels.
Fig. 2 presents the MAE architecture and highlights the paper’s key three ideas: (1) use a ViT encoder, (2) mask many patches, and (3) leverage a lightweight decoder.
MAE uses the mean squared error (MSE) loss to predict the masked pixels. By reconstructing the pixel RGB values for each masked patch, MAE can produce RGB images that serve as a sanity check and deliver qualitative evaluation results as shown in Fig. 3.
Despite its aesthetic outputs, using RGB pixels as a target is not ideal. If the model predicts the image semantics correctly but with wrong RGB values (two right-most examples in Fig. 2), the model would suffer a high loss unfairly. Accordingly, the paper explores other reconstruction targets. For instance, the paper evaluates normalized pixel values — for each masked patch — as a reconstruction target. Specifically, every 16x16 patch is standardized using the mean and standard deviation of all pixels in a patch. Tab. 1 shows that using normalized pixels improves representation quality.
Besides its simple loss function, MAE is designed to be computationally efficient during pre-training. By masking a high portion (e.g., 75%) of the image, the MAE encoder (ViT) processes a small portion of the image. Thus, very large encoders (e.g., ViT-Huge) can be pre-trained with only a fraction of compute and memory. Tab. 2 shows that processing masked tokens by the encoder not only degrades performance but also increases the computational cost (FLOPs).
On top of MAE’s efficient encoder, MAE leverages a lightweight decoder to reconstruct the masked pixels. The proposed decoder processes the entire image, so its efficiency is vital for MAE. Accordingly, the authors propose a shallow and thin decoder, i.e., a small number of blocks with a small embedding dimension as shown in Tab. 3. MAE’s decoder has less than 10% computation per token compared to the encoder.
Simplicity is a key feature in MAE. Instead of proposing a fancy masking strategy, the paper leverages uniform random masking. In addition, the paper leverages simple augmentation techniques (e.g., crop and random resize) during pre-training. Tab. 4 presents quantitative evaluations for these simple features compared to other alternatives.
By default, MAE masks 75% of patches. Yet, Fig. 3 shows that MAE supports a large range of masking ratios while achieving SOTA performance.
By masking a large portion of the input image, MAE achieves two goals: (1) it largely reduces redundancy and creates a challenging self-supervisory task that requires holistic understanding beyond low-level image statistics, (2) it reduces the wall-clock time to pre-train a given architecture for a given number of epochs as shown in Tab. 5.
MAE is quantitatively evaluated against SOTA pre-training methods (e.g., DINO) as shown in Tab. 6.
Fig. 4 compares MAE with fully-supervised methods using large datasets (e.g., JFT300M). It is worth noting that all these evaluations leverage end-to-end fine-tuning. Yet, linear probing evaluations are reported for hyper-parameters tuning only, i.e., tuning the decoder depth and width. Thus, the paper rarely reports linear probing evaluations against SOTA methods. I will address this point further at the end of the article.
Finally, the paper evaluates MAE pre-training using object detection in Tab.7, semantic segmentation in Tab.8, and transfer learning in Tab. 9. MAE achieves competitive performance on all these benchmarks. Further evaluations are reported in the paper.
My Comments
- [S] This is a well-written and presented paper. The MAE approach is simple and a great starting point for those interested in self-supervised learning. Kudos for releasing the code and pre-trained checkpoints.
- [S] By masking 75% of an input image, MAE is computationally cheap, i.e., works on small GPUs. In addition, MAE’s loss function is independent of the batch size. So, MAE works with small batch sizes and there is no need to synchronize features/losses across GPUs, i.e., no need for distributed data-parallel tricks (e.g., gather/reduce).
- [W] Indeed, MAE is both computationally cheap and batch-size independent which makes it an ideal approach for self-supervised learning. Unfortunately, MAE needs a large number of epochs during pre-training. The paper [1] used 800 epochs by default and pre-trained some models for 1600 epochs. This large number of epochs compensates for the large masking ratio (e.g., 75%). If 75% of inputs are masked during a pre-training epoch, a pre-trained model sees 25% of the dataset within this single epoch. Accordingly, four (4) epochs are required to “see” the entire dataset once during pre-training. In other words, four (4) epochs of MAE are equivalent to a single epoch of other pre-training approaches.
- [W] The proposed MAE [1] is strictly entangled with ViT models. Further modifications are required to enable MAE with CNN models. Fortunately, follow-up literature [2, 3] addresses this problem.
- [W] The proposed MAE [1] struggles with linear probing evaluations. Accordingly, the paper delivers fine-tuning evaluations only. Furthermore, the authors argue against linear probing evaluations citing [5] that “linear probing is not well correlated with transfer learning performance, e.g., for object detection.” Fortunately, a follow-up paper [4] proposes a contrastive loss term that boosts MAE’s linear probing performance.
- [W] I don’t think this is how humans (e.g., babies) learn. Yet, no one cares about this in 2023.
References
- Masked autoencoders are scalable vision learners. CVPR 2022
- ConvMAE: Masked Convolution Meets Masked Autoencoders.
- ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders.
- A simple, efficient and scalable contrastive masked autoencoder for learning visual representations
- Masked Autoencoders Are Scalable Vision Learners