Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Complete Code

Here is the complete source code for the model, along with all files from previous chapters.

model.ts

/**
 * GPT Model
 *
 * Defines the model's structure: configuration, weight matrices, and
 * parameter collection. Includes the forward pass (gpt) and KV cache
 * for inference. Save/load is introduced in a later chapter.
 */

import { Value, vsum } from "./autograd.js";
import { type Matrix, linear, softmax, rmsnorm } from "./nn.js";
import { gauss } from "./rng.js";

// --- Configuration ---

export interface GPTConfig {
  nLayer: number;
  nEmbd: number;
  blockSize: number;
  nHead: number;
  headDim: number;
  vocabSize: number;
}

interface Attention {
  query: Matrix;
  key: Matrix;
  value: Matrix;
  output: Matrix;
}

interface MLP {
  hidden: Matrix;
  output: Matrix;
}

interface Layer {
  attention: Attention;
  mlp: MLP;
}

export interface Weights {
  tokenEmbedding: Matrix;
  positionEmbedding: Matrix;
  output: Matrix;
  layers: Layer[];
}

/** The trained (or untrained) model: config + weights + flattened parameter list. */
export interface Model {
  config: GPTConfig;
  weights: Weights;
  params: Value[];
}

// --- Model Creation ---

function matrix(nout: number, nin: number, std = 0.08): Matrix {
  return Array.from({ length: nout }, () =>
    Array.from({ length: nin }, () => new Value(gauss(0, std)))
  );
}

/** Create a new model with randomly initialized weights. */
export function createModel(config: GPTConfig): Model {
  const { nEmbd, nLayer, blockSize, vocabSize } = config;

  const weights: Weights = {
    tokenEmbedding: matrix(vocabSize, nEmbd),
    positionEmbedding: matrix(blockSize, nEmbd),
    output: matrix(vocabSize, nEmbd),
    layers: Array.from({ length: nLayer }, () => ({
      attention: {
        query: matrix(nEmbd, nEmbd),
        key: matrix(nEmbd, nEmbd),
        value: matrix(nEmbd, nEmbd),
        output: matrix(nEmbd, nEmbd),
      },
      mlp: {
        hidden: matrix(4 * nEmbd, nEmbd),
        output: matrix(nEmbd, 4 * nEmbd),
      },
    })),
  };

  const allMatrices: Matrix[] = [
    weights.tokenEmbedding,
    weights.positionEmbedding,
    weights.output,
    ...weights.layers.flatMap((layer) => [
      layer.attention.query, layer.attention.key,
      layer.attention.value, layer.attention.output,
      layer.mlp.hidden, layer.mlp.output,
    ]),
  ];
  const params = allMatrices.flatMap((mat) => mat.flatMap((row) => row));

  return { config, weights, params };
}

// --- KV Cache ---

/** Create fresh key/value caches for a new sequence. Must be called per-sequence. */
export function createKVCache(model: Model): {
  keys: Value[][][];
  values: Value[][][];
} {
  return {
    keys: Array.from({ length: model.config.nLayer }, () => []),
    values: Array.from({ length: model.config.nLayer }, () => []),
  };
}

// --- Forward Pass ---

/**
 * Run one step of the GPT: given a token at a position, return logits
 * over the vocabulary for the next token.
 *
 * The keys/values caches are mutated (appended to) on each call —
 * this is the KV cache that avoids recomputing attention for past positions.
 */
export function gpt(
  model: Model,
  tokenId: number,
  posId: number,
  keys: Value[][][],
  values: Value[][][],
): Value[] {
  const { nLayer, nHead, headDim } = model.config;
  const { weights } = model;

  // Step 1: Embedding lookup
  // Combine "what word is this?" with "where does it appear?" into a single
  // vector. This is the hidden state that flows through the rest of the network.
  const tokenVector: Value[] = weights.tokenEmbedding[tokenId];
  const positionVector: Value[] = weights.positionEmbedding[posId];
  let hidden: Value[] = tokenVector.map((t, i) => t.add(positionVector[i]));

  // Normalize before the first layer to keep values at a stable scale
  hidden = rmsnorm(hidden);

  // Step 2: Transformer layers
  // Each layer has two blocks: attention (gather context from other tokens)
  // followed by MLP (process the gathered information). Both use residual
  // connections so the input is added back to the output of each block.
  for (let li = 0; li < nLayer; li++) {
    const layer = weights.layers[li];

    // --- Attention block: look at previous tokens to gather context ---

    // Save the hidden state so we can add it back after the block (residual)
    const beforeAttention: Value[] = hidden;
    hidden = rmsnorm(hidden);

    // Project the hidden state into query, key, and value vectors.
    // These are structurally identical projections — training teaches them
    // to play different roles: query asks "what am I looking for?",
    // key advertises "what do I contain?", value carries "what to retrieve".
    const query: Value[] = linear(hidden, layer.attention.query);
    const key: Value[] = linear(hidden, layer.attention.key);
    const value: Value[] = linear(hidden, layer.attention.value);

    // Cache the key and value so future tokens can attend to this position
    keys[li].push(key);
    values[li].push(value);

    // Each head independently attends to a different slice of the vectors,
    // allowing the model to track multiple relationships at once
    const attentionOutput: Value[] = [];
    for (let h = 0; h < nHead; h++) {
      const headStart = h * headDim;
      const headQuery = query.slice(headStart, headStart + headDim);
      const headKeys = keys[li].map((ki) => ki.slice(headStart, headStart + headDim));
      const headValues = values[li].map((vi) => vi.slice(headStart, headStart + headDim));

      // Scaled dot-product attention: score = (query · key) / √headDim
      const attnLogits = headKeys.map((cachedKey) =>
        vsum(headQuery.map((q, j) => q.mul(cachedKey[j]))).div(Math.sqrt(headDim))
      );

      // Softmax converts scores into weights that sum to 1
      const attnWeights = softmax(attnLogits);

      // Weighted sum of value vectors — high-scoring positions contribute more
      for (let j = 0; j < headDim; j++) {
        attentionOutput.push(vsum(attnWeights.map((w, t) => w.mul(headValues[t][j]))));
      }
    }

    // Project concatenated head outputs back to the hidden dimension
    hidden = linear(attentionOutput, layer.attention.output);
    // Residual connection: add back what we had before attention
    hidden = hidden.map((h, i) => h.add(beforeAttention[i]));

    // --- MLP block: process each token's representation independently ---

    const beforeMLP: Value[] = hidden;
    hidden = rmsnorm(hidden);
    hidden = linear(hidden, layer.mlp.hidden);   // expand to 4x wider
    hidden = hidden.map((h) => h.relu());         // nonlinearity
    hidden = linear(hidden, layer.mlp.output);    // compress back
    // Residual connection: add back what we had before the MLP
    hidden = hidden.map((h, i) => h.add(beforeMLP[i]));
  }

  // Step 3: Output projection
  // Project the final hidden state to vocabulary size — one score per word.
  // These raw scores (logits) will be passed through softmax later to get
  // a probability distribution over the next token.
  return linear(hidden, weights.output);
}

rng.ts

/**
 * Seeded Random Number Generator
 *
 * JavaScript has no built-in seeded RNG, so we roll our own:
 * - mulberry32 for uniform random numbers
 * - Box-Muller transform for Gaussian (normal) distribution
 * - Fisher-Yates for array shuffling
 * - Weighted random choice for sampling from probability distributions
 */

let _rngState = 0;

export function seed(s: number): void {
  _rngState = s | 0;
}

/** Mulberry32: a fast, seedable 32-bit PRNG. Returns a uniform float in [0, 1). */
export function random(): number {
  _rngState = (_rngState + 0x6d2b79f5) | 0;
  let t = _rngState;
  t = Math.imul(t ^ (t >>> 15), t | 1);
  t ^= t + Math.imul(t ^ (t >>> 7), t | 61);
  return ((t ^ (t >>> 14)) >>> 0) / 4294967296;
}

/** Box-Muller transform: convert two uniform samples into a Gaussian sample. */
export function gauss(mean: number, std: number): number {
  let u1 = random();
  while (u1 === 0) u1 = random();
  const u2 = random();
  return mean + std * Math.sqrt(-2 * Math.log(u1)) * Math.cos(2 * Math.PI * u2);
}

/** Fisher-Yates shuffle: randomly reorder an array in-place. */
export function shuffle<T>(arr: T[]): void {
  for (let i = arr.length - 1; i > 0; i--) {
    const j = Math.floor(random() * (i + 1));
    [arr[i], arr[j]] = [arr[j], arr[i]];
  }
}

/** Sample an index from a weighted distribution. Used for token sampling during inference. */
export function weightedChoice(weights: number[]): number {
  const total = weights.reduce((a, b) => a + b, 0);
  let r = random() * total;
  for (let i = 0; i < weights.length; i++) {
    r -= weights[i];
    if (r <= 0) return i;
  }
  return weights.length - 1;
}

nn.ts

/**
 * Neural Network Primitives
 *
 * The building blocks that the GPT architecture assembles:
 * - linear: matrix-vector multiply (the fundamental neural net operation)
 * - softmax: convert raw scores into a probability distribution
 * - rmsnorm: normalize activations to stabilize training
 *
 * These mirror PyTorch's torch.nn.functional — general-purpose operations
 * used by the model, training loop, and inference.
 */

import { Value, vsum } from "./autograd.js";

export type Matrix = Value[][];

/** y = Wx: multiply a weight matrix by an input vector. */
export function linear(input: Value[], weights: Matrix): Value[] {
  return weights.map((row) => vsum(row.map((w, i) => w.mul(input[i]))));
}

/** Convert raw logits to probabilities. Subtracts max for numerical stability. */
export function softmax(logits: Value[]): Value[] {
  const maxVal = Math.max(...logits.map((v) => v.data));
  const exps = logits.map((v) => v.sub(maxVal).exp());
  const total = vsum(exps);
  return exps.map((e) => e.div(total));
}

/** Root Mean Square normalization: scale activations to unit variance. */
export function rmsnorm(input: Value[]): Value[] {
  const ms = vsum(input.map((xi) => xi.mul(xi))).div(input.length);
  const scale = ms.add(1e-5).pow(-0.5);
  return input.map((xi) => xi.mul(scale));
}

autograd.ts

/**
 * Autograd Engine
 *
 * A scalar-valued automatic differentiation engine. Each Value node records:
 * - its forward-pass result (data)
 * - its gradient w.r.t. the loss (grad), filled in by backward()
 * - its children in the computation graph and the local derivatives
 *
 * The backward() method applies the chain rule via reverse-mode autodiff:
 * topologically sort the graph, then propagate gradients from output to inputs.
 *
 * "If I nudge this parameter slightly, how does the loss change?"
 */

export class Value {
  data: number;
  grad: number;
  children: Value[];
  localGrads: number[];

  constructor(data: number, children: Value[] = [], localGrads: number[] = []) {
    this.data = data;
    this.grad = 0;
    this.children = children;
    this.localGrads = localGrads;
  }

  add(other: Value | number): Value {
    const o = typeof other === "number" ? new Value(other) : other;
    return new Value(this.data + o.data, [this, o], [1, 1]);
  }

  mul(other: Value | number): Value {
    const o = typeof other === "number" ? new Value(other) : other;
    return new Value(this.data * o.data, [this, o], [o.data, this.data]);
  }

  pow(n: number): Value {
    return new Value(this.data ** n, [this], [n * this.data ** (n - 1)]);
  }

  log(): Value {
    return new Value(Math.log(this.data), [this], [1 / this.data]);
  }

  exp(): Value {
    return new Value(Math.exp(this.data), [this], [Math.exp(this.data)]);
  }

  relu(): Value {
    return new Value(Math.max(0, this.data), [this], [this.data > 0 ? 1 : 0]);
  }

  neg(): Value {
    return this.mul(-1);
  }

  sub(other: Value | number): Value {
    const o = typeof other === "number" ? new Value(other) : other;
    return this.add(o.neg());
  }

  div(other: Value | number): Value {
    const o = typeof other === "number" ? new Value(other) : other;
    return this.mul(o.pow(-1));
  }

  backward(): void {
    const topo: Value[] = [];
    const visited = new Set<Value>();
    const buildTopo = (v: Value): void => {
      if (!visited.has(v)) {
        visited.add(v);
        for (const child of v.children) buildTopo(child);
        topo.push(v);
      }
    };
    buildTopo(this);
    this.grad = 1;
    for (const v of topo.reverse()) {
      for (let i = 0; i < v.children.length; i++) {
        v.children[i].grad += v.localGrads[i] * v.grad;
      }
    }
  }
}

/** Sum a list of Values through the computation graph. */
export function vsum(values: Value[]): Value {
  return values.reduce((acc, v) => acc.add(v), new Value(0));
}

tokenizer.ts

/**
 * Tokenizer
 *
 * Translates strings to sequences of integers ("tokens") and back.
 * Builds a word-level vocabulary from the training corpus, with a BOS
 * (Beginning of Sequence) delimiter appended on each side.
 */

export interface Tokenizer {
  vocabSize: number;
  BOS: number;
  encode(sentence: string): number[];
  decode(tokens: number[]): string;
}

/** Word-level tokenizer. Discovers the word vocabulary from the corpus. */
export function createWordTokenizer(sentences: string[]): Tokenizer {
  const words = [...new Set(sentences.flatMap((d) => d.split(" ")))].sort();
  const BOS = words.length;
  const vocabSize = words.length + 1;

  return {
    vocabSize,
    BOS,
    encode(sentence: string): number[] {
      return [BOS, ...sentence.split(" ").map((w) => words.indexOf(w)), BOS];
    },
    decode(tokens: number[]): string {
      return tokens.map((t) => words[t]).join(" ");
    },
  };
}