Noah Golmant

Dynamical Systems and a Suitably Abstract View of Autoregressive Transformers

The basic premise of this post is that autoregressive transformers with stacks of self-attention modules, interspersed with various non-linear transformations, interpolate between observations in a high-dimensional dynamical system. A lot of the work to see this comes from more of a “perspective shift” than anything proof-based. But, if we squint really hard, we can understand a properly trained transformer to be simulating Stochastic Gradient Langevin Dynamics (SGLD) on a suitable potential function. Much of the work will be in defining this potential and understanding where the noise comes in. The rest of the work will be understanding why this is possibly a useful perspective for the system. When we start taking this more dynamics-focused view of the system, we reveal new perspectives on popular practices, such as using linear classification heads for transfer learning.

Let’s define our terms. Let \(\boldsymbol \tau = \{\tau_1, \ldots, \tau_T\}\) be the input sequence of \(T\) tokens. Let \(\boldsymbol e = (E(\tau_1), \ldots, E(\tau_T))\) where \(E\) is the dictionary map from tokens to \(D\)-dimensional embedding vectors. This is usually implemented directly through something like PyTorch’s nn.Embedding. If we regard \(\boldsymbol e\) as a random variable, we are essentially dealing with sample trajectories of a system in some \(D\)-dimensional state space.

The self-attention operation can be written as follows. This particular perspective looks like the self-attention module located right after the initial embedding layer. Subsequent self-attention modules deal with some transformed sequence other than \(\boldsymbol e\). We will ignore multi-headed modules for now, although they represent a particular factorization of the problem. Let \(L_Q, L_K, L_V\) be the query, key, and value linear transformations applied to \(\boldsymbol e\). These are usually implemented as 1-D convolutions in code, and they are usually actually affine because they include bias terms. I am not simply reducing them to matrices because keeping them in the abstract “\(L\)” notation may help with analysis of particular properties in the future (e.g. 1-D convolutional form, or multi-headed magic), and I try not to be a pedant.

Then, considering \(\boldsymbol e\) as a matrix, let \(Q = L_Q(\boldsymbol e), K = L_K(\boldsymbol e), V = L_V(\boldsymbol e)\) be new random variables. Again, they are really like sample trajectories from their respective dynamical systems, each of which is obtained by applying the respective linear transformation to each row vector in \(\boldsymbol e\).

Zooming in on the role of the softmax

Now, I will be investigating a particular object that is usually ignored in most conceptual analysis of self-attention. Consider \(S = \sigma(M QK^{\intercal})\), which is obtained by applying a row-wise softmax to \(QK^{\intercal}\), possibly prefaced by some autoregressive (“causal”) mask \(M\). \(S\) is a \(T \times T\) matrix. The output of attention, \(Y\), is then equal to \(Y = SV\). For those who are newer to transformers, I hope this last operation makes it more obvious why longer contexts are more computationally intensive for vanilla self-attention. There is a matrix multiplication with two “\(T\)”s in the dimension arguments. Ouch.

Most people force this mathematical object \(S\) to play second fiddle to \(V\) in that last equation. They treat the elements of \(S\) as weights applied to sum up the various candidate value vectors in \(V\). Well, I am here to put a stop to that- let \(S\) play second fiddle no longer. Ignore the particulars of causal masking, for a moment, and let’s take a look at the structure of an element of \(S\). Say, \(S_{ij}\). We should have that

\[S_{ij} = \frac1Z \exp\{ \langle q_i, k_j \rangle \}\]

For suitable query and key vectors \(q_i\) and \(k_j\). These were rows of \(Q\) and \(K\) before. I am going to now perform a trick that I once saw referred to as “the fundamental theorem of optimization”. For simplicity, rename \(k_i\) as \(k\) and \(q_j\) as \(q\).

\[\langle k, q \rangle = \frac12 [ \|k\|^2 + \|q\|^2 - \|k - q\|^2 ]\]

Now we can make a fun swap: we observe that by absorbing a bunch of random constants into and out of the subsequent equations, letting \(\mathcal{N}(x \vert \mu)\) be the density of x under an isotropic multivariate normal centered at \(\mu\), and, again, only really looking at causally appropriate indices, we actually find that

\[S_{ij} \propto \frac{\mathcal{N}(q_i \vert \mu = k_j)}{\mathcal{N}(q_i \vert \mu = \vec{0})\mathcal{N}(k_j \vert \mu = \vec{0})}\]

That is, each entry is an odds ratio between an assertion of the form “query \(q\) is from a normal centered about key \(k\)” and a null hypothesis of the form “\(k\) and \(q\) are drawn independently from normals centered about the origin”.

Although this process is symmetric in \(k\) and \(q\), it is useful to use the above language, which basically amounts to a fuzzy key-value lookup process. One could further simplify the analysis by setting \(Q = K\) so that \(S = \sigma(Q Q^T)\). This would keep us firmly in “eigen” land rather than looking at singular values/vectors. It would also allow us to interpret the softmax with more of a constrastive feel. One interesting question for the future is: how does \(\sigma\), the row-wise softmax, affect the “eigenthings” of \(S\)? One way to analyze this might be to regard the elements as edge weights for a particular vertex on a directed graph, and in turn have each edge weight represent a transition probability for a random walk on this graph. The stationary distributions of this random walk should be the eigenvectors.

It is interesting to note at this point that all of the rows of \(S\) are vectors lying in the \((T-1)\)-simplex, since each row has \(T\) elements whose magnitudes sum to one. Any eigenvectors would represent typical conjunctions of certain hypotheses about relationships between the tokens. These relationships were expressed earlier as odds ratios about key-query similarity. It may be that the stability of the system (stability in the sense of “Jacobian eigenvalues are close to the unit disk”) can be found by looking at these eigenvectors in particular, ignoring \(V\) for the most part. This may happen if, for example, most of the numerical dynamic range was accounted for in the exponential softmax operation compared to the linear change in magnitudes from the final multiplication with \(V\).

Anyways, the final step: \(Y = S V\). Easy, right? You could say that the rows of \(S\) represent barycentric coordinates for a simplex with masses placed at the vertices. These vertices are to be found in the rows of \(V\). Let’s investigate that idea further and understand a single row of \(Y\), say \(Y_t\). Basically, in column vector notation, \(Y_t = V^{\intercal} S_t\). This output vector is a convex combination of (at most, considering causal masking) \(T\) \(D\)-dimensional basis vectors, one for each time step. \(S_t\) is thus a coordinate vector, and normalized. So it is a barycentric coordinate. But, remember that \(V\) is itself the result of applying a linear transformation (really, an affine one) to the embedded input sequence \(\boldsymbol e\). So any estimates being done with this output are made by evaluating whether the target can be expressed as a convex combination of the affinely transformed input.

Quadratics, Gaussians, and bears, oh my

Now, we can throw in the language of optimization, to see what this system “should” do when it maximizes a likelihood. The actual output vector \(Y_t\) will be used to predict some target quantity in the future. So, really, a trained self-attention module is “good” when the module outputs some \(Y_t\) close to the target. What is the target vector in our case? Well, the overarching problem is token inference. So we usually think about this in terms of token likelihoods. But when the linear embedding is fixed, so that each token maps to a unique dictionary embedding vector, we are maximizing the inner product between the output and the target token’s embedding vector. Considering the additional implict constraints on feature norm imposed by regularization tools like LayerNorm and weight decay, we can equivalently minimize the distance between the output and the target embedding vector.

This means that for a properly trained transformer, with a high-dimensional embedding space, an MLP, and a residual connection thrown in the mix, this dictionary embedding vector can be written as a linear combination of our current value vectors to a sufficient degree of accuracy.

Let’s hone our problem down for a second to see this more clearly. We will specifically consider and rename the output \(\hat{Y}_{t+1}\) to be the \(t\)-th row of \(Y\) from before. It is a vector which will in effect serve to estimate the \(t+1\)-st token in the sequence. In theory, gradient descent should optimize this thing to be close to the embedding vector \(E(\tau_{t+1})\) through likelihood maximization, right?

Cool, let’s take that last embedding vector as the target, call it \(e_{t+1}\), and write is as a convex linear combination of value vectors like I mentioned above. But, hold up… the linear combination, as a coefficient vector \(A_t\), is also a barycentric coordinate for the simplex!

So we can do all of our math and analysis on that simplex. In fact, the distance can equivalently be written as a general quadratic form. So if \(\hat{Y}_{t+1} = V^{\intercal} S_t\), and \(e_{t+1} \approx V^{\intercal} A_t\),

\[\| \hat{Y}_{t+1} - e_{t+1} \|^2 \approx \| V^{\intercal}S_t - V^{\intercal}A_t \|^2 = (S_t - A_t)^\intercal V V^{\intercal} (S_t - A_t)\]

This is a quadratic form on the simplex! Or maybe a Mahalonobis distance when we regard \(V V^{\intercal}\) as a covariance matrix. That sounds like a cool, math-y statement you would hear in a movie. I’m sure you could do something cool with that knowledge. Or maybe you could do some kind of optimization directly on that surface.

Again, the important part here is that we are at this stage of analysis as a result of a local “good basis” assumption. That is, we assumed that the target was approximately in the span of the value vectors, and in the convex hull in particular. That is just what it means for this transformer to be trained to maximize the likelihood of the token sequence with a fixed embedding. Although the convexity part may have been enabled in our “expressivity” booklet by the addition of a non-linearity and residual term, as well as by the work of preceding layers. But still, it says that there exists a barycentric coordinate \(A_t\) that approximately minimizes the above quadratic.

Now, I have yet to make any claims about how gradient descent or a variant might arrive at the optimizer of a likelihood objective in this model. But, suppose it did. Just for a second. That would mean that, averaged over some time interval (to get our autoregressive niche in), and integrated over some volume of trajectories, the minimum quadratic error is low. That is, a maximizer of the likelihood yields a system in which the resulting embedding vector can be consistently estimed using this least-squares-y process.

So the likelihood maximization problem in token space is equivalent to maximizing the likelihood of certain Gaussians. You can think about these Gaussians in either the embedding space, or in the barycentric coordinate space. For example, if you are dealing with an isotropic Gaussian likelihood centered around \(e_{t+1}\), one could just as easily consider \(S_t\) to be a maximizer of the likelihood \(\mathcal{N}(\cdot \vert \mu = A_t, \Sigma = (VV^{\intercal})^{-1})\) (just take the pseudoinverse if it’s really not invertible- you probably get the point). Some people may take issue with the fact that I am treating the linear embedding as fixed here, but any modifications to it could probably be absorbed into the learning dynamics of the actual residual attention blocks of the transformer since they are so expressive. There are definitely good and bad embeddings, but that’s not so pertinent for the present analysis.

My hypothesis here is that, at any given moment, we can interpret the trained transformer self-attention modules as approximately optimizing these Gaussian likelihoods. The inherent noise in this process comes from the potential modeling error that accrues when moving from \(e_{t+1}\) to \(A_t\), which is the least squares solution to approximating \(e_{t+1}\) in the convex hull of \(V\). But overall, the solution (left-multiplied by \(V^{\intercal}\)) is approximately performing Langevin dynamics about the embedding vector in order to maximize this likelihood. The SGLD potential we were looking to noisily minimize is just the corresponding negative log-likelihood.

Deeper networks and non-linear interpolation

Most of the above thinking takes place around the “pre-logit” feature space of the network. The optimization goals are most obvious here since we are dealing with dictionary embedding vectors as targets. The self-attention module is only “good” at next-token prediction when it produces an estimate close to the next dictionary embedding vector in the input sequence. This is why people fix the “unembedding” layer to be the transpose of the linear embedding layer weights.

If we take a step back from this point, the significance of depth will make more sense. We start with an initial token sequence \(\boldsymbol \tau\). We map that to an embedding sequence \(\boldsymbol e\). We apply an autoregressive transformer

\[F = f_L \circ f_{L-1} \circ \ldots \circ f_1\]

where \(f_i\) denotes the \(i\)-th residual attention block, and we get some pre-logit hidden state sequence \(\boldsymbol z = F(\boldsymbol e)\). You could also break this down on a per-layer basis and investigate

\[\boldsymbol z^{(\ell)} = f_{\ell} \circ \ldots \circ f_1(\boldsymbol e) = f_{\ell}(\boldsymbol z^{(\ell - 1)})\]

where \(\ell\) indexes over layers. These are the dynamics of the intermediate hidden states. In the same way that we broke down \(Y\) into a sequence of vectors and honed in on the particular approximation capabilities of \(\hat{Y}_{t+1}\) earlier, we can isolate the state space trajectory for the estimate of a single token as it passes through the network. This trajectory should be of length \(L+1\), since that is the depth of the transformer, plus one to keep track of the initial embedding. It looks like:

\[\text{depth-wise trajectory for estimate at timestep $t$} = \{ e_t, z^{(1)}_t, z^{(2)}_t, \ldots, z^{(L)}_t = \hat{Y}_{t+1}\}\]

Now, the final layer of the network is simply optimizing over this trajectory so that \(z^{(L)}_t \approx e_{t+1}\). This is the more precise sense in which autoregressive transformers interpolate the dynamics in the embedding space. Given sequential vector-valued observations \(e_t\) and \(e_{t+1}\), it is using stacks of residual attention blocks to iteratively approximate the steps for unobserved states of the system between these observations. Unlike an RNN, the hidden states for all timesteps are computed in parallel. But the similarity becomes more obvious when you think of the autoregressive transformer in its natural habitat, functioning as a generative model that decodes the output step-by-step.

What are the attractors?

Now that we have interpreted the autoregressive transformer as a full-fledged dynamical system with an interesting state space structure, the most interesting question becomes: what are its attractors? Since it is a system trained to maximize the likelihood of embedding vectors in the pre-logit feature space, the most obvious candidates for possible attractors there are: the embedding vectors themselves!

However, there is possibly more to the story. We still know very little about the dynamics of the intermediate layers. Remember, \(\boldsymbol z^{(\ell)}\) is itself a trajectory, but we don’t know much about its behavior. We just know that for \(\ell = L\), at the final layer, that trajectory should look like the input sequence, “rolled” over by one timestep. But that doesn’t tell us much about the middle layers. Empirically, it seems like these middle layers sometimes contain more “useful” features for other prediction tasks. But there is also the trajectory I labeled above, which has its own behavior, and involves breaking down time along the “depth” dimension.

In the end, that ambiguity is not super useful. There is a way in which you can get around this problem by flattening the feature vectors. In this case, you just iterate along the depth dimension first. The resulting trajectory looks starts like:

\[\text{flattened trajectory} = \{e_1, z^{(1)}_{1}, z^{(2)}_1, \ldots, z^{(L)}_1, e_2, z^{(1)}_2, z^{(2)}_2, \ldots\}\]

It’s at this point that mathematical notation can feel a bit cumbersome. But I promise that it’s useful. We end with single, ordered sequence of \(D\)-dimensional vectors. Both the embedding vectors as well as the “depth-wise” trajectories reemerge when we take subsequences from this at fixed periods of length \(L\). So we can recover all the older trajectories of interest by “sub-sampling” from this longer sequence.

How do transfer learning and linear classification “heads” work?

My closing hypothesis is that these trajectories have attractors that represent semantically meaningful concepts. There is some basic empirical evidence for this claim. After all, how do people utilize feature vectors from these models for non-generative tasks? They can either take the GPT-3 approach and do in-context meta-learning (this deserves its own article), or they train a linear classification layer on top of the “context-pooled” features. This last term just means they average the feature vectors over the time dimension and plug it into a linear layer, just like people do with ResNet/VGG features for image classification problems. In some models, like in iGPT, they may concatenate or pool features over multiple layers, too.

Suppose we are training a linear classifier for image classification on top of our network’s features. Concretely, we trained iGPT, and are now using the sequence I defined above as the possible set of features for a CIFAR-10 model. In general, say we have \(C\) classes. To do classification, we normally think of the operation at the abstraction level of a sentence like “train logistic regression on top of these features”. But we can take the linear basis view again, like we did above, to dive one level deeper in the abstraction hierarchy. The linear layer is really more like a set of \(C\) basis vectors in the feature space. The softmax preserves the rank order of the arguments, and the arguments are just inner products between these class basis vectors and the feature vectors. So, with the assumption I mentioned before, based on the regularization constraints on the norms of these things, we can see that the classification objective amounts to finding basis vectors that minimize the average distance between the feature vectors and class basis vectors in this state space.

Now suppose these pooled features are actually useful for classification. What does that imply? It means that one can predict the reported class of the object by seeing if the time-average of the system is close to the basis vector corresponding to a particular class. That is what a linear classification layer is doing, after all. This is the same thing as asserting that there exists an attractor point at this basis vector. From a likelihood maximization perspective, if you can truly move all of the categorical likelihood “thinking” back into Gaussians in the feature space, this would be like saying that the state space trajectories for images in the given class approximately behave like a Gaussian around the class basis vector.

This would explain why iGPT does so well on classification tasks. There is still a lot of work to be done to explain why this linear separability should necessarily emerge. But I don’t think it’s controversial to call the solution an attractor. Part of the answer to the “why” question may just come down to the fact that these algorithms happen to be really good at empirical prediction. And if you buy the whole “prediction and compression are equivalent problems” story (we could always encode the likelihood using arithmetic coding), the model must be a really good compressor. Surely, a compressor wouldn’t be very good if it couldn’t compress easily expressed semantic concepts like “cat” and “dog”, right? In fact, the inability to do so would sort of, by definition, mark it as a bad compressor. Fundamentally, had the algorithm known the class beforehand, there would have been a more efficient way to estimate the sequence of pixels in the image. I hope to expand a bit on this point in a future post, since it hints at some subtler issues that occur whenever we endeavor to construct a language to describe the internal representation of a predictive system.

Maybe the state space trajectory would have been constrained to certain low-rank subspaces, or limited and distinct volumes of the state space. I don’t know the precise mechanisms, and the geometry of such an attractor may end up being more complex than I thought. But in the same way that dynamical systems theory helped scientists understand empirical phenomena like binocular rivalry through the lens of enactivism, I hope that we can use the same theory to understand complex qualitative behaviors in these models. It seems like we have some basic language to understand implications of linear separability in terms of attractor sets, but there is an untamed jungle of other topics such as meta-learning in language models or cross-modal alignment of representations in multimodal systems. Time for some bifurcation analysis?

comments powered by Disqus