Averaging Weights Leads to Wider Optima and Better Generalization
This paper introduces a method that very closely estimates the solution that could have been produced by an ensembling technique, requiring only a single model to work with. And in return, not only provides faster inference performance, but also wider optima, and thus making the model better generalized.
Article Overview
The idea behind any version of an ensembling technique is to use multiple learning algorithms to obtain better predictive performance than what could be obtained from any of the constituent learning algorithms alone. And to be honest, I’m not a huge fan of them. Now such techniques work great for Kaggle competitions, or anywhere where deployment is not a concern. But elsewise, it’s not an easy technique to work with, because the ratio of accuracy improvement versus performance regression is too low. In this article, we’ll discover a technique, following the idea that multiple models can provide better accuracy, with almost zero computational overhead 🙂.
Prerequisites
What’s up with ensembling?
Ensemble methods combine several models to a single predictive model. Basically for one input, you’d make k number of predictions from k different models, and then combine those predictions by let’s say, taking the average, and call that average as your final predictions.
Typically, they are applied for ML algorithms because you could easily have 10 decision trees making 10 predictions. But in DL, having 10 different models with each having over 20 million parameters or so would be a terrible idea. And when it comes to deployment, then there’s no way you can go with any ensemble technique. Even if there’s way out, in DL it still is better and infact easy, to improve a single model.
Such techniques are most heavily used in competitions, in fact, most of the top winners from Kaggle use ensembling techniques one way or the other.
So the issue with these techniques are (mostly for DL models):
- Not easily deployable.
- Takes longer time to train all the models.
- Horrible inference time. (With an ensemble of 10 models you’d have to make 10 predictions 😢)
Width of a local optima
Now we’ll understand why ensembling techniques work in the first place. One way to satisfy ourselves is a simple idea, I mean intuitively its easy to think that, since we have k different models, and each model would interpret the training data differently, so having an average of all interpretations should result in a better-generalized prediction, right? Pretty close, I’d say.
After training any model, it would end up in a local minimum. Now based on the graph of that minimum, we can comment if our model is generalized or simply overfitted. The idea is if the 2nd order derivative for the region where the model has finally converged, is very low, then the model is good, else simply overfitted.
Or we can say if the local region where the model has finally converged is very wide or flat, then the model is good, else overfitted. I know it is not very clear right now, let me laymanize this.
For a well-generalized model, we can say that, with a slight change in the inputs, the outputs shouldn’t change much. If the outputs, in this case, is changing too much, then yes your model is just simply overfitted to that input.
One more idea I’d like to present before visualizing the width of an optimum.
The idea is that training dataset and testing dataset or for any other data that corresponds to the particular data we are dealing with, will produce similar loss surfaces, at least for the region where the model has finally converged to.
So the shift in the loss surfaces for training data and testing data wouldn’t be much, and you can simply imagine the loss surface for the testing dataset be slightly shifted from that of training data.
And now, the figure above should make a perfect sense to you. The blue line is the loss surface for training data and the green one is for testing. First, have a look at the sharp minimum, with a slight change in the training data, which you can think as a test sample, the model’s outputs are changing too much. However, for the flat minimum, that is of course not the case. And this is why flat or wider minima are better than sharp ones.
Why ensembling techniques works like a charm
Please try not to read the next paragraph for a few minutes, if you can now come up with an explanation yourself then it would be awesome, it is doable for sure, take your time.
In case of an ensembled model, since we have trained more than one model, so we would basically have more than 1 blue line in the above figure, hence covering quite a lot of regions. So even if few models are sort of poorly trained and overfitted (sharp minimum), but still, due to the regions covered by other models, the average of the blue lines is much more wider than any single of them.
And when most of the models are better generalized, then using all of them together actually makes the already good outputs even better.
So ensembling techniques work, because usage of more than one model minimizes the possibility of getting a sharp minimum after training, and hence we get a better generalized model.
The general explanation for the importance of width is that the surfaces of train loss and test error are shifted with respect to each other and it is thus desirable to converge to the modes of broad optima, which stay approximately optimal under small perturbations.
Introducing the paper
Sorry if I explained too much in the prerequisites above 😅. Let’s have a look our paper now. The idea behind this paper is actually pretty intuitive and easy to understand.
Instead of training k models as in case of any ensembling technique, this method only requires a single model and still provides us with a flatter or wider minimum after the training. So essentially this paper solves all the issues that I discussed above.
We show that simple averaging of multiple points along the trajectory of SGD, with a cyclical or constant learning rate, leads to better generalization than conventional training. We also show that this Stochastic Weight Averaging (SWA) procedure finds much flatter solutions than SGD, and approximates the recent Fast Geometric Ensembling (FGE) approach with a single model.
FGE is one of the variants of ensembling techniques, which I’m not gonna talk about in details, but it essentially provides a solution to train an ensembling model faster. FGE authors found that the local optima can be connected by a curve or a two line segments, so after discovery 2 optima, we can allow the different models to move along that curve to discover a completely new optimum within few epochs only, and ultimately speeding up the training. But again, you’d have to make multiple outputs for a single input to get the average of them.
The Nub
From the very beginning of any ensembling technique, we wanted to produce a system with flatter/wider optimum. The paper we are discussing provides us a methodology, to train a single model which can converge on a flatter minima, rather than using multiple models.
The authors essentially ask this question: Instead of using multiple models, to obtain multiple loss surfaces, to get a wider minimum, why not find a single weight space that corresponds to that?
I mean if we can somehow find a model, such that its prediction approximately matches the output from an ensembled model, then it’s just a matter of using that single model during inference, solving all of our problems 🙂.
Sounds interesting? Indeed.
So you get the basic idea, now let’s see how exactly SWA allows us to train a single model, getting ensembling level of accuracy (approximately).
The Working
We show that SGD with cyclical and constant learning rates traverses regions of weight space corresponding to high-performing networks. We find that while these models are moving around this optimal set they never reach its central points. We show that we can move into this more desirable space of points by averaging the weights proposed over SGD iterations. SGD generally converges to a point near the boundary of the wide flat region of optimal points. SWA on the other hand is able to find a point centered in this region, often with slightly worse train loss but with substantially better test error. SWA is extremely easy to implement and has virtually no computational overhead compared to the conventional training schemes.
So as you can see, its pretty straight forward. SWA is based on averaging the weights proposed by SGD using a learning rate schedule that allows exploration of the region of weight space corresponding to high-performing networks.
The cyclic learning rate allows the model to discover many local minima, which are along the curve on the loss surface covering the central location, where the flatter or wider optimum is present. So the average of the parameters corresponding to all those local minima found by each cycle can very well approximate that flatter central location.
And here’s the full working:
Here are the steps layman terms:
- c is the period of the cyclic learning rate schedulers, which is 1 if we are using a constant learning rate.
- W_swa represents the final model, which we’ll keep updating as the training progresses.
- After the end of each cycle, (when the learning rate is minimum for the cycle), the weights of the current model will be used to update the weights of our final model W_swa, by taking the weighted mean between the previous W_swa and the weights of the current model.
The only additional operation which costs extra time, is spend to update the aggregated weight average, and it only requires computing a weighted sum of the weights of two models. Also, since we would have to apply this operation at most once per epoch only, hence SWA and SGD require practically the same amount of computation.
Note
In the figure above, you can see are giving special treatment to the batchnorm layers. To get the global mean and standard deviation for each batchnorm, since the batchnorm running means and standard deviations are not collected during the training, we pass our training data to the model, to compute them for each layer of the network, with W_swa weights after the training is finished.
Results
In the above figure you can see, we got three local optima (from three cycles) along the curve that covers the wider or flatter optimum. And the final model, W_swa, obtained from the algorithm above, is indeed at the center location at the flatter optimum.
In the above figure, its interesting to note, that training loss for W_swa was more than that for a model trained with SGD, but the test error for W_swa is much less compared to SGD. And this once again, confirms the idea that SWA provides us wider optimum, which is much more generalized the less overfitted.
My Observation
After reading the paper, I just couldn’t wait to test out this awesome technique. But the results I got was totally out of my understanding. In my experiment I had a simple Fashion MNIST dataset, trained with a simple 5 layer CNN from scratch.
And here are my results after 15 epochs with multiple runs:
- SGD Testing Accuracy ~ 89%
- SWA Testing Accuracy ~ 86%
Now I thought there’s gotta be some mistake that I made, but it was not the case. So I looked at the paper again 🙂, and found a very important note at the very end from the authors:
While being able to train a DNN with a fixed learning rate is a surprising property of SWA, for practical purposes we recommend initializing SWA from a model pretrained with conventional training (possibly for a reduced number of epochs), as it leads to faster and more stable convergence than running SWA from scratch.
So basically I was not getting good results, because, with just 20 epochs, I was not able to obtain significantly different local optima. And hence averaging the weights was basically a bad idea, since that would make the later weights obsolete.
And following with this in the next experimentation on CIFAR10, I was getting over 2% to 5% of increase in my further experimentations based on transfer learning with SWA.
And here are some of my own suggestions from the experiments:
- Always use a cyclic learning rate.
- Whenever you update the final model W_swa at the end of each cycle, make sure that training loss is almost converged significantly.
- With each update of W_swa, the training loss should not vary too much. For instance at the end of 4th cycle you were getting loss 0.0258, so for the next updates, the loss at the end of each cycle should also be at the range of approximately 0.02 to 0.03 and so.
Also, to make the technique work for giving really good results, you have to know, given your model, your dataset, and your configurations for the cyclic learning rate, after how many training iterations the loss converges. (If you’re using Tensorflow, then you can simply use the EarlyStopping callback to find this.) And after finding that information, set that number of training iterations or epochs as the period of the cyclic learning rate.
Doing this will make you easily follow my 2nd and 3rd advise above.
Conclusion
Since the loss surface for the testing data would very closely resemble that of the training data, thus at the end of the training, we want to obtain a wider or flatter optimum for the training loss, such that the model is better generalized.
In the training with SGD, the training loss misses the wider optima and usually converges to a sharp ones.
SWA uses multiple of these sharp minima, and take out the average of their weight space, which lies in the central location, where a much wider minimum would be present, and thus making the model much more generalized and less overfitted, with almost insignificant training time regression and computational cost.
Thanks for reading.