3/26 Optimizers
Now that we have
- MLPs, a class of models that can approximate any function to arbitrary precision and
- PBT, a hyperparameter optimization method that can find hyperparameter schedules on its own,
let's consider optimization algorithms.
SGD
So far, we have been using SGD for this. A SGD train step is
$$ \theta \leftarrow \theta - \eta\nabla_\theta. $$ where:
- \(\theta\in\mathbf R^M\) is the collection of parameters that we optimize.
- \(\nabla_\theta\in\mathbf R^M\) is the gradient of the parameters when we evaluate the loss function on a minibatch
- \(\eta\in\mathbf R\) is a hyperparameter, the learning rate.
Stateful Optimizers
The optimizers we shall consider from now on have parameters of their own, that is they are stateful. We shall now formalize this. Note that there can be optimizers that don't follow this pattern.
- First, we initialize the optimizer parameters with \(\mathtt{init}(\theta)\).
- Then a training step \(\mathtt{step}(\theta, \nabla_\theta)\) is a sequence of the following functions:
- We update optimizer parameters with \(\mathtt{update\_state}(\theta,\nabla_\theta)\).
- Afterwards, we get an update vector \(\Delta\theta\) with \(\mathtt{get\_update}(\theta,\nabla_\theta)\).
- Finally, we apply the update vector with \(\mathtt{apply\_update}(\theta,\Delta\theta)\).
Example: SGD
To recast SGD in this framework:
- As SGD is a stateless optimizer, the functions \(\mathtt{init}\) and \(\mathtt{update\_state}\) don't do anything.
- The gradient descent displacement is the update vector: \(\mathtt{get\_update}(\theta,\nabla_\theta)=-\eta\nabla_\theta\)
- In \(\mathtt{apply\_update}\), we perform the in-place operation \(\theta\leftarrow \theta + \Delta\theta\). In all our examples, this function will stay the same.
\(L_p\) Regularization and Weight Decay
Regularization methods help prevent overfitting. We'll consider two such methods that can be used via a modification to the optimization algorithm.
\(L_p\) Regularization
We can modify the loss function by penalizing large parameter values:
$$ \ell \leftarrow \ell(f_\theta(\mathbf x), \mathbf y) + \lambda_p\frac{1}{p}|\theta|_p^p. $$ Here:
- We have a new hyperparameter: the \(L_p\) regularization coefficient \(\lambda_p\).
- The notation \(\|\theta\|_p\) stands for the \(p\)-norm
This modifies the gradient as follows:
where
- \(\odot\) denotes taking powers componentwise and
-
\(\mathrm{sgn}\) denotes taking the sign function:
\[ \mathrm{sgn}(x)=\begin{cases} 1 & x > 0,\\ 0 & x = 0,\\ -1 & x < 0 \end{cases} \]values componentwise.
\(L_1\) regularization for sparse parameters
In particular, if \(p=1\), then we modify the gradient as follows:
$$ \nabla_\theta\leftarrow\nabla_\theta+\lambda_p\mathrm{sgn}(\theta)\mathbf I\theta $$ where
the so-called indicator function, is applied componentwise.
This leads to sparse parameters. Besides regularization, this is also useful for interpretability: with less nonzero parameters, it is easier to figure out, based on what formulas does the model make its decisions.
Weight Decay [1]
In weight decay, we penalize large parameter absolute values by modifying the update vector: $$ \Delta\theta\leftarrow\Delta\theta - \lambda\eta\theta. $$ The hyperparameter \(\lambda\) is the weight decay rate.
\(L_2\) Regularization = Weight Decay for SGD
In case of SGD, with \(\lambda_2=\lambda\), \(L_2\) regularization is the same as weight decay: both lead to the update $$ \theta\leftarrow\theta-\eta(\nabla_\theta-\lambda\theta)=\theta-\lambda\eta\theta-\eta\nabla_\theta. $$ With the more advanced optimizer we will introduce today, this will not be the case anymore.
SGD with Momentum [2]
One can smoothen and potentially accelerate the progress of the parameter vector \(\theta\) down the loss landscape by accumulating displacements with momentum. We denote the accumulated velocity vector by \(\mathbf v\in\mathbf R^M\).
init: We let \(\mathbf v=\boldsymbol0\).update_state: We let \(\mathbf v\leftarrow\mu\mathbf v - \eta\nabla_\theta\). Here, we have a new hyperparameter \(\mu\): the momentum coefficient.get_update: The update vector is the momentum vector \(\mathbf v\).
Adaptive Moment Estimation (Adam) [3]
The approach of Adam is to track moving averages
- \(\mathbf m\) of the gradient: biased first moment estimate and
- \(\mathbf v\) of the elementwise square of the gradient: biased second raw moment, that is the uncentered variance estimate
and take descent steps with \(\frac{\mathbf m}{\sqrt{\mathbf v}}\) instead of the gradients. The idea is that besides smoothening the optimization path via \(\mathbf m\), dividing by \(\sqrt{\mathbf v}\) gives sparse signals large updates.
init: We let \(\mathbf m,\mathbf v=\boldsymbol0\). We also track the step id \(t=0\).update_state: We let $$ \mathbf m \leftarrow \beta_1\mathbf m + (1 - \beta_1)\nabla_\theta $$ and $$ \mathbf v \leftarrow \beta_2\mathbf v + (1 - \beta_2)(\nabla_\theta)^{\odot2}. $$ The hyperparameters \(\beta_1\) and \(\beta_2\) are (exponential) moving average decay rates.get_update: We let $$ \hat{\mathbf m}\leftarrow\frac{\mathbf m}{1-\beta_1^t} $$ and $$ \hat{\mathbf v}\leftarrow\frac{\mathbf v}{1-\beta_2^t}, $$ with which the update vector is $$ -\eta\frac{\hat{\mathbf m}}{\sqrt{\hat{\mathbf v}} + \epsilon}. $$ We use the adjusted first and second moment estimates to counteract the bias caused by zero initialization. The hyperparameter \(\epsilon\) is used for numerical stability.
The paper recommends setting \(\eta=10^{-3}\), \(\beta_1=0.9\), \(\beta_2=0.999\) and \(\epsilon=10^{-8}\). We'll tune these hyperparameters using PBT, but will base the initial and perturbation distributions on these defaults.
A testament to the importance of hyperparameter tuning: to prove that the Adam algorithm converges, you need \(\beta_2 < 1\) to be big enough in a task-dependent manner and \(\beta_1 < \sqrt{\beta_2}\) [4].
If we add weight decay, then the optimizer is colloquially referred to as AdamW for Adam with Weight decay [5].
References
[1] Stephen José Hanson and Lorien Y Pratt, Comparing biases for minimal network construction with back-propagation, Proceedings of the 1st International Conference on Neural Information Processing Systems (NIPS), pp. 177–185, 1988, link
[2] Boris T. Polyak, Some methods of speeding up the convergence of iteration methods, USSR Computational Mathematics and Mathematical Physics, 4(5):1–17, 1964, link
[3] Diederik Kingma and Jimmy Ba, Adam: A method for stochastic optimization, link
[4] Yushun Zhang, Congliang Chen, Naichen Shi, Ruoyu Sun, Zhi-Quan Luo: Adam Can Converge Without Any Modification On Update Rules, 2022. Advances in Neural Information Processing Systems (NIPS) 35, pp. 28386--28399. link
[5] Ilya Loshchilov and Frank Hutter, Decoupled Weight Decay Regularization, International Conference on Learning Representations (ICLR), 2019, link