function initTransformerDemo() {
const container = document.getElementById('transformer-architecture-demo');
if (!container) return;
container.innerHTML = `
`;
// --- Architecture definitions for LLaMA / Gemma / Qwen-style blocks ---
const architectures = {
llama: {
key: 'llama',
displayName: 'LLaMA-style (Meta)',
shortName: 'LLaMA',
tagline: 'RMSNorm + SwiGLU pre-norm decoder layer (LLaMA / LLaMA 2).',
summary: 'Pre-norm decoder-only Transformer with RMSNorm and SwiGLU MLP. This is representative of LLaMA, LLaMA 2 and many popular open LLMs.',
bullets: [
'RMSNorm before both attention and MLP blocks.',
'Most linear weights only see normalized activations → individually scale-invariant.',
'Token embeddings feed into RMSNormed blocks so inner weights live on scale orbits.'
],
blocks: [
{
key: 'attn',
title: 'RMSNorm → Multi-Head Self-Attention',
normLabel: 'RMSNorm(hₗ)',
description: 'Q, K, V and O projections all see the same normalized hidden state.',
weights: [
{ key: 'attn.W_q', label: 'W_Q (queries)' },
{ key: 'attn.W_k', label: 'W_K (keys)' },
{ key: 'attn.W_v', label: 'W_V (values)' },
{ key: 'attn.W_o', label: 'W_O (output)' }
]
},
{
key: 'mlp',
title: 'RMSNorm → SwiGLU MLP',
normLabel: 'RMSNorm(hₗ + Attn)',
description: 'Gated feed-forward expands, gates, then projects back.',
weights: [
{ key: 'mlp.W_up', label: 'W_up (expand)' },
{ key: 'mlp.W_gate', label: 'W_gate (SwiGLU gate)' },
{ key: 'mlp.W_down', label: 'W_down (project back)' }
]
}
]
},
gemma: {
key: 'gemma',
displayName: 'Gemma-style (Google)',
shortName: 'Gemma',
tagline: 'RMSNorm + GeGLU/SwiGLU pre-norm decoder layer (Gemma).',
summary: 'Gemma uses a very similar pre-norm decoder stack with RMSNorm and gated MLP blocks.',
bullets: [
'RMSNorm before attention and MLP → same radial scale invariance pattern.',
'Attention often uses standard or grouped-query layouts but still linear projections after RMSNorm.',
'Feed-forward block uses two parallel projections (value + gate) then a down projection.'
],
blocks: [
{
key: 'attn',
title: 'RMSNorm → Multi-Head / GQA Attention',
normLabel: 'RMSNorm(hₗ)',
description: 'Q, K, V share one normalized input; O reprojects back to the residual stream.',
weights: [
{ key: 'attn.W_q', label: 'W_Q' },
{ key: 'attn.W_k', label: 'W_K' },
{ key: 'attn.W_v', label: 'W_V' },
{ key: 'attn.W_o', label: 'W_O' }
]
},
{
key: 'mlp',
title: 'RMSNorm → GeGLU / SwiGLU MLP',
normLabel: 'RMSNorm(hₗ + Attn)',
description: 'Two projections (value + gate) followed by a down projection.',
weights: [
{ key: 'mlp.W_ff1', label: 'W_ff1 (value)' },
{ key: 'mlp.W_ff_gate', label: 'W_ff_gate (gate)' },
{ key: 'mlp.W_ff2', label: 'W_ff2 (down)' }
]
}
]
},
qwen: {
key: 'qwen',
displayName: 'Qwen-style (Alibaba)',
shortName: 'Qwen',
tagline: 'RMSNorm + GQA attention pre-norm decoder layer (Qwen).',
summary: 'Qwen uses pre-norm decoder layers with RMSNorm and grouped-query attention. The scale invariance story is the same.',
bullets: [
'RMSNorm before attention and MLP → radial scale invariance per weight matrix.',
'Grouped-query attention shares K/V across heads but still via linear projections.',
'MLP is a gated feed-forward block similar to LLaMA and Gemma.'
],
blocks: [
{
key: 'attn',
title: 'RMSNorm → GQA Attention',
normLabel: 'RMSNorm(hₗ)',
description: 'Separate Q projection, shared K/V projections, and an output projection.',
weights: [
{ key: 'attn.W_q', label: 'W_Q (per-head)' },
{ key: 'attn.W_kv', label: 'W_KV (shared K/V)' },
{ key: 'attn.W_o', label: 'W_O (output)' }
]
},
{
key: 'mlp',
title: 'RMSNorm → Gated MLP',
normLabel: 'RMSNorm(hₗ + Attn)',
description: 'Gated feed-forward (e.g. SwiGLU) similar to LLaMA and Gemma.',
weights: [
{ key: 'mlp.W_in', label: 'W_in (expand)' },
{ key: 'mlp.W_gate', label: 'W_gate' },
{ key: 'mlp.W_out', label: 'W_out (down)' }
]
}
]
}
};
// --- State & helpers ---
let currentArchKey = 'llama';
let currentWeightKey = architectures.llama.blocks[0].weights[0].key;
const archSummaryEl = container.querySelector('#transformer-arch-summary');
const archBulletsEl = container.querySelector('#transformer-arch-bullets');
const diagramEl = container.querySelector('#transformer-diagram');
const modelTabs = container.querySelectorAll('.transformer-model-tab');
const scaleSlider = container.querySelector('#transformer-scale-slider');
const scaleValueEl = container.querySelector('#transformer-scale-value');
const diffValueEl = container.querySelector('#transformer-diff-value');
const selectedWeightEl = container.querySelector('#transformer-selected-weight-label');
const selectedArchEl = container.querySelector('#transformer-selected-arch-label');
const orbitEl = container.querySelector('#transformer-orbit');
const invarianceBadgeEl = container.querySelector('#transformer-invariance-badge');
const nodeExponents = [-0.6, -0.3, 0.0, 0.3, 0.6];
const orbitNodes = [];
const invarianceTests = {};
function randn() {
let u = 0, v = 0;
while (u === 0) u = Math.random();
while (v === 0) v = Math.random();
return Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v);
}
function rmsnorm(vec) {
let ss = 0;
for (let i = 0; i < vec.length; i++) ss += vec[i] * vec[i];
const denom = Math.sqrt(ss / vec.length + 1e-12);
const out = new Float64Array(vec.length);
for (let i = 0; i < vec.length; i++) out[i] = vec[i] / denom;
return out;
}
function l2Distance(a, b) {
let s = 0;
for (let i = 0; i < a.length; i++) {
const d = a[i] - b[i];
s += d * d;
}
return Math.sqrt(s);
}
function formatScale(c) {
if (Math.abs(c - 1) < 1e-3) return '1.00×';
if (c < 10) return c.toFixed(2) + '×';
return c.toExponential(1) + '×';
}
function createInvarianceTest() {
// vector-level toy: y = RMSNorm(Wx) → RMSNorm is exactly scale-invariant
const dim = 32;
const baseVec = new Float64Array(dim);
for (let i = 0; i < dim; i++) baseVec[i] = randn();
const baseOut = rmsnorm(baseVec);
return {
isScaleInvariant: true,
baseOut,
forward: function (c) {
const scaled = new Float64Array(dim);
for (let i = 0; i < dim; i++) scaled[i] = c * baseVec[i];
return rmsnorm(scaled);
}
};
}
function getInvarianceTest(archKey, weightKey) {
const id = `${archKey}::${weightKey}`;
if (!invarianceTests[id]) {
invarianceTests[id] = createInvarianceTest();
}
return invarianceTests[id];
}
function findWeightMeta(arch, weightKey) {
for (const block of arch.blocks) {
for (const w of block.weights) {
if (w.key === weightKey) return w;
}
}
return null;
}
function buildArchitectureHTML(arch) {
let html = `
`;
return html;
}
function renderArchitecture() {
const arch = architectures[currentArchKey];
diagramEl.innerHTML = buildArchitectureHTML(arch);
const chips = diagramEl.querySelectorAll('.transformer-weight-chip');
chips.forEach(chip => {
chip.addEventListener('click', () => {
const key = chip.getAttribute('data-weight-key');
if (!key) return;
currentWeightKey = key;
updateScaleAndDiff();
highlightSelectedChip();
});
});
const allKeys = arch.blocks.flatMap(b => b.weights.map(w => w.key));
if (!currentWeightKey || !allKeys.includes(currentWeightKey)) {
currentWeightKey = allKeys[0];
}
highlightSelectedChip();
}
function highlightSelectedChip() {
const chips = diagramEl.querySelectorAll('.transformer-weight-chip');
chips.forEach(chip => {
const key = chip.getAttribute('data-weight-key');
chip.classList.toggle('active', key === currentWeightKey);
});
}
function updateArchDescription() {
const arch = architectures[currentArchKey];
archSummaryEl.textContent = arch.summary;
archBulletsEl.innerHTML = '';
if (arch.bullets && arch.bullets.length) {
arch.bullets.forEach(txt => {
const li = document.createElement('li');
li.textContent = txt;
archBulletsEl.appendChild(li);
});
}
}
function createOrbit() {
orbitEl.innerHTML = '';
const center = document.createElement('div');
center.className = 'transformer-orbit-center';
center.innerHTML = `
direction on sphere
`;
orbitEl.appendChild(center);
nodeExponents.forEach(exp => {
const node = document.createElement('div');
node.className = 'transformer-orbit-node';
node.dataset.exponent = exp.toString();
const c = Math.pow(10, exp);
node.textContent = formatScale(c);
orbitEl.appendChild(node);
orbitNodes.push(node);
});
function positionOrbitNodes() {
const rect = orbitEl.getBoundingClientRect();
if (!rect.width || !rect.height) return;
const cx = rect.width / 2;
const cy = rect.height / 2;
const radius = Math.min(cx, cy) - 22;
orbitNodes.forEach((node, idx) => {
const angle = (2 * Math.PI * idx) / nodeExponents.length - Math.PI / 2;
const x = cx + radius * Math.cos(angle);
const y = cy + radius * Math.sin(angle);
node.style.left = `${x}px`;
node.style.top = `${y}px`;
});
}
window.requestAnimationFrame(positionOrbitNodes);
window.addEventListener('resize', positionOrbitNodes);
}
function updateOrbitHighlight() {
const exp = parseFloat(scaleSlider.value);
let bestNode = null;
let bestDist = Infinity;
orbitNodes.forEach(node => {
const e = parseFloat(node.dataset.exponent);
const d = Math.abs(e - exp);
if (d < bestDist) {
bestDist = d;
bestNode = node;
}
});
orbitNodes.forEach(node => {
node.classList.toggle('active', node === bestNode);
});
}
function updateScaleAndDiff() {
const arch = architectures[currentArchKey];
const beta = parseFloat(scaleSlider.value);
const c = Math.pow(10, beta);
scaleValueEl.textContent = formatScale(c);
selectedArchEl.textContent = arch.displayName;
const weightMeta = findWeightMeta(arch, currentWeightKey);
selectedWeightEl.textContent = weightMeta ? weightMeta.label : currentWeightKey;
const test = getInvarianceTest(currentArchKey, currentWeightKey);
const yScaled = test.forward(c);
const diff = l2Distance(test.baseOut, yScaled);
diffValueEl.textContent = diff.toExponential(2);
const tiny = 1e-10;
if (diff < tiny) {
invarianceBadgeEl.textContent = 'Perfectly invariant under c > 0';
invarianceBadgeEl.classList.remove('warning');
} else {
invarianceBadgeEl.textContent = 'Almost invariant (numerical noise only)';
invarianceBadgeEl.classList.remove('warning');
}
updateOrbitHighlight();
}
function updateModelTabs() {
modelTabs.forEach(btn => {
const k = btn.getAttribute('data-arch');
btn.classList.toggle('active', k === currentArchKey);
});
}
// --- Event wiring ---
modelTabs.forEach(btn => {
btn.addEventListener('click', () => {
const k = btn.getAttribute('data-arch');
if (!k || k === currentArchKey) return;
currentArchKey = k;
const arch = architectures[currentArchKey];
currentWeightKey = arch.blocks[0].weights[0].key;
updateModelTabs();
updateArchDescription();
renderArchitecture();
updateScaleAndDiff();
});
});
scaleSlider.addEventListener('input', () => {
updateScaleAndDiff();
});
// --- Initial render ---
updateModelTabs();
updateArchDescription();
createOrbit();
renderArchitecture();
updateScaleAndDiff();
}