function initTransformerDemo() { const container = document.getElementById('transformer-architecture-demo'); if (!container) return; container.innerHTML = `

Where scale invariance lives in LLM Transformers

LLaMA, Gemma and Qwen all use pre-normalized decoder blocks. That means most of their weight matrices can be rescaled by any c > 0 without changing the network's function – they live on multiplicative scale orbits.

0.25×
Scale invariance group (ℝ+, ×)
Perfectly invariant under c > 0

Click a weight matrix on the left, then drag the slider to move along its scale orbit W → cW. Because the block is normalized, all these points have (almost) identical normalized outputs.

Model family
Weight matrix
Scale c
Δ normalized output
`; // --- Architecture definitions for LLaMA / Gemma / Qwen-style blocks --- const architectures = { llama: { key: 'llama', displayName: 'LLaMA-style (Meta)', shortName: 'LLaMA', tagline: 'RMSNorm + SwiGLU pre-norm decoder layer (LLaMA / LLaMA 2).', summary: 'Pre-norm decoder-only Transformer with RMSNorm and SwiGLU MLP. This is representative of LLaMA, LLaMA 2 and many popular open LLMs.', bullets: [ 'RMSNorm before both attention and MLP blocks.', 'Most linear weights only see normalized activations → individually scale-invariant.', 'Token embeddings feed into RMSNormed blocks so inner weights live on scale orbits.' ], blocks: [ { key: 'attn', title: 'RMSNorm → Multi-Head Self-Attention', normLabel: 'RMSNorm(hₗ)', description: 'Q, K, V and O projections all see the same normalized hidden state.', weights: [ { key: 'attn.W_q', label: 'W_Q (queries)' }, { key: 'attn.W_k', label: 'W_K (keys)' }, { key: 'attn.W_v', label: 'W_V (values)' }, { key: 'attn.W_o', label: 'W_O (output)' } ] }, { key: 'mlp', title: 'RMSNorm → SwiGLU MLP', normLabel: 'RMSNorm(hₗ + Attn)', description: 'Gated feed-forward expands, gates, then projects back.', weights: [ { key: 'mlp.W_up', label: 'W_up (expand)' }, { key: 'mlp.W_gate', label: 'W_gate (SwiGLU gate)' }, { key: 'mlp.W_down', label: 'W_down (project back)' } ] } ] }, gemma: { key: 'gemma', displayName: 'Gemma-style (Google)', shortName: 'Gemma', tagline: 'RMSNorm + GeGLU/SwiGLU pre-norm decoder layer (Gemma).', summary: 'Gemma uses a very similar pre-norm decoder stack with RMSNorm and gated MLP blocks.', bullets: [ 'RMSNorm before attention and MLP → same radial scale invariance pattern.', 'Attention often uses standard or grouped-query layouts but still linear projections after RMSNorm.', 'Feed-forward block uses two parallel projections (value + gate) then a down projection.' ], blocks: [ { key: 'attn', title: 'RMSNorm → Multi-Head / GQA Attention', normLabel: 'RMSNorm(hₗ)', description: 'Q, K, V share one normalized input; O reprojects back to the residual stream.', weights: [ { key: 'attn.W_q', label: 'W_Q' }, { key: 'attn.W_k', label: 'W_K' }, { key: 'attn.W_v', label: 'W_V' }, { key: 'attn.W_o', label: 'W_O' } ] }, { key: 'mlp', title: 'RMSNorm → GeGLU / SwiGLU MLP', normLabel: 'RMSNorm(hₗ + Attn)', description: 'Two projections (value + gate) followed by a down projection.', weights: [ { key: 'mlp.W_ff1', label: 'W_ff1 (value)' }, { key: 'mlp.W_ff_gate', label: 'W_ff_gate (gate)' }, { key: 'mlp.W_ff2', label: 'W_ff2 (down)' } ] } ] }, qwen: { key: 'qwen', displayName: 'Qwen-style (Alibaba)', shortName: 'Qwen', tagline: 'RMSNorm + GQA attention pre-norm decoder layer (Qwen).', summary: 'Qwen uses pre-norm decoder layers with RMSNorm and grouped-query attention. The scale invariance story is the same.', bullets: [ 'RMSNorm before attention and MLP → radial scale invariance per weight matrix.', 'Grouped-query attention shares K/V across heads but still via linear projections.', 'MLP is a gated feed-forward block similar to LLaMA and Gemma.' ], blocks: [ { key: 'attn', title: 'RMSNorm → GQA Attention', normLabel: 'RMSNorm(hₗ)', description: 'Separate Q projection, shared K/V projections, and an output projection.', weights: [ { key: 'attn.W_q', label: 'W_Q (per-head)' }, { key: 'attn.W_kv', label: 'W_KV (shared K/V)' }, { key: 'attn.W_o', label: 'W_O (output)' } ] }, { key: 'mlp', title: 'RMSNorm → Gated MLP', normLabel: 'RMSNorm(hₗ + Attn)', description: 'Gated feed-forward (e.g. SwiGLU) similar to LLaMA and Gemma.', weights: [ { key: 'mlp.W_in', label: 'W_in (expand)' }, { key: 'mlp.W_gate', label: 'W_gate' }, { key: 'mlp.W_out', label: 'W_out (down)' } ] } ] } }; // --- State & helpers --- let currentArchKey = 'llama'; let currentWeightKey = architectures.llama.blocks[0].weights[0].key; const archSummaryEl = container.querySelector('#transformer-arch-summary'); const archBulletsEl = container.querySelector('#transformer-arch-bullets'); const diagramEl = container.querySelector('#transformer-diagram'); const modelTabs = container.querySelectorAll('.transformer-model-tab'); const scaleSlider = container.querySelector('#transformer-scale-slider'); const scaleValueEl = container.querySelector('#transformer-scale-value'); const diffValueEl = container.querySelector('#transformer-diff-value'); const selectedWeightEl = container.querySelector('#transformer-selected-weight-label'); const selectedArchEl = container.querySelector('#transformer-selected-arch-label'); const orbitEl = container.querySelector('#transformer-orbit'); const invarianceBadgeEl = container.querySelector('#transformer-invariance-badge'); const nodeExponents = [-0.6, -0.3, 0.0, 0.3, 0.6]; const orbitNodes = []; const invarianceTests = {}; function randn() { let u = 0, v = 0; while (u === 0) u = Math.random(); while (v === 0) v = Math.random(); return Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v); } function rmsnorm(vec) { let ss = 0; for (let i = 0; i < vec.length; i++) ss += vec[i] * vec[i]; const denom = Math.sqrt(ss / vec.length + 1e-12); const out = new Float64Array(vec.length); for (let i = 0; i < vec.length; i++) out[i] = vec[i] / denom; return out; } function l2Distance(a, b) { let s = 0; for (let i = 0; i < a.length; i++) { const d = a[i] - b[i]; s += d * d; } return Math.sqrt(s); } function formatScale(c) { if (Math.abs(c - 1) < 1e-3) return '1.00×'; if (c < 10) return c.toFixed(2) + '×'; return c.toExponential(1) + '×'; } function createInvarianceTest() { // vector-level toy: y = RMSNorm(Wx) → RMSNorm is exactly scale-invariant const dim = 32; const baseVec = new Float64Array(dim); for (let i = 0; i < dim; i++) baseVec[i] = randn(); const baseOut = rmsnorm(baseVec); return { isScaleInvariant: true, baseOut, forward: function (c) { const scaled = new Float64Array(dim); for (let i = 0; i < dim; i++) scaled[i] = c * baseVec[i]; return rmsnorm(scaled); } }; } function getInvarianceTest(archKey, weightKey) { const id = `${archKey}::${weightKey}`; if (!invarianceTests[id]) { invarianceTests[id] = createInvarianceTest(); } return invarianceTests[id]; } function findWeightMeta(arch, weightKey) { for (const block of arch.blocks) { for (const w of block.weights) { if (w.key === weightKey) return w; } } return null; } function buildArchitectureHTML(arch) { let html = `
${arch.displayName}
${arch.tagline}
Token IDs
Embedding lookup
L × decoder blocks
`; arch.blocks.forEach(block => { const normLabel = block.normLabel || 'Norm(hₗ)'; const subtitle = block.description || ''; html += `
${block.title}
${subtitle ? `
${subtitle}
` : ''}
scale-invariant sphere
${normLabel}
${block.weights.map(w => ` `).join('')}
hout = hin + F(Norm(hin); W)
`; }); html += `
`; return html; } function renderArchitecture() { const arch = architectures[currentArchKey]; diagramEl.innerHTML = buildArchitectureHTML(arch); const chips = diagramEl.querySelectorAll('.transformer-weight-chip'); chips.forEach(chip => { chip.addEventListener('click', () => { const key = chip.getAttribute('data-weight-key'); if (!key) return; currentWeightKey = key; updateScaleAndDiff(); highlightSelectedChip(); }); }); const allKeys = arch.blocks.flatMap(b => b.weights.map(w => w.key)); if (!currentWeightKey || !allKeys.includes(currentWeightKey)) { currentWeightKey = allKeys[0]; } highlightSelectedChip(); } function highlightSelectedChip() { const chips = diagramEl.querySelectorAll('.transformer-weight-chip'); chips.forEach(chip => { const key = chip.getAttribute('data-weight-key'); chip.classList.toggle('active', key === currentWeightKey); }); } function updateArchDescription() { const arch = architectures[currentArchKey]; archSummaryEl.textContent = arch.summary; archBulletsEl.innerHTML = ''; if (arch.bullets && arch.bullets.length) { arch.bullets.forEach(txt => { const li = document.createElement('li'); li.textContent = txt; archBulletsEl.appendChild(li); }); } } function createOrbit() { orbitEl.innerHTML = ''; const center = document.createElement('div'); center.className = 'transformer-orbit-center'; center.innerHTML = `
direction on sphere
`; orbitEl.appendChild(center); nodeExponents.forEach(exp => { const node = document.createElement('div'); node.className = 'transformer-orbit-node'; node.dataset.exponent = exp.toString(); const c = Math.pow(10, exp); node.textContent = formatScale(c); orbitEl.appendChild(node); orbitNodes.push(node); }); function positionOrbitNodes() { const rect = orbitEl.getBoundingClientRect(); if (!rect.width || !rect.height) return; const cx = rect.width / 2; const cy = rect.height / 2; const radius = Math.min(cx, cy) - 22; orbitNodes.forEach((node, idx) => { const angle = (2 * Math.PI * idx) / nodeExponents.length - Math.PI / 2; const x = cx + radius * Math.cos(angle); const y = cy + radius * Math.sin(angle); node.style.left = `${x}px`; node.style.top = `${y}px`; }); } window.requestAnimationFrame(positionOrbitNodes); window.addEventListener('resize', positionOrbitNodes); } function updateOrbitHighlight() { const exp = parseFloat(scaleSlider.value); let bestNode = null; let bestDist = Infinity; orbitNodes.forEach(node => { const e = parseFloat(node.dataset.exponent); const d = Math.abs(e - exp); if (d < bestDist) { bestDist = d; bestNode = node; } }); orbitNodes.forEach(node => { node.classList.toggle('active', node === bestNode); }); } function updateScaleAndDiff() { const arch = architectures[currentArchKey]; const beta = parseFloat(scaleSlider.value); const c = Math.pow(10, beta); scaleValueEl.textContent = formatScale(c); selectedArchEl.textContent = arch.displayName; const weightMeta = findWeightMeta(arch, currentWeightKey); selectedWeightEl.textContent = weightMeta ? weightMeta.label : currentWeightKey; const test = getInvarianceTest(currentArchKey, currentWeightKey); const yScaled = test.forward(c); const diff = l2Distance(test.baseOut, yScaled); diffValueEl.textContent = diff.toExponential(2); const tiny = 1e-10; if (diff < tiny) { invarianceBadgeEl.textContent = 'Perfectly invariant under c > 0'; invarianceBadgeEl.classList.remove('warning'); } else { invarianceBadgeEl.textContent = 'Almost invariant (numerical noise only)'; invarianceBadgeEl.classList.remove('warning'); } updateOrbitHighlight(); } function updateModelTabs() { modelTabs.forEach(btn => { const k = btn.getAttribute('data-arch'); btn.classList.toggle('active', k === currentArchKey); }); } // --- Event wiring --- modelTabs.forEach(btn => { btn.addEventListener('click', () => { const k = btn.getAttribute('data-arch'); if (!k || k === currentArchKey) return; currentArchKey = k; const arch = architectures[currentArchKey]; currentWeightKey = arch.blocks[0].weights[0].key; updateModelTabs(); updateArchDescription(); renderArchitecture(); updateScaleAndDiff(); }); }); scaleSlider.addEventListener('input', () => { updateScaleAndDiff(); }); // --- Initial render --- updateModelTabs(); updateArchDescription(); createOrbit(); renderArchitecture(); updateScaleAndDiff(); }