Recurrent Networks and Test Time Training (TTT)
Notes on Interesting Papers on Recurrent Networks and their connection to Test Time Training (TTT)
Songlin has a great slidedeck on most of the papers I will go through here. I am just writing this for my own learning purposes.
Learning to (Learn at Test Time): RNNs with Expressive Hidden States
This paper presents an interesting perspective on how to update the hidden state of a recurrent model. Let's start by reviewing the traditional recurrent network structure:
Then the idea is to make the hidden state a have its own optimization process. Therefore, on test time when a new token comes in , then instead of just transforming the hidden state with a traditional fixed function, we will obvisouly still have some function to update it but that function is going to be the gradient update of some loss function .
Then then question is how can we construct such as loss function. The ones the author investigate takes the following form:
Here, the parameters project and are trainable. This formulation necessitates two optimization loops:
- Inner loop: Optimizes the meta-weights (which can be viewed as the weights of )
- Outer loop: Optimizes all remaining parameters and values
This dual-loop structure elegantly avoids the need to backpropagate through the same variable twice, thus circumventing the computation of Hessians.
Parallelizing Linear Transformations with Delta Rule over Sequence Length
Starting with the preliminaries: A single-head softmax attention:
We can view as a kernel and replace this with , where . As , we can create a feature map based on the Taylor series expansion of , where:
This results in:
For each time t , if we ignore the denominator and assume is the identity mapping, we can derive a recurrent formulation:
Here, S is updated with a rank-1 matrix , and then is transformed to yield .
Optimizing S with the Delta Rule
To ensure is close to when is close to , we can define the optimization problem:
The update can then be modified with a gradient step:
This Delta Rule allows to approximate when is close to , while ensuring past key-value pairs are preserved.
Titans: Learning to Memorize at Test Time
By adding a momentum term to the gradient descent view we can obtain the following modified update:
In the paper they note that and are data dependent controlling how much responsive the updates.
Gated Delta Networks: Improving Mamba2 with Delta Rule
By introducing a data dependent , there is greater modulation in the updates.
But the question I had is why is it on the term ? An interepretation of the Delta-Net formulation is
You can see that the interpretation of removing the prior key-value association is what the is modulating.
Test-Time Regression: A Unifying Framework for Designing Sequence Models with Associative Memory
This paper provides a general perspective on associative memory. The general optimization objective is:
In this framework, we generalize the function that processes keys. Instead of , we use . For an analytical solution:
Assume , so . The gradient becomes zero when:
If isn't invertible, use the pseudo-inverse. In addition , if , we recover vanilla linear attention:
There is more to this paper that I highly recommend checking out that explores different forms, such as what happens when we introduce softmax and approximate it as non-parametric regression. I think a fun investigation is to see what happens when we do a taylor expansion of the function and see how good of an approximation can we get with a few orders to concert a softmax attention to a linear one.