Noah Golmant

Fine-tuned Noise

When we decide to train a cool machine learning model, the most common machinery to get us there is stochastic gradient descent. From its humble beginnings as “cheapskate gradient descent”, SGD blows away its vanilla counterpart on highly non-convex loss surfaces. To me, the fact that such a simple hill-climbing technique has dominated modern non-convex optimization is fascinating. However, the presence of spurious local minima and saddle points makes analysis more complex. It’s important to understand how our intuitions about SGD dynamics can change when we remove our classical convex assumptions. The jump to the non-convex setting invites the use of frameworks like dynamical systems theory and stochastic differential equations, which provide models for thinking about long-term dynamics and short-term stochasticity in the optimization landscape.

Today I will be talking about what appears at first to be a simple nuisance in the world of gradient descent: noise. The only difference between stochastic gradient descent and vanilla gradient descent is the fact that the former uses a noisy approximation of the gradient. The structure of this noise ends up being the driving force behind the “exploratory” behavior of SGD for non-convex problems. I’ll talk a bit about the structure of this noise, connect it with some ideas from information geometry and stochastic nonlinear systems, and conclude with a discussion on some practical implications for these results.

The covariance structure of mini-batch noise

Let’s introduce our problem setup. Suppose I would like to minimize a loss function \(f: \mathbb{R}^n \rightarrow \mathbb{R}\) with a finite dataset consisting of \(N\) examples. For parameters \(x \in \mathbb{R}^n\), I will call the loss on the \(i\)th example \(f_i(x)\). Now, \(N\) is presumably quite large, so we’ll estimate the dataset gradient \(g_N = \frac1N \sum_{i=1}^N \nabla f_i(x)\) with a mini-batch estimate \(g_B = \frac1m \sum_{i \in B} \nabla f_i(x)\), where \(B \subseteq \{1,2,\ldots,N\}\) is a mini-batch of size \(m\). Although \(g_N\) is itself a noisy estimate of \(\nabla f(x)\), it turns out that mini-batch sampling produces an estimate with an interesting covariance structure.

Lemma 1 (Chaudhari & Soatto): When sampling with replacement, the variance of the mini-batch gradient for batch size \(m\) is given by \(\text{Var}(g_B) = \frac1m D(x)\), where

\[D(x) = \frac1N \sum_{i=1}^N \nabla f_i(x) \nabla f_i(x)^T - \nabla f(x) \nabla f(x)^T\]

What does this result mean? Well, let’s put ourselves in-context. In a lot of optimization problems, our underlying goal is to maximize the likelihood of some configuration of the parameters. So our loss is a negative log-likelihood. For classification problems, this is cross-entropy. In this case, the first term \(\hat{I}(x) = \sum_{i=1}^N \nabla f_i(x) \nabla f_i(x)^T\) is an estimate of (minus) the covariance of the gradient of the log-likelihood. This is the observed Fisher information. As \(N \rightarrow \infty\), this approaches the Fisher information matrix. This is precisely the Hessian of the relative entropy (KL divergence). But the KL divergence is a constant factor away from the cross-entropy loss (negative log-likelihood) we were attempting to minimize.

Hence, the covariance of mini-batch noise is asymptotically related to the Hessian of our loss. In fact, when \(x\) approaches a local minimum, the covariance approaches a scaled version of the Hessian.

A detour into the Fisher information

Before I continue a bit into SGD-specific analysis, let’s take a moment to consider the connection between the Fisher information matrix \(I(x)\) and the Hessian \(\nabla^2 f(x)\). \(I(x)\) is the variance of the log-likelihood gradient. How does variance relate to the curvature of the loss surface? Suppose we’re at a strict local minimum of \(f\), i.e. \(I(x^*) = \nabla^2 f(x^*)\) is positive definite. \(I(x)\) induces a metric called the Fisher-Rao metric near \(x^*\): \(d(x,y) = \sqrt{(x-y)^T I(x^*)(x-y)}\). Interestingly, the Fisher-Rao norm of the parameters provides an upper bound on the generalization error. This means we can be more confident about the generalization ability of flatter minima.

Back to the story

This immediately introduces some interesting conjectures about the dynamics of SGD. Let’s make a CLT-type assumption for a second and suppose we can decompose our estimate \(g_B\) into the “true” dataset gradient and a noise term: \(g_B = g_N + \frac{1}{\sqrt{B}}n(x)\), where \(n(x) \sim \mathcal{N}(0, D(x))\). Moreover, for simplicity, assume we’re close to a minimum so that \(D(x) \approx \nabla^2 f(x)\). \(n(x)\) has a density \(\rho (z)\) with a quadratic form in the \(\exp\) argument:

\[\log \rho (z) \propto -z^T D^{-1}(x) z \approx -z^T (\nabla^2 f(x))^{-1} z\]

This makes it clear that the eigenvalues of the Hessian play an important role in determining the sort of minima considered “stable” by SGD. When I’m at a sharp minimum where there are many large, positive eigenvalues, I am more likely to add noise that “pushes me out” of what would have been a basin of attraction in the vanilla gradient case. Similarly, for flat minima, I am more likely to “settle down.” We can make this precise using the following trick:

Lemma 2: Let \(v \in \mathbb{R}^n\) be a random vector with mean \(0\) and covariance \(D\). Then \(\mathbb{E}\left[ \|v\|^2 \right ] = \text{Tr}(D)\).

By using this lemma and Markov’s inequality for our picture, we can see that larger perturbations are more likely when the Hessian has high curvature. We can also consider a “radius of stability” around a local minimum \(x^*\): for a fixed \(\epsilon \in (0,1)\), there exists some \(r(x^*) > 0\) such that if our starting point \(x_0\) satisfies \(\|x_0 - x^*\| < r(x^*)\), the probability that the \(t\)-th iterate satisfies \(\|x_t - x^*\| < r\) for all \(t \geq 0\) is at least \(1 - \epsilon\). In this case, we can say \(x^*\) is stochastically stable with radius \(r(x^*)\). Combining this notion of stability with our previous informal argument, we get the following picture:

Micro-theorem 1: the radius of stability \(r(x^*)\) for a strict local minimum \(x^*\) is inversely proportional to the spectral radius of \(\nabla^2 f(x^*)\).

Let’s connect this to what we know about the Fisher information. If flatter minima are more stable under the dynamics of SGD, this means that SGD implicitly provides a form of regularization. It does this by injecting anisotropic noise to push us out of regions where the Fisher-Rao norm hints at unfavorable generalization conditions.

Implications for deep learning: degeneracy of the Hessian and “wide valleys”

An interesting phenomenon within deep learning is overparameterization. We often have many more parameters than we do examples (\(d >> N\)). In this case, \(D(x)\) is highly degenerate, i.e. it has a lot of zero (or near-zero) eigenvalues. This means there are a lot of directions along which the loss function is locally constant. This paints an interesting picture of the optimization landscape for these networks: SGD spends most of the time traversing “wide valleys.” The noise is spread out along the few directions of higher curvature, which counteracts \(g_N\)’s push towards the bottom of this valley.

Contemporary concerns: batch size, learning rate, and the generalization gap

Since we scale \(n(x)\) by a factor of \(\frac{1}{\sqrt{m}}\) before adding it to our vanilla gradient, increasing the batch size scales down the overall variance of our mini-batch estimate. This is a problem, since large batch sizes enable faster training of models at scale. It’s faster in two important ways: training error tends to converge in fewer gradient updates, and large batches let us take advantage of data parallelism at scale. However, increasing the batch size without any tricks can cause the test error to increase. This phenomenon is known as the “generalization gap” and there are a few working hypotheses for why this might be the case. One popular explanation is that our “exploratory noise” is no longer powerful enough to push us out of the basin of attraction of a sharp minimum. One solution is to simply scale up the learning rate to increase the contribution of this noise. These scaling rules have been pretty successful.

Long-term concerns: escaping saddle points

Although the generalization gap has become a popular topic of discussion lately, there is also a lot of prior work studying how saddle points affect things. Although we don’t converge to saddle points asymptotically, we can still get stuck near them for a really long time. And even though larger batch sizes seem to encourage sharper minima, really large batch sizes lead us to deterministic trajectories that get stuck near saddle points. This work shows that injecting sufficiently large isotropic noise can help us escape saddle points. I’m willing to bet that mini-batch SGD provides enough noise along the problematic dimensions to escape them too if the noise has sufficient “amplification.”

Once we manage to solve the “sharp minima” problem, saddle points seem to be the next big barrier to large-scale optimization. For example, I trained ResNet34 on CIFAR-10 with plain SGD. As I increased the batch size up to 4096, the generalization gap appeared. After this point (I tested up to a batch size of 32k, with 50k training examples), performance degrades significantly: training error and test error both plateau in only a few epochs, and the network fails to converge to a useful solution. Here’s a preliminary (i.e. ugly and in-progress) plot of these results:


Next steps?

Most of the proposed solutions to deal with sharp minima/saddle points revolve around (a) injecting isotropic noise or (b) maintaining a particular “learning rate to batch size” ratio. I don’t think these will be enough in the long-run. Isotropic noise does a poor job of taking into account the “wide valley” structure of our loss landscape. Increasing the learning rate also increases the contribution of the vanilla gradient, which makes our weight update bigger. I think the right approach is to come up with efficient ways to simulate the anisotropy of mini-batch noise in a way that’s “decoupled” from the learning rate and batch size. There are efficient ways to do this using sub-sampled gradient information and the Hessian-vector product, which I’m experimenting with right now. I would love to hear other ideas on how to solve this issue. In the meantime, there’s a lot of theoretical work to be done to understand these dynamics in more detail, especially in a deep learning context.

comments powered by Disqus