Skip to content

Commit 4c82722

Browse files
Update 2020-08-08-pytorch-1.6-now-includes-stochastic-weight-averaging.md
Updating the Author and image path for images 2,6 and 8
1 parent ef47fd4 commit 4c82722

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

_posts/2020-08-08-pytorch-1.6-now-includes-stochastic-weight-averaging.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
---
22
layout: blog_detail
33
title: 'PyTorch 1.6 now includes Stochastic Weight Averaging'
4-
author: Pavel Izmailov and Andrew Gordon Wilson
4+
author: Pavel Izmailov, Andrew Gordon Wilson and Vincent Queneneville-Belair
55
---
66

77
Do you use stochastic gradient descent (SGD) or Adam? Regardless of the procedure you use to train your neural network, you can likely achieve significantly better generalization at virtually no additional cost with a simple new technique now natively supported in PyTorch 1.6, Stochastic Weight Averaging (SWA) [1]. Even if you have already trained your model, it’s easy to realize the benefits of SWA by running SWA for a small number of epochs starting with a pre-trained model. [Again](https://twitter.com/MilesCranmer/status/1282140440892932096) and [again](https://twitter.com/leopd/status/1285969855062192129), researchers are discovering that SWA improves the performance of well-tuned models in a wide array of practical applications with little cost or effort!
@@ -37,7 +37,7 @@ By contrast, SWA uses an **equal average** of SGD iterates with a modified **cyc
3737
There are two important ingredients that make SWA work. First, SWA uses a **modified learning rate** schedule so that SGD (or other optimizers such as Adam) continues to bounce around the optimum and explore diverse models instead of simply converging to a single solution. For example, we can use the standard decaying learning rate strategy for the first 75% of training time and then set the learning rate to a reasonably high constant value for the remaining 25% of the time (see Figure 2 below). The second ingredient is to take an average of the weights **(typically an equal average)** of the networks traversed by SGD. For example, we can maintain a running average of the weights obtained at the end of every epoch within the last 25% of training time (see Figure 2). After training is complete, we then set the weights of the network to the computed SWA averages.
3838

3939
<div class="text-center">
40-
<img src="{{ site.url }}/assets/images/swa/figure2-highres.png" width="100%">
40+
<img src="{{ site.url }}/assets/images/nswapytorch2.jpg" width="100%">
4141
</div>
4242

4343
**Figure 2**. *Illustration of the learning rate schedule adopted by SWA. Standard decaying schedule is used for the first 75% of the training and then a high constant value is used for the remaining 25%. The SWA averages are formed during the last 25% of training*.
@@ -225,7 +225,7 @@ In another follow-up [paper](http://www.gatsby.ucl.ac.uk/~balaji/udl-camera-read
225225
We can filter through quantization noise by combining weights that have been rounded down with weights that have been rounded up. Moreover, by averaging weights to find a flat region of the loss surface, large perturbations of the weights will not affect the quality of the solution (Figures 9 and 10). Recent [work](https://arxiv.org/abs/1904.11943) shows that by adapting SWA to the low precision setting, in a method called SWALP, one can match the performance of full-precision SGD even with all training in 8 bits [5]. This is quite a practically important result, given that (1) SGD training in 8 bits performs notably worse than full precision SGD, and (2) low precision training is significantly harder than predictions in low precision after training (the usual setting). For example, a ResNet-164 trained on CIFAR-100 with float (16-bit) SGD achieves 22.2% error, while 8-bit SGD achieves 24.0% error. By contrast, SWALP with 8 bit training achieves 21.8% error.
226226

227227
<div class="text-center">
228-
<img src="{{ site.url }}/assets/images/swapytorch6.png" width="100%">
228+
<img src="{{ site.url }}/assets/images/nswapytorch6.png" width="100%">
229229
</div>
230230
**Figure 9**. *Quantizing a solution leads to a perturbation of the weights which has a greater effect on the quality of the sharp solution (left) compared to wide solution (right)*.
231231

@@ -244,7 +244,7 @@ By finding a centred solution in the loss, SWA can also improve calibration and
244244
SWA can be viewed as taking the first moment of SGD iterates with a modified learning rate schedule. We can directly generalize SWA by also taking the second moment of iterates to form a Gaussian approximate posterior over the weights, further characterizing the loss geometry with SGD iterates. This approach,[SWA-Gaussian (SWAG)](https://arxiv.org/abs/1902.02476) is a simple, scalable and convenient approach to uncertainty estimation and calibration in Bayesian deep learning [4]. The SWAG distribution approximates the shape of the true posterior: Figure 6 below shows the SWAG distribution and the posterior log-density for ResNet-20 on CIFAR-10.
245245

246246
<div class="text-center">
247-
<img src="{{ site.url }}/assets/images/swapytorch8.jpg" width="100%">
247+
<img src="{{ site.url }}/assets/images/nswapytorch8.png" width="100%">
248248
</div>
249249
**Figure 6**. *SWAG posterior approximation and the loss surface for a ResNet-20 without skip-connections trained on CIFAR-10 in the subspace formed by the two largest eigenvalues of the SWAG covariance matrix. The shape of SWAG distribution is aligned with the posterior: the peaks of the two distributions coincide, and both distributions are wider in one direction than in the orthogonal direction. Visualization created in collaboration with* [Javier Ideami](https://losslandscape.com/).
250250

0 commit comments

Comments
 (0)