This is going to be a more exploratory post on a splatter of interesting ideas from multiple papers discussing model merging. Model merging is a very interesting topic because the idea of taking model trained on different distributions and somehow using them to create a new model that can operate over both distributions is fascinating.
Why Merging?
Given the recent popularity of LLMs and their impressive capabilities, users are effectively finetuning these models on smaller datasets to elicit custom behaviors. So now we have many finetuned checkpoints floating around, and so, one can ask: could we somehow combine all these checkpoints into a mega model that has all the capabilities from the individual checkpoints? Sounds amazing if true, and it is partly true!
Main Resources
This post is going to mainly be talking about ideas explained in the following posts / papers. It may be good to read them first to give a better context.
Papers:
Core Concepts:
Preliminary
Typically model merging is posed using the following formulation.
Consider we have a base model F0(y∣x;θ) and M datasets Di,i∈1…M. Then F0 is separately trained on D1…DM yielding, F1…FM. Since the architecture of F0 doesn’t change during training, we can just consider the trained parameters θ1…θM. Now, we want to obtain a single θ1:M which is the optimal parameter such that F(y∣x;θ1:M) is able to perform well on D1…DM.
A simple approach, averaging.
Indeed, averaging is a commonly used method, where θ^1:M=M1∑i=1Mθi. The federated averaging community uses this quite often, known as FedAvg.
Can curvature help us?
If all we have access to is, θ1…θM, then there isn’t too much we can do. However, what if we start to extract more information, like understanding the model’s sensitivity to different parameter dimensions? To more concretely view this, consider F to be a likelihood model, that is, P(y∣x;θ). Now if we look at the posterior, P(θ∣y,x), we can frame the problem of finding the parameters that is best at solving D1…DM as the following
This yields, θ^1:M=M1θi. However, the assumption we made was the posterior can be approximated by an Gaussian distribution with unit covariance. However, we know that it is likely that the shape isn’t going to be isotropic. So here we can utilize the Laplace approximation for the posterior, which I linked above. In essence, it is also trying to approximate the posterior distribution with a Gaussian distribution, but not with an isotropic one. Instead we will look at the second order characteristics of the distribution to shape the covariance matrix. Specifically, the Laplace approximation models the posterior as:
I am going to take a small detour now to talk about something we will use a bit later, the Fisher information matrix. First we define the score function as, ∇θlogp(y∣x;θ). It can be shown that
p(y∣x;θ)E[∇θlogp(y∣x;θ)]=0. Using this, we can also look at the covariance which yields the Fisher information matrix I
It turns out that under the modes of the distribution, we can substitute, I=(−H) where H is the Hessian for the Laplace approximation see [Perone et al]. Then we can revisit the original objective, and update it to the following
In practice usually only the diagonal of the Fisher is used for computational efficiency, and so each merged parameter dimension is scaled by the normalized Fisher diagonal values, the authors a closed form solution:
As a mental exercise think about the intuition behind this, for parameters that have a high Fisher information, we would want them to be less changed when averaging, whereas the low Fisher information values can be changed more without as much degradation in the performance. And try to draw the connection in terms of curvature in a quadratic function like θTHθ, we wouldn’t want the move too far away from the principle axis of curvature (the eigen vector of H) that corresponds to the larger eigen value of H. In addition, geometrically (disclaimer this is handy wavy and just how I got a visual understanding, this is not rigorous at all) if we look at ellipsoid quadratic equation, and their negative version (−H), we can see you get a quadratic ellipsoid with that is flipped across the domain plane. Then if we look at ∇θlogp(y∣x;θ) samples, we can see that it too will form an ellipsoid with the same principle axis. Then when we take each ∇θlogp(y∣x;θ) and make a rank 1 matrix ∇θlogp(y∣x;θ)∇θlogp(y∣x;θ)T if we then look at a quadractic form this takes, via plotting vT∇θlogp(y∣x;θ)∇θlogp(y∣x;θ)Tv, then we can see it forms a trough shape. Then we just plot both trough’s that correspond to the principle axis, and when we combine both plots we roughly recover the negative hessian plot. These values are eye balled, so it won’t be exact lol.
Another angle is Task Arithmetic (TA), and the paper from [Daheim et al] describes cool analysis on it. Consider the details of the loss function typically used. In that we optimize for the task loss ℓˉi for task i, and we have a weight regularization ∣∣θ∣∣2.
A cool equation they show is what are the errors when performing TA, what if we want to merge the parameters with weightings on the task given by {α}i=1M, we can see what does the difference between the optimal parameter which is given as
Where, θbase are the weights of some pretrained model that we don’t want significantly diverge from, and Mahalanbois distance, which is a quadratic distance, is used as the regularizer. ∣∣θ∣∣H02=θTH0θ. And the TA merged parameters which are
So how does, θ1:M compare to ∑i=1Mαiθi ? Turns out we can derive the difference to be!
Then we can Taylor expand, ∇ℓi(θ)≈∇ℓi(θ)+Ht(θ−θt), where Hi is the Hessian. Then
So it shows that under the Taylor approximation, we should be utilizing the Hessians to shape how we merge the “task vectors” (θi−θbase)
This starts to elude to the final paper[Tam et al] which I will briefly discuss here which says that a general perspective to see model merging is
Where Ci is an “(approximate) covariance matrix of some random variable”. And depending on how you set Ci you can model common techinques like Simple Averaging, Fisher Merging . This goes to show that we really are warping the parameters under a linear transfirmation and reweighing them. I won’t go further into that paper’s details, but do check it out. I mainly wanted to end on this because it is interesting to think about what if we move beyond the general framework they presented. Thanks for listening, and feel free to shoot me an email if there is a mistake / if you want to see other topics covered.