Centered Kernel Alignment

After reading the enlightening paper The Platonic Representation Hypothesis and talking to my mentor Jacob Huh, I found the tool of comparing kernels very interesting. Although neural networks are a black box that is hard to interpret, there are ways in which we can compare representational similarity. While this does not provide a complete picture by any means of how representations and networks differ from each other, it does provide a good starting point. The particular method I will be diving into for this post is called the Centered Kernel Alignment.

Problem Setup

If we have a neural network like the one illustrated on the right, then we can tap in to the internal representations in various ways, like average token embeddings at a specific layer, etc. Now the question is given two distributions of representations: X,Y\mathcal{X}, \mathcal{Y}, how can we compare them?


The Hilbert-Schmidt Independence Criterion (HSIC)

We are going to use some tools, namely, the Reproducing Kernel Hilbert Space (RKHS). By definition, we have a Hilbert Space H\mathcal{H}, which is a vector space with an defined inner product, and we need a kernel function K:X×XRK: X \times X \rightarrow \mathbb{R} s.t. fH,f(x)=f,K(x,)\forall f \in \mathcal{H}, f(x) = \langle f, K(x, \cdot) \rangle. So now for each distribution of representations let's create a different RKHS: F,G\mathcal{F}, \mathcal{G} each equipped with its own kernel K,LK, L respectively, (in practice we use the same kernel function).

Cx,y=ExX,yY[(K(x,)μx)(L(y,)μy)]μx=ExX[K(x,)]μy=EyY[L(y,)]C_{x,y} = \mathbb{E}_{x\sim X,y \sim Y}[(K(x, \cdot) - \mu_x) \otimes (L(y, \cdot) - \mu_y)] \\ \mu_x = \mathbb{E}_{x \sim X}[K(x,\cdot)] \\ \mu_y = \mathbb{E}_{y \sim Y}[L(y,\cdot)]

This is the Cross Covariance Operator, and it is capturing how “aligned” F,G\mathcal{F}, \mathcal{G} under the distributions X,YX,Y respectively. To add more clarity, by looking at the cross product between the kernels for a sampled pair x,yx,y of inputs, you can tell how “aligned” their kernel distances are the range of xsx’s and ysy’s respectively.

To see this the distances between

If we sample some inputs and transform them into their respective representation spaces via ϕ1,ϕ2\phi_1, \phi_2, then associations across their dot product distances is captured by the tensor product in RKHS space. The i,ji,j element in Cx,yC_{x,y} captures the cross similarity of their associations.

If we have a set of empirical ordered samples (xi,yi)i=1n(x_i,y_i)^n_{i=1} then we can empirically approximate Cx,yC_{x,y} via the following equations.

Kˉi,j=K(xi,xj),Lˉi,j=L(yi,yj)H=I1n11TK~=HKˉH,L~=HLˉHHSIC(Kˉ,Lˉ)=1(n1)2tr(K~L~)\bar{K}_{i,j} = K(x_i, x_j), \,\,\,\bar{L}_{i,j} = L(y_i,y_j) \\ H = I - \frac{1}{n}\mathbb{1}\mathbb{1}^T \\ \tilde{K} = H\bar{K}H, \,\,\,\tilde{L} = H\bar{L}H \\ \text{HSIC}(\bar{K},\bar{L}) = \frac{1}{(n-1)^2}\text{tr}(\tilde{K}\tilde{L})

HSIC being equal to 00 implies that X,YX,Y are independent, and the HSIC is equal to MMD(P(X,Y)P(X)P(Y))MMD(P(X,Y) \lVert P(X)P(Y)), where MMD is the Maximum Mean Discrepancy.

Finally to obtain the CKA, we normalize the HSIC to make it invariant to uniform scaling.

CKA(Kˉ,Lˉ)=HSIC(Kˉ,Lˉ)HSIC(Kˉ,Kˉ)HSIC(Lˉ,Lˉ)\text{CKA}(\bar{K},\bar{L}) = \frac{\text{HSIC}(\bar{K},\bar{L})}{\sqrt{\text{HSIC}(\bar{K},\bar{K})\text{HSIC}(\bar{L},\bar{L})}}

So the important things to note are that if the representations across two networks are the same up to an orthogonal transformation, then the CKA will be invariant to them. It won’t be invariant to any linear transformation. It is also interesting that if ϕ1,ϕ2\phi_1, \phi_2 are identical then:

tr(Cx,yTCx,y)=tr(Cx,xTCy,y)\text{tr}(C_{x,y}^TC_{x,y}) = \text{tr}(C_{x,x}^TC_{y,y})

Here I am abusing notation a bit, and Ca,bC_{a,b} is an empirical estimation.

References

This blog is great: https://jejjohnson.github.io/research_journal/appendix/similarity/hsic/

https://arxiv.org/pdf/1905.00414

https://www.youtube.com/watch?v=6Xp0i--pg5M

https://www.gatsby.ucl.ac.uk/~gretton/coursefiles/lecture5_covarianceOperator.pdf

ChatGPT