Understanding PaTH Attention

This is a post that walks through my understanding of an awesome paper by my friend Songlin, PaTH Attention: Position Encoding via Accumulating Householder Transformations. She presented it during our seminar’s summer bootcamp, make sure to check that out as well (YouTube).

Many of you would know RoPE - which allows us to encode positions.

ij  =  qiTijRoPEkjTijRoPE  =  Rij.q~i  =  Riqi,k~j  =  Rjkj\ell_{ij} \;=\; q_i^\top\, T^{\mathrm{RoPE}}_{i\leftarrow j}\, k_j \quad T^{\mathrm{RoPE}}_{i\leftarrow j} \;=\; R^{\,i-j}. \\ \tilde q_i \;=\; R^{-i} q_i,\quad\tilde k_j \;=\; R^{-j} k_j

Its interesting to think about complexity classes now. The above equations are fully parallelizable, meaning that no matter how many tokens are in the sequence, the depth of the computation graph is constant TC0\text{TC}^0. Now why is this important?


Well, consider a state tracking problem in S4S_4 where we need to output the group element after a series of permutations. For example, let x1=(1,2)(3,4)x_1 = (1,2)(3,4) and x2=(1,3)(2,4)x_2 = (1,3)(2,4), both elements of S4S_4. The desired output after the composition x1x2=(1,4)(2,3)x_1 \cdot x_2 = (1,4)(2,3). This problem can’t be done with a model that has a fixed depth computational graph, and in-fact it needs a computational graph whose depth grows log of the input size NC1\text{NC}^1.

So what models can do this - well certain recurrent models can do it, check out my previous blog post. So how can we modify transformers which are currently in TC0\text{TC}^0 to do state tracking? We need to somewhere modify the computational graph such that the depth grows as we have more inputs. Revisiting the idea of position embeddings, instead of RoPE which has matrix RijR^{i-j}, what if we have a function whose depth grows as the distance iji-j grows. Naturally, the idea of having not just a single rotation matrix RR but a composition seems reasonable, but actually a cool math fact is:

Householder transformations are strictly more expressive than a single rotation matrix. While a rotation matrix belongs to the special orthogonal group SO(n)SO(n) (determinant +1), Householders generate the entire orthogonal group O(n)O(n), which includes both rotations and reflections. In fact, any rotation matrix can be written as a product of Householders, but not vice versa—so working with them gives us a strictly larger space of transformations.


Following this we arrive at the PaTH formulation:

ot=1Ztj=1tvjexp ⁣(kj(s=j+1tHs)qt)o_t = \frac{1}{Z_t}\sum_{j=1}^t v_j \exp\!\Big( k_j^\top \left( \prod_{s=j+1}^t \mathbf{H}_s \right) q_t \Big)

ZtZ_t is the softmax normalization. Ht=Iβtwtwt\mathbf{H}_t = \mathbf{I} - \beta_tw_tw_t^\top, and wtw_t is some function of token representation xtx_t.

This still can still retain the associative recall capabilities of quadratic transformers because unlike linear transformers we have the non-linear softmax that prevents us collapsing the attention into a single state matrix. Note: I am skipping a large chunk of the paper which goes over efficient training - my main interest is the capabilities and the formulation of the model - but Section 3 is also critical.

Now we will now show how we can solve a state tracking NC1\text{NC}^1 problem. So in Appendix A: Representation Power of Transformers with PaTH Attention, we see how the PaTH attention block can solve a swapping tasks: given the permutation group of 55 elements S5S_5, consider a sequence of tokens that denotes the elements (swaps). The sequence is constructed in the following form:

#[a1b1][a2b2][anbn]\# [a_1 \leftrightarrow b_1] [a_2 \leftrightarrow b_2] \ldots [a_n \leftrightarrow b_n]

The # is the start token and each one of the bracket terms are another token that denotes a particular swap action, thus including the # there are 21 tokens in total, each have a unique one hot vector uu. I am going to now go a bit out of order in how the explain things (just to write down my understanding). First we define the Householder weight WwW_w.

Wwu=w[u]=(exey)2forv=[xy],and0ifv=#W_wu = w[u] = \frac{(e_x - e_y)}{\sqrt{2}} \quad \text{for} \, v = [x \leftrightarrow y], \, \text{and} \, \, 0 \, \, \text{if} \, v=\#

ex,eye_x, e_y are basis vectors, and when we define now the house holder transformation, we can see how it performs a swap operation.

H=I2ww, w=exey2.H = I - 2ww^\top, \qquad w = \tfrac{e_x - e_y}{\sqrt{2}}.
Hex=ex2w(wex)=ex2w ⁣(12)=ex(exey)=ey.He_x = e_x - 2w(w^\top e_x) = e_x - 2w\!\left(\tfrac{1}{\sqrt{2}}\right) = e_x - (e_x - e_y) = e_y.
Hey=ey2w(wey)   =ey2w ⁣(12)   =ey+(exey)   =ex.He_y = e_y - 2w(w^\top e_y)     = e_y - 2w\!\left(-\tfrac{1}{\sqrt{2}}\right)     = e_y + (e_x - e_y)     = e_x.

For j{x,y}j \notin \{x,y\}:

Hej=ej2w(wej)=ej.He_j = e_j - 2w(w^\top e_j) = e_j.

Thus, it just swaps exe_x and eye_y while preserving the rest in place. Cool!

Wku=k[u]=1{u=#}(e1+2e2+3e3+4e4+5e5e6),Wqu=q[u]=n(e1+2e2+3e3+4e4+5e5+54.5e6),Wwu=w[u]=(exey)/2 for v=[xy], and 0 if v=#,Wvu=v[u]=1{u=#}e1,β=2.\begin{align*} W_k u &= k[u] = \mathbf{1}\{u = \#\}(e_1 + 2e_2 + 3e_3 + 4e_4 + 5e_5 - e_6), \\[6pt] W_q u &= q[u] = n(e_1 + 2e_2 + 3e_3 + 4e_4 + 5e_5 + 54.5e_6), \\[6pt] W_w u &= w[u] = (e_x - e_y)/\sqrt{2} \ \text{for } v = [x \leftrightarrow y], \ \text{and } 0 \text{ if } v = \#, \\[6pt] W_v u &= v[u] = \mathbf{1}\{u = \#\} e_1, \\[6pt] \beta &= 2. \end{align*}

These are all the parameters defined. So the only key and value that are non-zero are for the first token #, then if use the PaTH formulation to get the output for the last token, we obtain:

s0  =  k0 ⁣(s=1nHs)qn   =  n(i=15iπ(i)54.5).s_0 \;=\; k_0^{\top}\!\Bigg(\prod_{s=1}^{n} H_s\Bigg) q_n \;=\; n\left(\sum_{i=1}^{5} i\,\pi(i) - 54.5\right).

where π(i)\pi(i) is i’th element of the vector after all the permutations/swaps were applied to the original vector [1,2,3,4,5][1,2,3,4,5]. So when π\pi is the identity after all the swaps, then s0>0s_0 > 0, and otherwise s0<0s_0 < 0. There is more to it to make this a final prediction of 11 or 1-1, but this is the essence of it and I will end here.

Final thoughts. I think this paper is very interesting because it added the extra computation in the positional encoding, showing how we do more complex transformations of the tokens when comparing them for attention. It is also interesting to think about what are the other axis we can efficiently inject more computation in the transformer.