Noah Golmant

Meta-Learning and Optimization

In this post, I’ll try to delve into meta-learning from an optimization perspective. I’ll be asking some small questions about an interesting meta-learning algorithm called Model-Agnostic Meta-Learning (MAML). The objective of MAML is to be able to adapt to new tasks from only a few examples. If I train a model to do well on a bunch of other tasks, we’d like to come up with something that can do well on a new task, given only a few steps of gradient descent to adapt the parameters a bit. MAML does this by finding parameters such that when you take a gradient descent from those parameters, you are close to the optimum for your task. In the end, we obtain some parameters called a meta-model. If I receive examples for a new task, I can run gradient descent to minimize that task’s loss, initializing the process from the meta-model. If this task is similar to the training tasks I used to produce the meta-model, the “fine-tuned” model I end up with should do pretty well.

I’m interested in studying what the meta-model really is because I think it could provide some interesting insights into how MAML implicitly gauges the similarity between tasks. I also think it could help us make some stronger theoretical statements about how good the meta-model is as a starting point for gradient descent. The only MAML theory I’m aware of focuses on probabilistic interpretations and universality (which are super cool!).

This post will basically be a bit of a case study for MAML where I look into what it does on very simple objectives that are quadratic in the model parameters. I’ll remove the stochasticity from play, focusing on vanilla gradient descent. Then, I’ll derive the fixed point for the MAML gradient descent update equation. There’s a kind of interesting interpretation of this fixed point as a curvature-weighted average of the optima for the objectives. Then I’ll calculate the loss of the fine-tuned models and prove that MAML really does its job. I’ll close with some thoughts about how this picture might relate to the more general case, e.g. when the objectives are strongly convex.

The MAML objective

Let’s start out by stating the MAML objective. Since I’m not thinking about stochasticity right now, a task will just consist of a loss function on a domain \(\mathcal{X}\). So we’ll consider running MAML on two equally important objectives \(f, g: \mathcal{X} \rightarrow \mathbb{R}\). For any objective \(h\) and step size \(\alpha > 0\), I’ll call its gradient descent update \(T_h(x) = x - \alpha \nabla h(x)\). Now we’re ready to state the MAML objective:

\[\underset{x \in \mathcal{X}}{\text{minimize}} \hspace{1ex} L(x) := \frac{1}{2}[(f \circ T_f)(x) + (g \circ T_g)(x)]\]

We can minimize this by running gradient descent on \(L\) with some step size \(\beta > 0\). The gradient is

\[\nabla L(x) = \frac{1}{2}[(I - \alpha \nabla^2 f(x))\nabla f(T_f(x)) + (I - \alpha \nabla^2 g(x))\nabla g(T_g(x)) ]\]

Where \(\nabla^2 f(x)\) is the Hessian of \(f\) at \(x\). And, like usual, we’ll get an update equation \(T_L(x) = x - \beta \nabla L(x)\).

Quadratic Forms

Like I said, we’ll be consider some very simple functions here. Let \(A, B \in \mathbb{R}^{d \times d}\) be symmetric, positive definite matrices, and let \(a, b \in \mathbb{R}^d\). Then we’ll set

\[f(x) = \frac{1}{2} (x-a)^T A (x-a), g(x) = \frac{1}{2} (x-b)^T B (x-b)\]

Clearly \(f\) is minimized at \(a\) while \(g\) is minimized at \(b\). The gradients are \(\nabla f(x) = A(x-a), \nabla g(x) = B(x-b)\). The Hessians are \(A\) and \(B\), respectively. I’m pretty sure this is the simplest setup I could look at with any interesting behavior. For simplicity, let’s define \(D_A = (I - \alpha A)A(I - \alpha A), D_B = (I - \alpha B)B(I - \alpha B)\). Plugging things in, we can calculate

\[\nabla L(x) = D_A (x-a) + D_B (x-b)\]

This is kind of interesting, because the gradient looks like the sum of the gradients of some modified quadratic objectives. By applying the spectral theorem, we can see that \(D_A\) sort of attenuates the eigenvalues of \(A\) in a non-linear manner. The \(i\)th eigenvalue of \(D_A\) is given by \(\lambda_i(D_A) = \lambda_i(A)(1 - \alpha \lambda_i(A))^2\). When you look at this as a quadratic function of the step size, it is decreasing from \(\lambda_i(A)\) to \(0\) in the range \([0\), \(1/\lambda_i(A))\) (which is the largest we could go due to regularity conditions anyways).

We can solve for the fixed point of the MAML update by setting \(\nabla L(x) = 0\). In the end, we get the clean solution

\[x^* = (D_A + D_B)^{-1}(D_A a + D_B b)\]

What is this fixed point?

This is kind of a matrix weighted average of the optima \(a\) and \(b\). There are cool ways to look at the spectrum of \(D_A + D_B\) to get an idea of what’s going on, but I think the simplest case to look at is when both \(A\) and \(B\) are diagonal, i.e. when their eigenbases “align” both with each other and with the coordinate system of \(a\) and \(b\). When this happens, \(x^*\) is a coordinate-wise weighted average of \(a\) and \(b\), where the weights are simply given by the attenuated eigenvalues of \(A\) and \(B\). In the general case, this scaling happens with respect to the coordinates of the vector \(a + b\) in the eigenbasis of the matrices. So we will move closer to the space generated by the eigenvectors corresponding to the largest eigenvalues.

How good is this fixed point?

Now I’ll derive the MAML loss of \(x^*\).We would expect that this is some function of the distance between the optima and the actual objectives, but what is MAML improving on? Let’s calculate the loss for a baseline approach. Why don’t we just take the average of the optima, \(\frac{1}{2}(a+b)\)? Plugging this into \(L\):

\[L(\frac{1}{2}(a+b)) = \frac{1}{2}[ f(\frac{1}{2}[(a+b) - \alpha A(b-a)]) + g(\frac{1}{2}[(a + b) - \alpha B(b-a)]) ]\] \[= \frac{1}{8}[ ((I + \alpha A)a + (I - \alpha A)b)^T A((I + \alpha A)a + (I - \alpha A)b ) + ((I - \alpha B)a + (I + \alpha B)b)^T B ((I - \alpha B)a + (I + \alpha B)b )]\]

This is actually a form of a Mahalanobis metric. The thing is, when I increase the distance between points, the loss can skyrocket, especially if I move them apart along the direction of one of the principal eigenvectors of \(A\) or \(B\). And, since we took \(a\) and \(b\) to have equal weight, a “low-risk” objective with smaller eigenvalues just hurts how well we do on a harder objective with higher eigenvalues, since we didn’t penalize that difference.

Knowing what we do about the loss function, we can extract out the fine-tuned models for the tasks given the midpoint as an initialization. For example, the fine-tuned model for \(f\) is \((I + \alpha A)a + (I - \alpha A)b\). So we increase the contribution of \(a\) a bit in our weighted average, by an additive factor of about \(\alpha \lambda_{min}(A)\). This is exactly a contraction in the Mahalanobis metric induced by \(A\). To see this, note that since \(A\) is positive definite, the eigenvalues of \(I + \alpha A\) are all greater than one, while the eigenvalues of \(I - \alpha A\) are all less than one. So this contraction is “pulling apart the fine-tuned model from the meta-model” by scaling the components of the \(a\) and \(b\) vectors away from each other in the eigenbasis of \(A\).

Now, let’s calculate the losses when we arrive at our fine-tuned models for the respective tasks. I’ll call \(D = D_A + D_B\). We get

\[(f \circ T_f)(x^*) = f(x^* - \alpha A(x^* - a)) = \frac{1}{2} ((I - \alpha A) x^* + \alpha A a)^T A ((I - \alpha A) x^* + \alpha A a)\]

We get a similar thing for \((g \circ T_g)(x^*)\). Hence, our overall loss using these fine-tuned models is

\[L(x^*) = \frac{1}{4}[ ((I - \alpha A) x^* + \alpha A a)^T A ((I - \alpha A) x^* + \alpha A a) + ((I - \alpha B) x^* + \alpha B b)^T A ((I - \alpha B) x^* + \alpha B b)]\]

Like before, we can extract out the fine-tuned models for each task starting from this initialization. This looks the same as before: we simply move away from the meta-model and towards \(a\) by taking a “mixture” of the two components weighted by \(A\). When is this doing better than \(\frac{1}{2}(a + b)\)? Precisely when \(x^*\) lies closer to the subspace spanned by the eigenvectors corresponding to the largest eigenvalues of \(A\). This is because multiplying by \(I - \alpha A\) will scale the components corresponding to these eigenvectors down to zero in only a few iterations. And \(x^*\) definitely does lie closer, since when we took the “weighted average” of \(a\) and \(b\), we gave more weight to those components.

Finishing up

In what sense is this “the best we could’ve done”? Well, MAML found the unique point that minimizes the expected loss of the “fine-tuned” models obtained by running gradient descent from that point for the respective tasks. For quadratics, this minimizer involves an interesting “curvature-weighted average” of the optima of the respective tasks. This minimizer has a kind of “accelerated path” towards a particular task’s optimum, since we have already pushed it closer to the eigenspaces corresponding to the larger eigenvalues. In a sense, we pushed the midpoint to the “top of the cliff” of the objective’s loss surface, so that we could tumble down quickly with a little push.

For the next steps, I’d like to investigate convergence for smooth, strongly convex objectives. This basically requires checking if the function \(f(x - \alpha \nabla f(x))\) is convex and smooth. Convexity is easy. A nice result would be to bound the number of iterations to achieve some \(\epsilon\) error for a task as a function of the distance between the two optima. Looking back on this quadratic stuff, I expect the distance to be with respect to something like a Mahalanobis metric based on the Hessians of the two tasks. This would give some insight into how MAML measures “task similarity” based on curvature information.

Edit: Convex, smooth objectives imply unique one-step fixed point

I thought that the case with smooth, convex objectives would end up being more interesting. Let’s investigate convexity properties of the one-step gradient descent updated version of \(f\), \(F(x) = f(x - \alpha \nabla f(x))\). First of all, if we assume \(\alpha < 1/\lambda_{max}(\nabla^2 f)\) since \(\nabla f\) is smooth, then only minimizer of \(F\) is the original minimizer \(x'\) of \(f\). Since \(\nabla f(x') = 0\), we get \(\|\alpha \nabla f(x)\| \leq \alpha \lambda_{max}(\nabla^2 f) \|x - x'\| < \|x - x'\|\). Moreover, you can show that the gradient update for \(F\) is a contraction, which means gradient desent converges to the same minimum as before.

I’d like to show that for a convex combination of smooth, convex objectives, the MAML objective is convex, too. I’m not sure if I have the right proof of that yet, though. But, I do think that there could be some useful information in the theory of multi-objective optimization and Paretto optimality. Besides that, we at least know that MAML converges in this case. This is because if two functions converge under gradient descent, then any convex combination of these two functions also converges. However, if the two functions converge to their minimizers in grzdient descent, I’m not sure if the convex combination has a unique minimizer. For example, if \(f, g\) are convex with minima \(y_f, y_g\), then we have \(\alpha f(x) + (1-\alpha)g(x) \geq \alpha y_f + (1-\alpha)y_g\) for \(0 \leq \alpha \leq 1\). But it does not imply that there exists some \(x\) that achieves this minimum. This is different than the convex combination of convex functions, since gradient descent can converge to minimizers without a convex objective.

comments powered by Disqus