Sharpness-Aware Minimization for Efficiently Improving Generalization

Figure 1: Two different minima with the same loss value but different sharpness (curvature). Left Minimum: A local minimum with high curvature, and expected high generalization error. Right Minimum: A local minimum with low curvature, and expected low generalization error. L_s(w) is the loss value on the training dataset. L_s(w+ϵ) is a surrogate for the loss curvature.
Figure 2: SAM two-step optimization from two different perspectives. (Left) SAM visualization in 2D where the red-curve is the loss curve, and arrows denote the gradient (Right) SAM visualization in 3D surface. Starting at the yellow circle, SGD would reach the red circle after one optimization step, while SAM would reach the green circle.
Table 1: SAM evaluation using state-of-the-art models on CIFAR-{10, 100} (WRN = WideResNet;
AA = AutoAugment; SGD is the standard non-SAM procedure used to train these models).
Table 2: SAM evaluation using ResNets on ImageNet.
Table 3: Top-1 error rates for finetuning EfficientNet on various downstream tasks
Table 4: Test accuracy on a clean test split for models trained on CIFAR10 with noisy labels.
Table 5: SAM evaluation on two architectures: Compact-ResNet18 and Overparameterized-WideResNet50. Minimal augmentation is used in this evaluation. SAM brings marginal improvement to ResNet18 but boosts performance significantly on WideResNet50.

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store