MoE

Introduction

Mixture of Experts (MoE) have gotten popular recently with the rise of large language models and multi modal reasoning. They are not a new idea, and have existed for a while in the form of Ensemble Methods. For example, you might have heard of Bagging and Boosting. Bagging refers to training different models on different random partitions of data and then aggregate their results to produce a more robust model. Boosting involves training models sequentially, and each consecutive model is trained on a reweighed data depending on the previous models performance. In addition, you probably have seen models like Gaussian Mixture Models. While they are simple, they capture the essence of the motivation for a mixture model, model a more complex distribution through explicit use of simple distributions / functions.

From: https://www.youtube.com/watch?v=U8J32Z3qV8s

Key Papers

To be transparent, most of the papers I choose are the ones that Finbarr Timbers used it his awesome blogs on MoE, make sure to check his page out! He seemed to already capture the key ideas, but hopefully I added some extra insights.

OUTRAGEOUSLY LARGE NEURAL NETWORKS:
THE SPARSELY-GATED MIXTURE-OF-EXPERTS LAYER

They present a model in the form

y=i=1nG(x)iEi(x)where,G(x)=Softmax(KeepTopK(H(x),k))H(x)i=(xWg)i+N(0,1)Softplus((xWnoise)i)KeepTopK(v,k)i=viifvi{top k}elsey = \sum_{i=1}^nG(x)_iE_i(x) \\ \text{where}, G(x) = \text{Softmax}(\text{KeepTopK}(H(x), k)) \\ H(x)_i = (xWg)_i+N(0,1)*\text{Softplus}((xW_{noise})_i) \\ \text{KeepTopK}(v,k)_i = v_i\,\, \text{if}\,\,v_i\in\{\text{top k}\}\,\, \text{else}\,\, -\infty

Both Wg,WnoiseW_g, W_{\text{noise}} are learned through normal back propogation. I think a important takeaway is the WnoiseW_{\text{noise}} parameters, because it allows for exploration, but under expectation, once Ei(x)E_i(x)  converge to their optimal functions, then WnoiseW_\text{noise}  should converge to zero as WgiWg_i converges to the optimal value. In addition kk should be larger than one, because the SwitchFormer authors write that

Shazeer et al. (2017) conjectured that routing to k > 1 experts was necessary in order to have non-trivial gradients to the routing functions. The authors intuited that learning to route would not work without the ability to compare at least two experts.

The authors convey that problem that often arises in these setups is

We have observed that the gating network tends to converge to a state where it always produces large weights for the same few experts. This imbalance is self-reinforcing, as the favored experts are trained more rapidly and thus are selected even more by the gating network. Eigen et al. (2013) describe the same phenomenon, and use a hard constraint at the beginning of training to avoid this local minimum. Bengio et al. (2015) include a soft constraint on the batch-wise average of each gate.

To mitigate this issue, they introduce a Importance loss term that tries to enforce a higher variation of the gating value over a batch XX.

Importance(X)=xXG(x)Limportance(X)CV(Importance(X))2\text{Importance}(X) = \sum_{x\in X}G(x) \\ L_\text{importance}(X) \propto CV(\text{Importance}(X))^2

Where CV is the coefficient of variation σ/μ\sigma/\mu. This encourages the model to have uniform gating across a batch. However, is still not computationally ideal because of the following reason.

The authors write that

We want to define an additional loss function to encourage experts to receive roughly equal numbers of training examples. Unfortunately, the number of examples received by an expert is a discrete quantity, so it can not be used in backpropagation.

This also helps in a distributed setup, where computationally is more evenly spread across.

The problem with the Importance\text{Importance} loss term is that, as you sum across the batch you loose information of the gate values for individual data points in the batch, this information loss is why the aforementioned problem arises. So the authors have a new metric

Let P(x,i)P(x,i) denote the probability that “probability that G(x)iG(x)_i is nonzero, given a new random choice of noise on element ii, but keeping the already-sampled choices of noise on the other elements”. And they create an additional loss term which will spread apart values per each column which will prevent the degenerate case that the Importance term can suffer from.

Load(X)i=xXP(x,i)Lload(X)CV(Load(X))2\text{Load}(X)_i = \sum_{x\in X}P(x,i) \\ L_{\text{load}}(X) \propto CV(\text{Load}(X))^2

Something I am not sure about is why we can’t just use the Load loss term and drop the Importance term. One final detail is they set Wnoise,WgW_{\text{noise}}, W_g to all zeros because that will have a uniform weighting over the experts initially, which helps with allowing them to specialize.

Switch Transformers: Scaling to Trillion Parameter Models
with Simple and Efficient Sparsity

Fig 2 from the paper, note the routing function only looks at its current token and its independent of previous tokens

Each token is routed to one expert. So this different than previous works that show that tokens should be routed for multiple experts. The authors show that this is no longer the case and this also improves computational efficiency.

Figure 3 from paper

In this figure we see that for a batch of tokens, each expert has a budget of how many tokens they can process in total. This is denoted by the Expert Capacity. If its too small, then some tokens won’t be processed (the blue one in the left picture), however, too large of a capacity is also inefficient.


For their load balancing loss, let
NN be the number of experts, and BB be the batch that has TT  tokens. The loss is:

loss=αNi=1NfiPi,fi=1TxB1{argmaxp(x)=i},1 is the indicator functionPi=1TxBpi(x)loss = \alpha N \sum_{i=1}^Nf_iP_i, \\ f_i = \frac{1}{T}\sum_{x\in B}\mathbf{1}\{\text{argmax}\, p(x) = i\}, \mathbf{1} \text{ is the indicator function}\\ P_i = \frac{1}{T}\sum_{x\in B}p_i(x)

One neat thing is with this loss, you don’t need to have both a load balancing and importance loss as the previous paper had. Lets unpack what this loss is doing. Since fif_i will roughly be aligned with PiP_i then we can say in a handy wavy way that the loss can be minimized when both are uniform. This also prevents cases where pi(x)p_i(x)  is very unimodal because for the same vector F={f}iNF = \{f\}_i^N, there can be mutiple p1n(x)p_{1-n}(x), so the uniform pp  would minmize the loss the most.

Importantly, the authors show that there is consistently an improvement when adding more experts and this is done with the same computational budget, see figure to the right. And they also show the scaling is better than the traditional dense scaling. See figure below.

Part of figure 4 from the paper

Figure 6 from paper

Hash Layers For Large Sparse Models

The MoE layers are implemented to replace the feed forward networks in original transformer, SwitchFormer style. Most papers replace the FFN with the MoE layers because FFNs are ‘’the most computationally expensive part in a Transformer-based network” - (Zhou et al. 2022).

I found this paper pretty surprising because you can get good performance with a random mapping between the token and which FFN it gets routed to. This seems counter intuitive because one would expect that a dynamic routing model that is able to decide which expert to send the token to depending on the token’s embedding would provide for more flexibility. The authors write:

We are free to choose from various possible hash functions, which we will consider below. However, for training purposes, the hash function is fixed in advance, and in this way, our routing mechanism requires no training and has no adjustable parameters …


So one problem with this is, because of the Zipfain distribution, which well models the distributions of word frequencies, the distribution of experts being used will also be skewed. So they came up with a
Balanced Hash, which uses the distribution of the training data and tries to rehash to obtain a less skewed distribution over the hash buckets. Another version is the Clustered Hash, which performs k-means on the token embeddings, this will hash similar tokens to the same function. Interestingly they also try the opposite of this where within a cluster from k-means, they will spread out the tokens within that cluster over the buckets. The authors motivation for this is:

very similar tokens need fine distinctions which requires more model capacity (hence assigning to different experts)

One final version they try is to hash part of the weight matrix for the feed forward network: B(relu(A(h)))B(\text{relu}(A(h))).

v=relu([Ak1(h),...,AkN(h)]),FFN(h)=[Bk1(v),...,Bkn(v)]v = \text{relu}([A_{k_1}(h), ..., A_{k_N}(h)]), FFN(h) = [B_{k_1}(v), ..., B_{k_n}(v)]

Where, kik_i is determined by the hash: ki=hashi(x)k_i = \text{hash}_i(x)

DSelect-k: Differentiable Selection in the Mixture of Experts with Applications to Multi-Task Learning

So in most of the other works in this page we often use a top k select over the gating values. The authors of this paper propose that that could lead to instabilities during training, because the loss landscape is no longer smooth. So to reiterate, the prior MoE models often are equivalent to solving:

minf1,...,fn,w1N(x,y)D(y,i=1nfi(x)wi)s.t.w0ki=1nwi=1,w0\underset{f_1, ..., f_n, w}{\min}\frac{1}{N}\sum_{(x,y)\in D}\ell(y, \sum_{i=1}^{n}f_i(x)w_i)\\ \text{s.t.}\,\,\, ||w||_0\leq k \\ \sum_{i=1}^n w_i = 1, w \geq 0

So here the L0L_0 norm constraint is what makes it difficult for our usual gradient based optimizers. Consequently, the contribution of this work is to convert this into a unconstrained optimization problem. Pretty cool!

r(z)i=jB(i1)(zj)j[m]\B(i1)(1zj)r(z)_i = \prod_{j\in B(i-1)}(z_j)\prod_{j\in [m] \backslash B(i-1)}(1-z_j)

This formula can map binary numbers to one hot vectors. For example, if z=[1,0]z = [1,0], then r(z)2=1r(z)_2 = 1, since I am using 0 indexing. Thus we can use this to obtain a mixture over k experts with a stack of k binary numbers which is ZZ.

q(α,Z)=i=1kσ(α)ir(z(i))q(\alpha, Z) = \sum_{i=1}^k\sigma(\alpha)_ir(z^{(i)})

So our new optimization problem becomes:

minf1,...,fn,α,Z1N(x,y)D(y,i]1nfi(x)q(α,Z)i)z(i){0,1}m,i[k]\underset{f_1, ..., f_n, \alpha, Z}{min} \frac{1}{N}\sum_{(x,y)\in D} \ell(y, \sum_{i]1}^nf_i(x)q(\alpha, Z)_i)\\ z^{(i)}\in \{0,1\}^m, i\in[k]

But this is still not that useful because z(i)z^{(i)}  is still a binary vector which becomes a combinatorial optimization problem which is not what we want. So instead lets relax z(i)z^{(i)} to be continuous and we can do that with the following.

S(t) smooth function that can exactly equal 0, 1.

q~(α,Z)q(α,S(Z))=i=1kσ(α)ir(S(z(i)))\tilde{q}(\alpha, Z) \coloneqq q(\alpha, S(Z)) = \sum_{i=1}^k\sigma(\alpha)_ir(S(z^{(i)}))
minf1,...,fn,α,Z1N(x,y)D(y,i=1nfi(x)q~(α,Z)i)+λΩ(Z)\underset{f_1, ..., f_n, \alpha, Z}{min} \frac{1}{N}\sum_{(x,y)\in D} \ell(y, \sum_{i=1}^nf_i(x)\tilde{q}(\alpha, Z)_i) + \lambda \Omega(Z)\\

The entropy isn’t directly calculated on ZZ but Ω(Z)i=1kh(r(S(z(i))))\Omega(Z)\coloneqq \sum_{i=1}^kh(r(S(z^{(i)}))), where hh is an entropy function. The authors state that the entropy regularization isn’t needed because empirically the z(i)z^{(i)} will become a binary vector, but for faster convergence the entropy term helps. So α,Z\alpha, Z does not depend on xx , but you can easily do that as well via a linear transformation.

BASE Layers: Simplifying Training of Large, Sparse Models

This paper has a similar setup to the previous paper with some key differences. So to go over their notation, they have EE experts, and each one is denoted by fef_e and its learnable representation weRDw_e\in \mathbb{R}^D to allow us to to routing. hth_t  is the token embedding and at{0,...,E}a_t \in \{0,..., E\} is the assignment of the token to expert. So the overall model takes the following form:

σ(htwat)fat(ht)+ht\sigma(h_t\cdot w_{a_t})f_{a_t}(h_t) +h_t

The assignment during training and testing is different: the authors write that

During training, we maximize model throughput by assigning an equal number of tokens to each expert. At test time, we simply assign each token to its highest scoring expert.

So during training they solve the well studied assignment problem

maximizethtwats.t.et=1T1at=e=TE\text{maximize}\sum_th_t\cdot w_{a_t}\\ s.t. \forall e \sum_{t=1}^T \mathbb{1}_{a_t=e} = \frac{T}{E}

Here TT is the number of tokens. One potential algorithm to solve this is the famous Hungarian matching one, but that is O(n3)O(n^3) and not parallelizable. Instead the authors use a different Auction Algorithm (Bertsekas et al. 1922) which is “which is more easily parallelizable on GPUs than the Hungarian Algorithm”.


O
ne important point is that since the partition of the dataset over workers would not be IID, they randomly shuffle the tokens across workers before calculating the assignments.

Mixture-of-Experts with Expert Choice Routing

The authors start off by highlighting some of the previous problems with MoE models where the routing function decides which for each token, which expert to route it to. These problems revolve around load imbalance, and as we have seen so far, there are multiple different heuristics / regularizations to encourage more uniform load over the experts. So to be concise, here is how they implement their method:

S=softmax(XWg),SRn×eG,I=TopK(ST,k),P=Onehot(I)S = \text{softmax}(XW_g), S\in\mathbb{R}^{n\times e} \\ G,I = \text{TopK}(S^T, k), P = \text{Onehot(I)}

XRn×dX \in \mathbb{R}^{n \times d} XRn×dX \in \mathbb{R}^{n \times d}ddXRn×dX \in \mathbb{R}^{n \times d}ddXRn×dX \in \mathbb{R}^{n \times d}II  is an index matrix, and G is the weights of the selection. So then a permutation matrix PRe×k×nP \in \mathbb{R}^{e \times k \times n} is calculated based on II to reshuffle the indicies such that XinRe×k×d=PXX_{in} \in \mathbb{R}^{e \times k \times d}= PX  allows you to index the tokens per expert. Then the FFN and reverse shuffle/weighting is defined as:

i,Xe[i]=GeLU(Xin[i]W1[i])W2[i]TXout[l,d]=i,jP[i,j,l]G[i,j]Xe[i,j,d]\forall i, X_e[i] = \text{GeLU}(X_{in}[i]W_1[i])W_2[i]^T\\ X_{out}[l,d] = \sum_{i,j}P[i,j,l]G[i,j]X_e[i,j,d]

Notice here that there is nothing to prevent experts taking in the same tokens. While it seems that a more even spread of tokens better utilizes the model’s capacity, they introduce a entropy regularization in the form of:

maxA<ST,A>+λH(A)s.ti:jA[i,j]=k;j:iA[i,j]b;i,j:0A[i,j]1\underset{A}{\max}<S^T, A>+ \lambda H(A) \\ \text{s.t}\,\,\, \forall i: \sum_{j'}A[i,j']=k; \,\,\, \forall j: \sum_{i'}A[i',j]\leq b; \forall i,j: \,\, 0\leq A[i,j]\leq 1

Where H(A)H(A) calculates the entropy. So this is a entropy regularized linear program, and after obtaining AA , instead of STS^T the top k is calculated using AA. So what this allows is, if STS^T’s top k are rather skewed towards a small set of tokens, depending on how much larger the top k values are compared to the rest, and λ\lambda, A can reselect a more uniform set of tokens.