Understanding PaTH Attention
This is a post that walks through my understanding of an awesome paper by my friend Songlin, PaTH Attention: Position Encoding via Accumulating Householder Transformations. She presented it during our seminar’s summer bootcamp, make sure to check that out as well (YouTube).
Many of you would know RoPE - which allows us to encode positions.
Its interesting to think about complexity classes now. The above equations are fully parallelizable, meaning that no matter how many tokens are in the sequence, the depth of the computation graph is constant . Now why is this important?
Well, consider a state tracking problem in where we need to output the group element after a series of permutations. For example, let and , both elements of . The desired output after the composition . This problem can’t be done with a model that has a fixed depth computational graph, and in-fact it needs a computational graph whose depth grows log of the input size .
So what models can do this - well certain recurrent models can do it, check out my previous blog post. So how can we modify transformers which are currently in to do state tracking? We need to somewhere modify the computational graph such that the depth grows as we have more inputs. Revisiting the idea of position embeddings, instead of RoPE which has matrix , what if we have a function whose depth grows as the distance grows. Naturally, the idea of having not just a single rotation matrix but a composition seems reasonable, but actually a cool math fact is:
Householder transformations are strictly more expressive than a single rotation matrix. While a rotation matrix belongs to the special orthogonal group (determinant +1), Householders generate the entire orthogonal group , which includes both rotations and reflections. In fact, any rotation matrix can be written as a product of Householders, but not vice versa—so working with them gives us a strictly larger space of transformations.
Following this we arrive at the PaTH formulation:
is the softmax normalization. , and is some function of token representation .
This still can still retain the associative recall capabilities of quadratic transformers because unlike linear transformers we have the non-linear softmax that prevents us collapsing the attention into a single state matrix. Note: I am skipping a large chunk of the paper which goes over efficient training - my main interest is the capabilities and the formulation of the model - but Section 3 is also critical.
Now we will now show how we can solve a state tracking problem. So in Appendix A: Representation Power of Transformers with PaTH Attention, we see how the PaTH attention block can solve a swapping tasks: given the permutation group of elements , consider a sequence of tokens that denotes the elements (swaps). The sequence is constructed in the following form:
The # is the start token and each one of the bracket terms are another token that denotes a particular swap action, thus including the # there are 21 tokens in total, each have a unique one hot vector . I am going to now go a bit out of order in how the explain things (just to write down my understanding). First we define the Householder weight .
are basis vectors, and when we define now the house holder transformation, we can see how it performs a swap operation.
For :
Thus, it just swaps and while preserving the rest in place. Cool!
These are all the parameters defined. So the only key and value that are non-zero are for the first token #, then if use the PaTH formulation to get the output for the last token, we obtain:
where is i’th element of the vector after all the permutations/swaps were applied to the original vector . So when is the identity after all the swaps, then , and otherwise . There is more to it to make this a final prediction of or , but this is the essence of it and I will end here.
Final thoughts. I think this paper is very interesting because it added the extra computation in the positional encoding, showing how we do more complex transformations of the tokens when comparing them for attention. It is also interesting to think about what are the other axis we can efficiently inject more computation in the transformer.