Recurrent Networks and Test Time Training (TTT)

Notes on Interesting Papers on Recurrent Networks and their connection to Test Time Training (TTT)

Songlin has a great slidedeck on most of the papers I will go through here. I am just writing this for my own learning purposes.

Learning to (Learn at Test Time): RNNs with Expressive Hidden States

This paper presents an interesting perspective on how to update the hidden state of a recurrent model. Let's start by reviewing the traditional recurrent network structure:

Wt=g(xt,Wt1),zt=f(xt;Wt)W_t = g(x_t, W_{t-1}), \quad z_t = f(x_t;W_t)

Then the idea is to make the hidden state a have its own optimization process. Therefore, on test time when a new token comes in xtx_t, then instead of just transforming the hidden state with a traditional fixed function, we will obvisouly still have some function to update it but that function is going to be the gradient update of some loss function \ell.

Wt=g(xt,Wt1)=Wt1η(Wt1;xt)W_t = g(x_t, W_{t-1}) = W_{t-1} - \eta \nabla \ell(W_{t-1};x_t)

Then then question is how can we construct such as loss function. The ones the author investigate takes the following form:

(W;xt)=f(θkxt;W)θVxt2zt=f(θQxt;Wt)\begin{aligned} \ell(W;x_t) &= \lVert f(\theta_kx_t;W)-\theta_Vx_t \rVert^2 \\ z_t &= f(\theta_Qx_t;W_t) \end{aligned}

Here, the θ\theta parameters project xtx_t and are trainable. This formulation necessitates two optimization loops:

  1. Inner loop: Optimizes the meta-weights WW (which can be viewed as the weights of ff)
  1. Outer loop: Optimizes all remaining parameters and θ\theta values

This dual-loop structure elegantly avoids the need to backpropagate through the same variable twice, thus circumventing the computation of Hessians.


Parallelizing Linear Transformations with Delta Rule over Sequence Length

Starting with the preliminaries: A single-head softmax attention:

qt,kt,vt=WQxt,WKxt,WVxt,ot=i=1texp(kiqt)j=1texp(kjqt)viq_t, k_t, v_t = W_Qx_t, W_Kx_t, W_Vx_t, \quad o_t = \sum_{i=1}^t\frac{\exp(k_i^\top q_t)}{\sum_{j=1}^t \exp(k_j^\top q_t)}v_i

We can view exp(kiqt)\exp(k_i^\top q_t) as a kernel and replace this with ϕ(ki)ϕ(qt) \phi(k_i)^\top\phi(q_t) , where ϕ:RdRn\phi: \mathbb{R}^d \rightarrow \mathbb{R}^n . As nn \rightarrow \infty , we can create a feature map based on the Taylor series expansion of exp(kiqt)\exp(k_i^\top q_t) , where:

exp(x)=n=0xnn!\exp(x) = \sum_{n=0}^{\infty}\frac{x^n}{n!}

This results in:

ot=i=1tϕ(ki)ϕ(qt)j=1tϕ(kj)ϕ(qt)vi=(i=1tviϕ(ki))ϕ(qt)(j=1tϕ(kj))ϕ(qt)o_t = \sum_{i=1}^{t}\frac{\phi(k_i)^\top\phi(q_t)}{\sum_{j=1}^{t}\phi(k_j)^\top \phi(q_t)}v_i = \frac{(\sum_{i=1}^t v_i \phi(k_i)^\top)\phi(q_t)}{(\sum_{j=1}^t\phi(k_j)^\top)\phi(q_t)}

For each time t , if we ignore the denominator (j=1tϕ(kj))ϕ(qt)(\sum_{j=1}^t\phi(k_j)^\top)\phi(q_t) and assume ϕ\phi is the identity mapping, we can derive a recurrent formulation:

St=St1+vtkt,ot=Stqt S_t = S_{t-1} + v_tk_t^\top, \quad o_t = S_tq_t 

Here, S is updated with a rank-1 matrix vtktv_tk_t^\top, and then qtq_t is transformed to yield ot o_t .


Optimizing S with the Delta Rule

To ensure StqtS_tq_t  is close to vt v_t when qtq_t  is close to ktk_t , we can define the optimization problem:

Lt(S)=12Sktvt2 \mathcal{L}_t(S) = \frac{1}{2}\left\lVert Sk_t - v_t \right\rVert^2 

The update can then be modified with a gradient step:

St=St1βtSt1Lt(St1)=St1βt(St1ktvt)kt=St1(Iβtktkt)+βtvtktS_t = S_{t-1} - \beta_t \nabla_{S_{t-1}}\mathcal{L}_t(S_{t-1}) = S_{t-1} - \beta_t(S_{t-1}k_t - v_t)k_t^\top \\= S_{t-1}(I - \beta_tk_tk_t^\top)+\beta_tv_tk_t^\top

This Delta Rule allows ot=Stqto_t = S_tq_t to approximate vtv_t when qtq_t is close to ktk_t , while ensuring past key-value pairs are preserved.


Titans: Learning to Memorize at Test Time

By adding a momentum term to the gradient descent view we can obtain the following modified update:

Gt=ηtGt1θtLt(St1xt)St=St1+Gt\begin{aligned}G_t &= \eta_t G_{t-1} - \theta_t\nabla\mathcal{L}_t(S_{t-1} | x_t) \\ S_t &= S_{t-1} + G_t \end{aligned}

In the paper they note that ηt\eta_t and θt\theta_t are data dependent controlling how much responsive the updates.


Gated Delta Networks: Improving Mamba2 with Delta Rule

By introducing a data dependent αt(0,1)\alpha_t \in (0,1), there is greater modulation in the updates.

St=St1(αt(Iβtktkt))+βtvtkt S_t = S_{t-1}(\alpha_t(I - \beta_tk_tk_t^\top)) + \beta_tv_tk_t^\top 

But the question I had is why is it on the term IβtktktI - \beta_tk_tk_t^\top? An interepretation of the Delta-Net formulation is

St=St1vtoldkt+vtnewktSt=St1(St1kt)kt+vtkt\begin{aligned} S_{t} &= S_{t-1} - v_t^{\text{old}}k_t^\top + v_t^{\text{new}}k_t^\top \\ S_{t} &= S_{t-1} - (S_{t-1}k_t)k_t^\top + v_tk_t^\top \end{aligned}

You can see that the interpretation of removing the prior key-value association is what the α\alpha is modulating.


Test-Time Regression: A Unifying Framework for Designing Sequence Models with Associative Memory

This paper provides a general perspective on associative memory. The general optimization objective is:

minmMi=1T12γivim(ki)22\min_{m \in \mathcal{M}}\sum_{i=1}^T\frac{1}{2}\gamma_i\left \lVert v_i- m(k_i) \right \rVert^2_2

In this framework, we generalize the function that processes keys. Instead of SkiSk_i, we use m(ki)m(k_i). For an analytical solution:

Mt=arg minM12i=1tMkivi22M_t = \argmin_{M} \frac{1}{2}\sum_{i=1}^t \left \lVert Mk_i - v_i \right \rVert_2^2

Assume viRdv,kiRdkv_i \in \mathbb{R}^{d_v}, k_i \in \mathbb{R}^{d_k}, so MtRdv×dkM_t \in \mathbb{R}^{d_v \times d_k}. The gradient becomes zero when:

M=i=1t(Mkivi)kiM=(MKV)KM=VK(KK)1\begin{aligned} \nabla_M &= \sum_{i=1}^t (M k_i - v_i) k_i^\top \\ \nabla_M &= (M K - V) K^\top \\ M &= V K^\top (K K^\top)^{-1} \end{aligned}

If KKKK^\top isn't invertible, use the pseudo-inverse. In addition , if KK=I K^\top K = I, we recover vanilla linear attention:

ot=VKqto_t = VK^\top q_t

There is more to this paper that I highly recommend checking out that explores different forms, such as what happens when we introduce softmax and approximate it as non-parametric regression. I think a fun investigation is to see what happens when we do a taylor expansion of the exp\exp function and see how good of an approximation can we get with a few orders to concert a softmax attention to a linear one.