Sharpness-Aware Minimization for Efficiently Improving Generalization

Figure 1: Two different minima with the same loss value but different sharpness (curvature). Left Minimum: A local minimum with high curvature, and expected high generalization error. Right Minimum: A local minimum with low curvature, and expected low generalization error. L_s(w) is the loss value on the training dataset. L_s(w+ϵ) is a surrogate for the loss curvature.
Figure 2: SAM two-step optimization from two different perspectives. (Left) SAM visualization in 2D where the red-curve is the loss curve, and arrows denote the gradient (Right) SAM visualization in 3D surface. Starting at the yellow circle, SGD would reach the red circle after one optimization step, while SAM would reach the green circle.
Table 1: SAM evaluation using state-of-the-art models on CIFAR-{10, 100} (WRN = WideResNet;
AA = AutoAugment; SGD is the standard non-SAM procedure used to train these models).
Table 2: SAM evaluation using ResNets on ImageNet.
Table 3: Top-1 error rates for finetuning EfficientNet on various downstream tasks
Table 4: Test accuracy on a clean test split for models trained on CIFAR10 with noisy labels.
  • [W] I wish the paper emphasized that SAM achieves significant improvement when coupled with over-parameterized networks, and not every network. I found this empirically by evaluating SAM on two networks ResNet18 and WideResNet50. Improvement margins are marginal on ResNet18, but significant on WideResNet50 as shown in the following table. A similar finding has been reported in [3] — “the degree of improvement negatively correlates with the level of inductive biases built into the architecture.
Table 5: SAM evaluation on two architectures: Compact-ResNet18 and Overparameterized-WideResNet50. Minimal augmentation is used in this evaluation. SAM brings marginal improvement to ResNet18 but boosts performance significantly on WideResNet50.
  • [S] The thing I like most about this paper is that it reminds me that we use gradient descent because it is computationally feasible, not because it is optimal. In 2022, we can only compute the 1st derivative with respect to the network’s weights, while the 2nd derivative (curvature) is computationally prohibitive. Once this computational limitation is unlocked, more can be explored [2].
  • For those interested in the topic, A follow-up paper [4] proposes adaptive SAM (ASAM). Developed by Samsung Research, ASAM makes tuning the ρ hyperparameter easier. I found ASAM useful when having a multi-stage training procedure (pre-training -> fine-tuning). A fixed ρ achieved inferior performance with SAM, but superior performance with ASAM.

--

--

--

I write reviews on computer vision papers. Writing tips are welcomed.

Love podcasts or audiobooks? Learn on the go with our new app.

Recommended from Medium

Image Classification using Deep Learning & PyTorch: A Case Study with Flower Image Data

Digital Signal Processing in One Lesson

Neural Ordinary Differential Equations: Major Breakthrough in Neural Network Research

Use Mask-RCNN to do Object Segmentation

Classification Vs. Clustering in ML

Localization of Manipulated Image Regions

How Machine Learning is shaping our future

Combining Global and Local Deep Learning Architectures for Art Mediums Classifications Tasks

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Ahmed Taha

Ahmed Taha

I write reviews on computer vision papers. Writing tips are welcomed.

More from Medium

Review — Equalized Focal Loss for Dense Long-Tailed Object Detection

Review — ResNet Strikes Back: An Improved Training Procedure in timm

Warp speed model training in PADL with PyTorch-Lightning

PADL + PyTorch-Lightning provides additional convenience, comfort and speed

Paper of Choice: Image Generation From Small Datasets via Batch Statistics Adaptation