Model Merging

This is going to be a more exploratory post on a splatter of interesting ideas from multiple papers discussing model merging. Model merging is a very interesting topic because the idea of taking model trained on different distributions and somehow using them to create a new model that can operate over both distributions is fascinating.

Why Merging?

Given the recent popularity of LLMs and their impressive capabilities, users are effectively finetuning these models on smaller datasets to elicit custom behaviors. So now we have many finetuned checkpoints floating around, and so, one can ask: could we somehow combine all these checkpoints into a mega model that has all the capabilities from the individual checkpoints? Sounds amazing if true, and it is partly true!

Main Resources

This post is going to mainly be talking about ideas explained in the following posts / papers. It may be good to read them first to give a better context.

Papers:

Merging Models with Fisher-Weighted Averaging
Averaging the parameters of models that have the same architecture and initialization can provide a means of combining their respective capabilities. In this paper, we take the perspective that...
https://arxiv.org/abs/2111.09832
Model Merging by Uncertainty-Based Gradient Matching
Models trained on different datasets can be merged by a weighted-averaging of their parameters, but why does it work and when can it fail? Here, we connect the inaccuracy of weighted-averaging to...
https://arxiv.org/abs/2310.12808
Merging by Matching Models in Task Subspaces
Model merging aims to cheaply combine individual task-specific models into a single multitask model. In this work, we view past merging methods as leveraging different notions of a ''task...
https://arxiv.org/abs/2312.04339
L2M: Practical posterior Laplace approximation with...
Uncertainty quantification for deep neural networks has recently evolved through many techniques. In this work, we revisit Laplace approximation, a classical approach for posterior approximation...
https://arxiv.org/abs/2107.04695

Core Concepts:

Fisher Information Matrix - Agustinus Kristiadi
An introduction and intuition of Fisher Information Matrix.
https://agustinus.kristia.de/techblog/2018/03/11/fisher-information/
Hessian and Curvatures in Machine Learning: A Differential-Geometric View - Agustinus Kristiadi
In machine learning, especially in neural networks, the Hessian matrix is often treated synonymously with curvatures. But, from calculus alone, it is not clear why can one say so. Here, we will view the loss landscape of a neural network as a hypersurface and apply a differential-geometric analysis on it.
https://agustinus.kristia.de/techblog/2020/11/02/hessian-curvatures/
Laplace's approximation
Laplace's approximation provides an analytical expression for a posterior probability distribution by fitting a Gaussian distribution with a mean equal to the MAP solution and precision equal to the observed Fisher information.[1][2] The approximation is justified by the Bernstein–von Mises theorem, which states that under regularity conditions the posterior converges to a Gaussian in large samples.[3][4]
https://en.wikipedia.org/wiki/Laplace's_approximation

Preliminary

Typically model merging is posed using the following formulation.

Consider we have a base model F0(yx;θ)F_0(y|x;\theta) and MM datasets Di,i1MD_i, i\in 1\ldots M . Then F0F_0 is separately trained on D1DMD_1 \ldots D_M yielding, F1FMF_1 \ldots F_M. Since the architecture of F0F_0 doesn’t change during training, we can just consider the trained parameters θ1θM\theta_1 \ldots \theta_M. Now, we want to obtain a single θ1:M\theta_{1:M} which is the optimal parameter such that F(yx;θ1:M)F(y|x;\theta_{1:M}) is able to perform well on D1DMD_1 \ldots D_M.

A simple approach, averaging.

Indeed, averaging is a commonly used method, where θ^1:M=1Mi=1Mθi\hat{\theta}_{1:M} = \frac{1}{M}\sum_{i=1}^M \theta_i. The federated averaging community uses this quite often, known as FedAvg.

Communication-Efficient Learning of Deep Networks from Decentralized Data
Modern mobile devices have access to a wealth of data suitable for learning models, which in turn can greatly improve the user experience on the device. For example, language models can improve...
https://arxiv.org/abs/1602.05629

Can curvature help us?

If all we have access to is, θ1θM\theta_1 \ldots \theta_M, then there isn’t too much we can do. However, what if we start to extract more information, like understanding the model’s sensitivity to different parameter dimensions? To more concretely view this, consider FF to be a likelihood model, that is, P(yx;θ)P(y|x;\theta). Now if we look at the posterior, P(θy,x)P(\theta|y,x), we can frame the problem of finding the parameters that is best at solving D1DMD_1 \ldots D_M as the following

θ^1:M=argmaxθ  logi=1MN(θθi,I)\hat{\theta}_{1:M} = \underset{\theta}{\text{arg\,max}}\; \log\prod_{i=1}^M \mathcal{N}(\theta | \theta_i,I)

This yields, θ^1:M=1Mθi\hat{\theta}_{1:M} = \frac{1}{M}\theta_i. However, the assumption we made was the posterior can be approximated by an Gaussian distribution with unit covariance. However, we know that it is likely that the shape isn’t going to be isotropic. So here we can utilize the Laplace approximation for the posterior, which I linked above. In essence, it is also trying to approximate the posterior distribution with a Gaussian distribution, but not with an isotropic one. Instead we will look at the second order characteristics of the distribution to shape the covariance matrix. Specifically, the Laplace approximation models the posterior as:

P(θDi)N(θθi^,[θ2logp(θD)]1)P(\theta|D_i) \approx \mathcal{N(\theta|\hat{\theta_i}, [-\nabla_{\theta}^2\log p(\theta|D)]^{-1})}

I am going to take a small detour now to talk about something we will use a bit later, the Fisher information matrix. First we define the score function as, θlogp(yx;θ)\nabla_{\theta} \log p(y|x;\theta). It can be shown that

Ep(yx;θ)[θlogp(yx;θ)]=0\underset{p(y|x; \theta)}{\mathbb{E}}[\nabla_{\theta}\log{p(y|x;\theta)}] = \vec{0}. Using this, we can also look at the covariance which yields the Fisher information matrix II

I=Ep(yx;θ)[(θlogp(yx;θ))(θlogp(yx;θ))T]I = \underset{p(y|x; \theta)}{\mathbb{E}}[(\nabla_{\theta}\log{p(y|x;\theta)})(\nabla_{\theta}\log{p(y|x;\theta)})^T]

It turns out that under the modes of the distribution, we can substitute, I=(H)I = (-H) where HH is the Hessian for the Laplace approximation see [Perone et al]. Then we can revisit the original objective, and update it to the following

θ^1:M=argmaxθ  logi=1MN(θθi,Ii1)\hat{\theta}_{1:M} = \underset{\theta}{\text{arg\,max}}\; \log\prod_{i=1}^M \mathcal{N}(\theta | \theta_i,I_i^{-1})

In practice usually only the diagonal of the Fisher is used for computational efficiency, and so each merged parameter dimension is scaled by the normalized Fisher diagonal values, the authors a closed form solution:

θ^1:M(j)=i=1MIi(j)θi(j)i=1MIi(j)\hat{\theta}_{1:M}^{(j)} = \frac{\sum_{i=1}^MI_i^{(j)}\theta_i^{(j)}}{\sum_{i=1}^MI_i^{(j)}}

As a mental exercise think about the intuition behind this, for parameters that have a high Fisher information, we would want them to be less changed when averaging, whereas the low Fisher information values can be changed more without as much degradation in the performance. And try to draw the connection in terms of curvature in a quadratic function like θTHθ\vec{\theta}^TH\vec{\theta}, we wouldn’t want the move too far away from the principle axis of curvature (the eigen vector of H)H) that corresponds to the larger eigen value of HH. In addition, geometrically (disclaimer this is handy wavy and just how I got a visual understanding, this is not rigorous at all) if we look at ellipsoid quadratic equation, and their negative version (H)(-H), we can see you get a quadratic ellipsoid with that is flipped across the domain plane. Then if we look at θlogp(yx;θ)\nabla_{\theta}\log{p(y|x;\theta)} samples, we can see that it too will form an ellipsoid with the same principle axis. Then when we take each θlogp(yx;θ)\nabla_{\theta}\log{p(y|x;\theta)} and make a rank 1 matrix θlogp(yx;θ)θlogp(yx;θ)T\nabla_{\theta}\log{p(y|x;\theta)}\nabla_{\theta}\log{p(y|x;\theta)}^T if we then look at a quadractic form this takes, via plotting vTθlogp(yx;θ)θlogp(yx;θ)Tv\vec{v}^T\nabla_{\theta}\log{p(y|x;\theta)}\nabla_{\theta}\log{p(y|x;\theta)}^T\vec{v}, then we can see it forms a trough shape. Then we just plot both trough’s that correspond to the principle axis, and when we combine both plots we roughly recover the negative hessian plot. These values are eye balled, so it won’t be exact lol.

red is original ellipsoid, blue is its negated hessian equvialent. Then green and orange are the troughs made from the rank 1 hessian, and purple is the merged rank 1 hessians.

Another angle is Task Arithmetic (TA), and the paper from [Daheim et al] describes cool analysis on it. Consider the details of the loss function typically used. In that we optimize for the task loss ˉi\bar{\ell}_i for task ii, and we have a weight regularization θ2||\theta||^2.

A cool equation they show is what are the errors when performing TA, what if we want to merge the parameters with weightings on the task given by {α}i=1M\{\alpha\}_{i=1}^M, we can see what does the difference between the optimal parameter which is given as

θ1:M=argminθ  i=1Mαi i(θ)+12θθbaseH02\theta_{1:M} = \underset{\theta}{\text{arg\,min}}\; \sum_{i=1}^M\alpha_i\ \ell_i(\theta) + \frac{1}{2}||\theta - \theta_{base}||_{\bf{H}_0}^2

Where, θbase\theta_{base} are the weights of some pretrained model that we don’t want significantly diverge from, and Mahalanbois distance, which is a quadratic distance, is used as the regularizer. θH02=θTH0θ||\theta||_{\bf{H}_0}^2 = \theta^T\bf{H}_0\theta. And the TA merged parameters which are

θi=argminθ  t(θ)+12θθbaseH02\theta_i = \underset{\theta}{\text{arg\,min}}\; \ell_t(\theta) + \frac{1}{2}||\theta - \theta_{base}||_{\bf{H}_0}^2

So how does, θ1:M\theta_{1:M} compare to i=1Mαiθi\sum_{i=1}^M \alpha_i\theta_i ? Turns out we can derive the difference to be!

θ1:M=θbase+i=1Mαi(θiθbase)i=1MαiH01[(θ1:M)(θi)]\theta_{1:M} = \theta_{base} + \sum_{i=1}^M \alpha_i(\theta_i - \theta_{base}) -\sum_{i=1}^M \alpha_i\bf{H}_0^{-1}[\nabla \ell(\theta_{1:M} ) - \nabla\ell(\theta_i)]

Then we can Taylor expand, i(θ)i(θ)+Ht(θθt)\nabla \ell_i(\theta) \approx \nabla \ell_i(\theta) + \bf{H}_t(\theta - \theta_t), where Hi\bf{H_i}  is the Hessian. Then

θ1:M=θbase+i=1Mαi(H0+t=1MαtHt)1(H0+Ht)(θiθbase)\theta_{1:M} = \theta_{base} + \sum_{i=1}^M \alpha_i(\bf{H}_0 + \sum_{t=1}^M\alpha_t\bf{H}_t)^{-1}(\bf{H}_0 + \bf{H}_t)(\theta_i - \theta_{base})

So it shows that under the Taylor approximation, we should be utilizing the Hessians to shape how we merge the “task vectors” (θiθbase)(\theta_i - \theta_{base})

This starts to elude to the final paper[Tam et al] which I will briefly discuss here which says that a general perspective to see model merging is

θ^1:M=(i=1MCi)1(i=1MCiθi)\hat{\theta}_{1:M} = (\sum_{i=1}^MC_i)^{-1}(\sum_{i=1}^MC_i \theta_i)

Where CiC_i is an “(approximate) covariance matrix of some random variable”. And depending on how you set CiC_i you can model common techinques like Simple Averaging, Fisher Merging . This goes to show that we really are warping the parameters under a linear transfirmation and reweighing them. I won’t go further into that paper’s details, but do check it out. I mainly wanted to end on this because it is interesting to think about what if we move beyond the general framework they presented. Thanks for listening, and feel free to shoot me an email if there is a mistake / if you want to see other topics covered.