Improving Recurrent Models with Group Theory

Today’s article is going to be about understanding combination of the following papers: Unlocking State-Tracking in Linear RNNS Through Negative Eigenvalues, DeltaProduct: Increasing the Expressivity of DeltaNet Through Products of Householders by Riccardo Grazzi and Julien Siems and others. I came across this paper through the ASAP Seminar.

Motivation: First to put down some notation, we will denote a recurrent model as

Hi=A(xi)Hi1+B(xi),yi^=dec(Hi,xi)H_i = A(x_i)H_{i-1}+B(x_i), \quad \hat{y_i} =dec(H_i, x_i)

In the DeltaNet formulation, A(xt)=Iβtktkt,B(xt)=βtktvt,dec(Ht,xt)=HtqtA(x_t) = I - \beta_t k_t k_t^\top, \,\,\,B(x_t) = \beta_tk_tv_t^\top, \text{dec}(H_t, x_t) = H_t^\top q_t

This can also be seen as 1 step gradient descent on 12Hktvt22\frac{1}{2}\lVert H^\top k_t - v_t \rVert^\text{2}_\text{2}. Importantly, since kt=1\lVert k_t \rVert = 1 and the range of βt[0,1]\beta_t \in [0,1], and this means the eigen values of A(xt)A(x_t) are bounded between [0,1][0,1]. Therefore, if we modify the update to be A(xt)=I2βtktktA(x_t) = I - \text{2}\beta_t k_t k_t^\top, then this makes the eigenvalues be in the range [1,1][-1,1]. Interestingly, this is a householder matrix! To refresh our memory house holder transformations are reflections.

In this article, we focus on how a particular extension enables recurrent models to solve state tracking problems, especially those that arise in structured domains like group theory. To set the stage, recall that a group is a set equipped with an operation that satisfies four key properties: closure, associativity, the existence of an identity element, and the existence of inverses. A powerful result in abstract algebra is Cayley’s theorem, which states that every finite group is isomorphic to a subgroup of the symmetric group Sn S_n, the group of all permutations of nn elements. This means that, in principle, any finite group computation can be modeled using permutations.

One state tracking problem we consider involves the symmetric group Sn S_n. The input is a sequence of permutations x1,x2,,xtx_1,x_2, \dots ,x_t and the goal is for the model to output the cumulative group element at each step:

yt=xtxt1x1y_t=x_t \cdot x_{t-1} \cdot \ldots x_1

where \cdot denotes composition of permutations, and the product is taken from right to left (i.e., x1x_1 is applied first). This task requires the model to maintain a hidden state that effectively encodes the current group element, updating it with each new permutation.

As a concrete example, let x1=(1,2)(3,4)x_1=(1,2)(3,4) and x2=(1,3)(2,4)x_2=(1,3)(2,4), both elements of S4S_4. Then the desired output after the second step is y2=x2x1=(1,4)(2,3)y_2=x_2 \cdot x_1=(1,4)(2,3), which is the result of composing the two permutations.

So how does this relate to Householder transformations? When permutations are viewed as transformations acting on a space, swaps can be naturally expressed using Householder reflections—a class of orthogonal transformations that reflect vectors across hyperplanes. This connection allows us to embed group operations directly into the linear dynamics of recurrent models, facilitating state tracking. The paper provides a helpful figure that visually demonstrates how permutations correspond to sequences of Householder reflections.

This also implies that if I am operating on SnS_n, then I would need n1n-1 householder reflections. And this brings us to the DeltaProduct Formulation: For each token xix_i, generate nhn_h keys, values, and step sizes (the betas):

ki,j=WjxiWjxi2,vi,j=Vjxi,βi,j=ϕ(Ujxi)whereϕ=2×sigmoidHi,j=(Iβi,jki,jki,j)Hi,j1+βi,jki,jvi,jwithHi,0=Hi1,Hi,nh=Hi.Hi=A(xi)Hi1+B(xi)\begin{aligned} k_{i,j} &= \frac{W_j x_i}{\| W_j x_i \|_2}, \quad v_{i,j} = V_j x_i, \quad \beta_{i,j} = \phi(U_j x_i) \\\text{where}\,\, \phi&=2\times\text{sigmoid}\\ \\ H_{i,j} &= (I - \beta_{i,j} k_{i,j} k_{i,j}^\top) H_{i,j-1} + \beta_{i,j} k_{i,j} v_{i,j}^\top\\ \text{with}\,\,\, H_{i,0} &= H_{i-1}, H_{i,n_h} = H_i. \\ H_i &= A(x_i) H_{i-1} + B(x_i) \end{aligned}

where:

A(xi)=j=1nh(Iβi,jki,jki,j)B(xi)=j=1nh(k=j+1nh(Iβi,kki,kki,k))βi,jki,jvi,j\begin{aligned} A(x_i) &= \prod_{j=1}^{n_h} \left( I - \beta_{i,j} k_{i,j} k_{i,j}^\top \right) \\ B(x_i) &= \sum_{j=1}^{n_h} \left( \prod_{k=j+1}^{n_h} \left( I - \beta_{i,k} k_{i,k} k_{i,k}^\top \right) \right) \beta_{i,j} k_{i,j} v_{i,j}^\top \end{aligned}

So to put in words, we are expanding our hidden state update to do nhn_h updates, each step having its own dedicated weight to produce a house holder transformation. This can also be seen as doing nhn_h steps of gradient descent on the associate recall objective. Now this will allow us to perform state tracking. Interestingly, there is a case where in S4,A5S_4, A_5 you only need nh=2n_h=2, instead of the theoretical value 3. See the following Figure 3.

The reason is they are isomorphic to subgroups of SO(3,R)\text{SO}(3, \mathbb{R}). To see this, first consider the S4S_4 setting where we have an ordering of 4 elements. Interestingly, we can represent these 4 elements as the diagonals of a 3D cube. The operations on this cube are orientation preserving rotations, and you can record the ordering of the 4 diagonals post transformation as the new group element. Now, things for A5A_5 are a bit more involved. A5A_5 is the alternating group of 5 elements, the alternating group has only even permutations of the 5 elements - meaning there are an even number of swaps to create the permutations. We can actually treat each element as the position of a colored tetrahedron inside an icosahedron, and again we can rotate the icosahedron while preserving its orientation and record the position of the tetrahedrons after a rotation. See both examples below, and the icosahedron is taken from this article.

In addition, the Cartan-Dieudonné theorem, states that

that every orthogonal transformation in an n-dimensional symmetric bilinear space can be described as the composition of at most n reflections.

For a rotation in 3D space, we just need 2 reflections, provided that our keys are also in subspace of dimension 3. This explains why nhn_h can be just 2. Here are some illustrations to explain this geometrically. The first one is from the paper, and the second was generated in Matplotlib just for myself to see an example.

Acknowledgement

Big thanks to Riccardo Grazzi, and Julien Siems for reading the article and providing really helpful feedback on adding more depth.

References:

https://en.wikipedia.org/wiki/Householder_transformation

https://en.wikipedia.org/wiki/Cartan–Dieudonné_theorem

https://en.wikipedia.org/wiki/Symmetric_group

https://arxiv.org/abs/2404.08819

https://en.wikipedia.org/wiki/Alternating_group