A mixture-of-experts layer is a bet that you can buy capacity without buying compute: replace one dense FFN with EE of them, let a router send each token to kEk \ll E of them, and your parameter count grows while per-token FLOPs barely move. The bet mostly pays off — and the bill arrives at the router. Everything that makes MoE hard to train and annoying to serve (instability, dropped tokens, load imbalance, all-to-all traffic, a memory footprint set by total and not active parameters) traces back to one discrete decision made per token, per layer. This is a guide to that decision, not to the marketing slide.

The router is a discrete bottleneck

The MoE layer swaps the position-wise FFN for EE expert FFNs plus a gating network. The gate is usually just a linear projection g=Wrhg = W_r h producing one logit per expert; a softmax turns logits into routing weights, and you keep the top-k. Switch Transformers (Fedus et al., 2021) pushed this to the limit with top-1 — exactly one expert per token — arguing routing quality survives even the most aggressive sparsity; GShard (Lepikhin et al., 2020) used top-2. The token output is the gate-weighted sum over the selected experts:

y=iTopK(g)gijTopK(g)gjEi(h)y = \sum_{i \in \mathrm{TopK}(g)} \frac{g_i}{\sum_{j \in \mathrm{TopK}(g)} g_j} \, E_i(h)

where the surviving gates are renormalized to sum to one (Mixtral-style).

Here is the thing every MoE paper steps around: the part of routing that matters is not differentiable. The router learns only through the magnitude of the gate it places on experts it already selected — y/gi\partial y / \partial g_i exists for ii in the chosen set, because gig_i scales an expert output that touches the loss. Which experts get selected is an argmax/top-k, and argmax has zero gradient almost everywhere. Nothing in the loss says “you should have routed this token to expert 7 instead of expert 3.” The original noisy top-k gate (Shazeer et al., 2017, “Outrageously Large Neural Networks”) injected Gaussian noise into the logits precisely to manufacture exploration in a process that otherwise has none.

moe_route.py
import torch, torch.nn.functional as F
def moe_route(x, W_r, k, n_experts):
# x: [T, d] tokens, W_r: [d, E] router weights
logits = x @ W_r # [T, E]
probs = F.softmax(logits.float(), -1) # route in fp32 (see stability)
topk_p, topk_i = probs.topk(k, dim=-1) # [T, k]
gates = topk_p / topk_p.sum(-1, keepdim=True) # renormalize
# Switch/GShard load-balance aux loss (top-1 form shown)
P = probs.mean(0) # mean prob per expert
f = F.one_hot(topk_i[:, 0], n_experts).float().mean(0) # dispatch fraction
aux = n_experts * (f * P).sum()
return topk_i, gates, aux

Keeping experts busy: load balancing and capacity

Left alone, routers collapse onto a handful of experts. It is self-reinforcing — a popular expert sees more tokens, trains faster, becomes more attractive, and the long tail of experts never receives gradient. Two families of fixes exist, and the difference between them is where you put the balancing pressure.

Auxiliary load-balance loss

GShard and Switch add a term minimized when tokens spread uniformly. With fif_i the fraction of tokens dispatched to expert ii and PiP_i the mean router probability for expert ii over the batch:

Laux=αEi=1EfiPiL_{aux} = \alpha \, E \sum_{i=1}^{E} f_i \, P_i

fif_i is a hard count with no gradient; PiP_i is the soft, differentiable factor, so the gradient nudges the router to put probability mass on under-used experts.

The coefficient α\alpha stays small (Switch used 0.01) because this loss competes directly with the language-modeling objective — crank it up and you balance perfectly while learning nothing.

Aux-loss-free bias adjustment

Because the aux loss is a gradient fighting your real loss, newer designs route the balancing signal around the gradient entirely. DeepSeek’s loss-free strategy (Wang et al., 2024, arXiv 2408.15664) adds a per-expert bias bib_i to the scores used only for top-k selection, never to the gate weight that combines outputs, and updates it with a feedback controller after each step:

bibi+usign(cˉci)b_i \leftarrow b_i + u \cdot \mathrm{sign}(\bar{c} - c_i)

where cic_i is expert ii‘s recent load and cˉ\bar{c} the mean. Underloaded experts get bumped up, overloaded ones down; the gate values that actually scale expert outputs stay clean and untouched by the balancing logic.

aux_free.py
scores = x @ W_r # [T, E]
sel = (scores + bias).topk(k, dim=-1).indices # bias affects SELECTION only
gates = torch.sigmoid(scores).gather(-1, sel) # gate from raw score, not bias
# after the optimizer step, nudge bias toward balance (no grad through this):
load = torch.bincount(sel.flatten(), minlength=E).float()
bias += update_rate * torch.sign(load.mean() - load)

DeepSeek-V3 uses this with only a tiny complementary sequence-wise aux loss; treat the exact recipe as version-specific rather than canonical.

Expert capacity and token dropping

Expert parallelism needs static tensor shapes for its all-to-all, so each expert gets a fixed-size buffer: capacity equals capacity_factor × tokens / E. Tokens that overflow their expert’s buffer are dropped — they skip the FFN and pass through on the residual. The capacity factor is the throughput-versus-quality dial in a single number.

The systems community’s answer to the capacity dilemma is to refuse it: MegaBlocks (Gale et al., 2023, arXiv 2211.15841) reformulates the expert computation as block-sparse GEMMs so variable token counts need no padding and nothing gets dropped. “Dropless” is the right default when your kernels support it.

Architectural axes: granularity, shared experts, placement

Four knobs define an MoE’s shape, and they interact.

  • Number of experts. Switch scaled to thousands of experts for pretraining throughput studies; production decoder LLMs cluster much lower. Mixtral 8x7B uses 8 experts with top-2; DeepSeek-V3 reports 256 routed experts with top-8. These counts are version-specific architecture choices, not laws.
  • Expert granularity. DeepSeekMoE (Dai et al., 2024, arXiv 2401.06066) splits each expert into several smaller ones and raises kk to compensate, holding FLOPs roughly constant. The payoff is combinatorial: with more, finer experts and a larger top-k, the number of distinct expert combinations a token can select explodes, which gives routing more expressive specialization without more compute.
  • Shared / always-on experts. Route every token through one or two experts unconditionally, in addition to the top-k. The shared experts absorb common, position-agnostic computation so the routed experts are freed to specialize instead of all relearning the same baseline transformation.
  • Placement. MoE typically replaces the FFN sublayer while attention stays dense and shared. Some designs interleave — MoE every other block, dense elsewhere — or keep the first and last layers dense, where routing tends to be noisiest.

Stability: collapse, z-loss, and precision

MoE training fails in characteristic ways: a sudden loss spike, or a slow divergence as the router commits to a degenerate solution. Three levers matter more than the rest.

Router z-loss. ST-MoE (Zoph et al., 2022, arXiv 2202.08906) penalizes the magnitude of the router logits directly:

Lz=1Bb=1B(logi=1Eexi(b))2L_z = \frac{1}{B} \sum_{b=1}^{B} \left( \log \sum_{i=1}^{E} e^{x_i^{(b)}} \right)^2

This keeps the log-sum-exp small so the softmax never has to exponentiate large logits — the operation that overflows or loses all precision in bf16. It is cheap insurance and it also tends to improve final quality, not just stability.

Precision. Compute the router in fp32 even when the rest of the block runs in bf16. Routing is a discrete decision: a rounding error that flips one logit past another changes which expert a token visits, which changes the gradient that expert receives, which is a non-smooth perturbation the optimizer cannot average away. The expert FFNs can stay in bf16; the gate cannot.

Initialization. Switch found that a smaller initialization scale on the router stabilized early training, when the gate is most prone to collapse before any expert has learned anything worth routing to.

The systems bill: expert parallelism and all-to-all

Sparsity is a memory and communication architecture wearing a compute-savings costume, and expert parallelism is where that becomes obvious. Each device holds a subset of the experts; a token computed on device A whose expert lives on device B has to travel. That is two all-to-all collectives per MoE layer: a dispatch that scatters tokens to the devices owning their experts, and a combine that gathers the results back.

All-to-all is the dominant cost, and it is latency-bound and on the critical path — it grows with the expert-parallel degree and does not overlap with compute as cleanly as the all-reduce in tensor parallelism does. In a composed 3D/4D layout you want expert parallelism mapped onto the highest-bandwidth domain you can afford (intra-node NVLink before inter-node fabric), the same way you place tensor parallelism. The fixed capacity buffers from the previous section exist precisely so these collectives have static shapes; that constraint is upstream of token dropping.

The expert computation itself is a batched/grouped GEMM: gather each expert’s assigned tokens, run its FFN, scatter back. When token counts per expert are small and uneven, those GEMMs are skinny and the kernel spends its time streaming weights rather than doing math.

Inference: active vs total, and why batch size rules everything

The headline number is the active-versus-total parameter ratio. Mixtral 8x7B reports roughly 47B total parameters with about 13B active per token (the experts are sparse, attention is shared); DeepSeek-V3’s report describes 671B total with 37B active. These are publicly reported figures tied to specific model versions and routing configs — quote them with the version attached and treat any “active params” claim without the routing recipe as unverified.

The ratio is the whole pitch: latency-relevant compute scales with active parameters, while model capacity scales with total. But your memory footprint also scales with total — you must keep every expert resident in HBM even though each token touches a few. So the bottleneck flips relative to a dense model of the same active size. MoE decode is not FLOP-bound. At small batch, each expert sees only a handful of tokens, the per-expert GEMM is memory-bandwidth-bound (you stream the full expert weights to multiply a few rows), and the all-to-all and routing overhead dominate wall-clock.

Batching is the lever that rescues this. As batch grows, tokens spread across all EE experts, each expert accumulates enough rows to amortize loading its weights, and arithmetic intensity climbs. The catch is that per-expert token counts are stochastic — you need batch sizes well above EkE \cdot k before coverage is even, and a bigger batch also inflates the all-to-all payload. The practical reading: MoE rewards high-concurrency serving and punishes low-latency single-stream decode, which is the exact opposite of what a naive active-FLOP count would predict.

What to take away

  • MoE buys parameters per FLOP, and you repay it in routing complexity, communication, and a footprint set by total parameters. Budget for the memory and the all-to-all, not just the active-param compute.
  • The router is the locus of every hard problem. Balance it explicitly — aux loss if you accept the gradient interference, aux-loss-free bias if you don’t — compute it in fp32, and keep its logit magnitudes in check with a z-loss.
  • The capacity factor is your quality/throughput dial; if your stack supports block-sparse (dropless) experts, prefer it and stop tuning the dial.
  • For serving, the only questions that matter are total-versus-active and whether your batch is large enough to keep every expert’s GEMM out of the memory-bandwidth floor. A quoted active-param figure without the model version and routing config is a rumor, not a spec.