← Writings

Weight-Tying: The Small, Gentle Read–Write Symmetry in Language Models

TL;DR — A language model at every step asks one simple question — given the numerical representation of the context so far, what is the probability of each vocabulary token being next? This writing unfolds that question cleanly: symbols first, forward & backward passes next, then the split of gradients that makes a single embedding table both reader and writer.

Flow (how to read this note)

  1. State the single desideratum that drives everything.
  2. Fix notation and the primitive objects we will manipulate.
  3. Walk forward: from tokens to probabilities.
  4. Walk backward: show the two distinct gradient paths into the embedding matrix and why they differ in character.
  5. Give a per-position worked derivative so nothing like magic.
  6. Read the consequences: symmetry, sparsity vs. density, and computational cost.

Prelude — the single desideratum

At each time step the model must compute a distribution over vocabulary tokens. Put plainly:

Given the numeric representation of the context so far, compute p(next token=v)p(\text{next token} = v) for every v{0,,V1}v\in\{0,\dots,V-1\}.

Everything that follows is a careful unpacking of how we represent the context, how the model turns that representation into logits and probabilities, and how gradients flow back into the same parameters that produced the representation.

Notation and primitives

We define the notation so the shapes and roles are obvious.

  • BB — batch size (parallel independent sequences).
  • TT — sequence length (tokens per sequence).
  • DD — model hidden / embedding dimension.
  • VV — vocabulary size.

Primitive tensors:

  • Xids{0,,V1}B×TX_{\text{ids}}\in\{0,\dots,V-1\}^{B\times T}: token indices after tokenization.
  • ERV×DE\in\mathbb{R}^{V\times D}: token embedding matrix; row Ev,:E_{v,:} is the vector for token vv.
  • PRT×DP\in\mathbb{R}^{T\times D}: positional embeddings (or any positional encoding of length TT).
  • XRB×T×DX\in\mathbb{R}^{B\times T\times D}: token vectors with position; for a sequence position (b,t)(b,t), Xb,t,:=EXidsb,t,:+Pt,:.X_{b,t,:} = E_{X_{\text{ids}_{b,t}},:} + P_{t,:}.
  • T()\mathcal{T}(\cdot): the transformer stack — a deterministic, differentiable function mapping XHX\mapsto H.
  • H=T(X)RB×T×DH=\mathcal{T}(X)\in\mathbb{R}^{B\times T\times D}: final hidden states.
  • Tied LM head: W=ERD×VW = E^\top\in\mathbb{R}^{D\times V}. (Tying enforces the same space for read/write.)
  • bRVb\in\mathbb{R}^V: optional bias on logits.
  • \ell: scalar loss (we use mean cross-entropy over batch and time).

These primitives suffice to write the forward and backward equations compactly and exactly.

Forward pass

Forward pass pipeline: raw text → tokenizer → embeddings + position → transformer → logits (H E^T + b) → softmax → sample
  1. Embed + add position

    Xb,t,:=EXidsb,t,:+Pt,:,XRB×T×D.X_{b,t,:} = E_{X_{\text{ids}_{b,t}},:} + P_{t,:},\qquad X\in\mathbb{R}^{B\times T\times D}.

    Why: we need a dense vector representation for each token and a way to indicate position.

  2. Transform

    H=T(X),HRB×T×D.H = \mathcal{T}(X),\qquad H\in\mathbb{R}^{B\times T\times D}.

    Why: the transformer composes attention and MLP layers to convert per-position inputs into context-aware representations.

  3. Logits via tied head

    Z=HE+1BTb,ZRB×T×V.Z = HE^\top + \mathbf{1}_{BT}b^\top,\qquad Z\in\mathbb{R}^{B\times T\times V}.

    Here 1BT\mathbf{1}_{BT} indicates broadcasting the bias to every (b,t)(b,t) slice.

    Why: each hidden hRDh\in\mathbb{R}^D scores every vocabulary row by inner product with that row; tying uses the same rows that produced embeddings.

Projection matrix multiplication: hidden states H × E^T → logits Z (Z = H E^T + b)
  1. Softmax → probabilities

    pb,t,j=exp(Zb,t,j)k=1Vexp(Zb,t,k).p_{b,t,j}=\frac{\exp(Z_{b,t,j})}{\sum_{k=1}^V\exp(Z_{b,t,k})}.
  2. Mean cross-entropy loss For targets Y{1,,V}B×TY\in\{1,\dots,V\}^{B\times T},

    =1BTb=1Bt=1Tlogpb,t,Yb,t.\ell = -\frac{1}{BT}\sum_{b=1}^B\sum_{t=1}^T \log p_{b,t,Y_{b,t}}.

These are the only forward equations required; they arise directly from the desideratum and the chosen primitives.

Backward pass — the two distinct contributions into EE (chain rule)

Because EE appears in two roles — as a lookup for inputs and as the transpose used to compute logits — the gradient E\frac{\partial \ell}{\partial E} decomposes into two additive pieces.

Gradient flow into embedding matrix E: input-side (sparse) and output-side (dense) contributions sum to ∂ℓ/∂E

1. Immediate quantity from softmax + CE (logit residuals)

Define the logit residual tensor

Δb,t,jZb,t,j=1BT(pb,t,j1j=Yb,t).\Delta_{b,t,j}\equiv\frac{\partial \ell}{\partial Z_{b,t,j}} = \frac{1}{BT}\bigl(p_{b,t,j} - \mathbf{1}_{j=Y_{b,t}}\bigr).

This ΔRB×T×V\Delta\in\mathbb{R}^{B\times T\times V} is the primitive starting point for gradients flowing into both the head (output side) and back into HH.

2. Output-side (projection) contribution — dense outer products

From Zb,t,:=Hb,t,:E+bZ_{b,t,:} = H_{b,t,:}E^\top + b we obtain the gradient on EE coming from the head:

(E)(out)=b=1Bt=1TΔb,t,:Hb,t,:.\Big(\frac{\partial \ell}{\partial E}\Big)^{\text{(out)}} = \sum_{b=1}^B\sum_{t=1}^T \Delta_{b,t,:}^{\top}\otimes H_{b,t,:}.

Concretely, for vocabulary row vv,

(Ev,:)(out)=b,tΔb,t,vHb,t,:RD.\Big(\frac{\partial \ell}{\partial E_{v,:}}\Big)^{\text{(out)}} = \sum_{b,t} \Delta_{b,t,v} H_{b,t,:} \in\mathbb{R}^D.

Character: dense — typically every vv receives some (small) contribution because the softmax probabilities are dense.

3. Input-side (lookup) contribution — sparse accumulation

Backpropagating through the transformer yields gradients to its inputs:

Gb,t,:Xb,t,:RD.G_{b,t,:}\equiv\frac{\partial \ell}{\partial X_{b,t,:}}\in\mathbb{R}^D.

Since Xb,t,:=EXidsb,t,:+Pt,:X_{b,t,:}=E_{X_{\text{ids}_{b,t}},:}+P_{t,:}, the corresponding row of EE for the token v=Xidsb,tv=X_{\text{ids}_{b,t}} receives:

(Ev,:)(in)+=Gb,t,:.\Big(\frac{\partial \ell}{\partial E_{v,:}}\Big)^{\text{(in)}} \mathrel{+}= G_{b,t,:}.

Character: sparse — only embedding rows for tokens present in the mini-batch are touched.

4. Total gradient on EE

By linearity,

E=(E)(out)+(E)(in).\frac{\partial \ell}{\partial E} = \Big(\frac{\partial \ell}{\partial E}\Big)^{\text{(out)}} + \Big(\frac{\partial \ell}{\partial E}\Big)^{\text{(in)}}.

Pragmatically, frameworks accumulate both contributions into E.grad during a single backward pass.

A per-position worked derivative (so there is no mystery)

Fix a single position (b,t)(b,t). Let h=Hb,t,:RDh=H_{b,t,:}\in\mathbb{R}^D, z=Zb,t,:RVz=Z_{b,t,:}\in\mathbb{R}^V, and y=Yb,ty=Y_{b,t} the correct index.

  • pj=ezjkezkp_j=\dfrac{e^{z_j}}{\sum_k e^{z_k}}.
  • b,tzj=pj1j=y\dfrac{\partial \ell_{b,t}}{\partial z_j}=p_j-\mathbf{1}_{j=y}.
  • Output-side update for Ev,:E_{v,:} from this position: b,tEv,:(out,pos)=(pv1v=y)h.\dfrac{\partial \ell_{b,t}}{\partial E_{v,:}}^{\text{(out,pos)}} = (p_v-\mathbf{1}_{v=y})h^\top.
  • Gradient into hh from the head: b,th=v=1V(pv1v=y)Ev,:=E(pey).\dfrac{\partial \ell_{b,t}}{\partial h} = \sum_{v=1}^V (p_v-\mathbf{1}_{v=y})E_{v,:}^\top = E^\top (p - e_y).
Outer product update per position: residuals Δ and hidden state h produce the output-side update to E

Summing these per-position contributions across (b,t)(b,t) recovers the tensor equations above. This shows how softmax residuals produce outer-product updates on EE and how they also feed back into the model via hh.

Computational Accounting

Consider B=4B=4, T=128T=128, D=4096D=4096, V=32,000V=32{,}000.

  • BT=4×128=512BT = 4\times128 = 512.
  • (BT)×D=512×4096=2,097,152(BT)\times D = 512\times4096 = 2{,}097{,}152.
  • Multiply by VV: 2,097,152×32,000=(2,097,152×32)×1000=67,108,864×1000=67,108,864,0002{,}097{,}152\times 32{,}000 = (2{,}097{,}152\times32)\times1000 = 67{,}108{,}864\times1000 = 67{,}108{,}864{,}000.

So roughly 6.71×10106.71\times 10^{10} multiply–accumulate operations are required for that single forward (or backward) pass through the head for the batch — a salutary reminder that the projection to vocabulary is often the dominant cost when VV is large.

Interpretation and Consequences

  • Read–write symmetry. Tying W=EW=E^\top means the vector space used to read tokens (embedding lookup) is identical to the space used to write tokens (projection). This reduces parameter redundancy and aligns representation and generation semantics.
  • Two distinct learning channels. An embedding row changes for two reasons: (i) its use as an input (sparse updates) and (ii) its role in predicting outputs (dense outer-product updates). Both matter, but they have very different computational and statistical character.
  • Sparsity vs. density. Input-side updates touch only the rows present in the batch (sparse). Output-side updates touch all VV rows (dense) because softmax spreads probability mass over the vocabulary. This explains why the head can be a bottleneck in large-vocabulary models.
  • Practical design choices. If VV is large, we can look towards techniques like: adaptive softmax, vocabulary culling, sampling-based losses, or token clustering, to reduce the dense cost while preserving predictive fidelity.

Closing Remark

The model reads by looking up embedding rows and writes by taking inner products with those same rows; tying them makes the read and write vocabularies identical, and backpropagation simply sums the sparse read-derived corrections with the dense write-derived outer-product corrections into one embedding table.