Sharpness-Aware Minimization for Efficiently Improving Generalization
For training a deep network, picking the right optimizer has become an important design choice. Standard optimizers (e.g., SGD, Adam, etc.) seek a minimum on the loss curve. This minimum is sought without regard for the curvature, i.e., the 2nd degree derivative of the loss curve. A curvature denotes the curve flatness; low curvature means a flat curve while high curvature means a sharp curve. This paper [1] proposes SAM, an efficient optimizer that seeks wide minima. SAM simultaneously minimizes the loss value and the loss sharpness.
There is a connection between the geometry of the loss landscape and the generalization of a trained network. In Figure 1, there are two minima with the same loss value. Yet, one minimum (Left) has a high curvature, while the second (right) has a small curvature. These two minima are equally good for standard optimizers (e.g., SGD). Both minima have training loss equal to zero, i.e., L_s(w)=0. Yet, SAM seeks the local minimum on the right.
While standard optimizers employ gradient descent, SAM employs both gradient ascent and descent. Accordingly, SAM is a two-step optimizer. Starting at point A, SAM computes the direction of gradient ascent (d1). From point A, SAM takes a step of size ρ in this direction (d1) to reach point B; this is the first step. At point B, SAM computes the direction of gradient descent (d2). Finally, starting from point A, SAM takes a step of size η (learning rate) in the direction d2. Thus, ρ is an extra learning rate that SAM requires. Figure 2 illustrates this procedure from two perspectives.
If the starting point A is on a sharp loss curve, SAM is likely to get out of this curve and to seek an alternative minimum. In contrast, if the starting point A is on a flat loss curve, SAM will descent normally — as SGD — because both points A and B will have a similar gradient.
SAM is evaluated using computer vision datasets. Tables 1 and 2 present a quantitative evaluation using randomly initialized networks trained on CIFAR-{10,100} and ImageNet, respectively.
Tab. 3 evaluates SAM using pretrained networks — with ImageNet weights — fine-tuned on small datasets (e.g., FGVC Aircraft, Flowers, etc.)
Finally, SAM is evaluated using a noisy CIFAR10 dataset, in which a fraction of the training set’s labels is randomly flipped. Tab. 4 presents a quantitative evaluation of SAM against noise-robust approaches (e.g., Bootstrap). SAM delivers competitive performance against these noise-specific approaches.
My Comments
- [W] I wish the paper emphasized that SAM achieves significant improvement when coupled with over-parameterized networks, and not every network. I found this empirically by evaluating SAM on two networks ResNet18 and WideResNet50. Improvement margins are marginal on ResNet18, but significant on WideResNet50 as shown in the following table. A similar finding has been reported in [3] — “the degree of improvement negatively correlates with the level of inductive biases built into the architecture.”
- [S] The thing I like most about this paper is that it reminds me that we use gradient descent because it is computationally feasible, not because it is optimal. In 2022, we can only compute the 1st derivative with respect to the network’s weights, while the 2nd derivative (curvature) is computationally prohibitive. Once this computational limitation is unlocked, more can be explored [2].
- For those interested in the topic, A follow-up paper [4] proposes adaptive SAM (ASAM). Developed by Samsung Research, ASAM makes tuning the ρ hyperparameter easier. I found ASAM useful when having a multi-stage training procedure (pre-training -> fine-tuning). A fixed ρ achieved inferior performance with SAM, but superior performance with ASAM.
References
[1] Foret, P., Kleiner, A., Mobahi, H. and Neyshabur, B., 2020. Sharpness-aware minimization for efficiently improving generalization. arXiv preprint arXiv:2010.01412.
[2] LeCun, Y., Denker, J.S. and Solla, S.A., 1990. Optimal brain damage. In Advances in neural information processing systems.
[3] Chen, X., Hsieh, C.J. and Gong, B., 2021. When Vision Transformers Outperform ResNets without Pretraining or Strong Data Augmentations. arXiv preprint arXiv:2106.01548.
[4] Kwon, J., Kim, J., Park, H. and Choi, I.K., 2021. ASAM: Adaptive Sharpness-Aware Minimization for Scale-Invariant Learning of Deep Neural Networks. arXiv preprint arXiv:2102.11600.