Skip to content

Commit

Permalink
Sort out MLP walkthrough
Browse files Browse the repository at this point in the history
  • Loading branch information
bbycroft committed Nov 16, 2023
1 parent d796fcb commit d0588b4
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 10 deletions.
7 changes: 7 additions & 0 deletions src/llm/GptModelLayout.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ export interface IBlkDef {
highlight: number; // 0 - 1 (0 = no highlight, 1 = full highlight)
opacity: number; // 0 - 1 (0 = transparent, 1 = opaque)
special: BlkSpecial;
transpose?: boolean; // transpose the process direction
subs?: IBlkDef[]; // substitutes for this block (i.e. render these instead)
offX?: number; // offset from the original block
offY?: number;
Expand Down Expand Up @@ -159,6 +160,7 @@ interface IBlkDefArgs {
deps?: IBlkDepArgs;
small?: boolean;
hidden?: boolean;
transpose?: boolean;
}

export interface IBlkLabel {
Expand Down Expand Up @@ -250,6 +252,7 @@ export function genGptModelLayout(shape: IModelShape, gptGpuModel: IGptModelLink
highlight: 0.0,
small: args.small ?? false,
special: args.special ?? BlkSpecial.None,
transpose: args.transpose,
idx: -1,
};
}
Expand Down Expand Up @@ -482,6 +485,7 @@ export function genGptModelLayout(shape: IModelShape, gptGpuModel: IGptModelLink
deps: { dot: [[qBlock, 'yi'], [kBlock, 'xi']], lowerTri: true, dotLen: A, special: BlKDepSpecial.Attention },
dimX: DimStyle.T, dimY: DimStyle.T,
special: BlkSpecial.Attention,
transpose: true,
name: 'Attention Matrix',
});

Expand Down Expand Up @@ -510,6 +514,7 @@ export function genGptModelLayout(shape: IModelShape, gptGpuModel: IGptModelLink
deps: { add: [[attnMtx, 'xy'], [attnMtxAgg1, 'iy'], [attnMtxAgg2, 'iy']], lowerTri: true, special: BlKDepSpecial.Softmax },
dimX: DimStyle.T, dimY: DimStyle.T,
special: BlkSpecial.Attention,
transpose: true,
name: 'Attn Matrix Softmax',
});

Expand Down Expand Up @@ -627,6 +632,7 @@ export function genGptModelLayout(shape: IModelShape, gptGpuModel: IGptModelLink
deps: { dot: [[mlpFcWeight, 'xi'], [ln2.lnResid, 'yi']], dotLen: C, add: [[mlpFcBias, 'x']] },
dimX: DimStyle.C4, dimY: DimStyle.T,
name: 'MLP',
transpose: true,
});

y += T * cell + margin;
Expand All @@ -638,6 +644,7 @@ export function genGptModelLayout(shape: IModelShape, gptGpuModel: IGptModelLink
deps: { add: [[mlpFc, 'xy']], special: BlKDepSpecial.Gelu },
dimX: DimStyle.C4, dimY: DimStyle.T,
name: 'MLP Activation',
transpose: true,
});

y += T * cell + margin;
Expand Down
10 changes: 9 additions & 1 deletion src/llm/components/DataFlow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ export function drawAttention(args: IDataFlowArgs) {
}

export function drawGeluActivation(args: IDataFlowArgs) {
let { state, center, mtx } = args;
let { state, center, mtx, blk, destIdx } = args;

// ugly
// let fontOpts = { color: opColor, mtx, size: 16 };
Expand Down Expand Up @@ -694,6 +694,14 @@ export function drawGeluActivation(args: IDataFlowArgs) {
let curveLineOpts = makeLineOpts({ color: Colors.Intermediates, mtx, thick: 3.5 });
drawLineSegs(state.render.lineRender, pts, curveLineOpts);

let srcBlk = blk.deps!.add![0].src;
let srcVal = getBlockValueAtIdx(srcBlk, destIdx);

if (isNotNil(srcVal)) {
let destVal = geluX(srcVal);
drawCircle(state.render, new Vec3(mappingX(srcVal), mappingY(destVal)), 2, 1, Colors.Intermediates, mtx);
}

let bb = new BoundingBox3d(tl, br);

return bb;
Expand Down
2 changes: 1 addition & 1 deletion src/llm/walkthrough/Walkthrough00_Intro.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ export function processUpTo(state: IProgramState, timer: ITimeInfo, block: IBlkD
// default, but switched for attention matrix
let dim0 = Dim.X;
let dim1 = Dim.Y;
if (blk.special === BlkSpecial.Attention) {
if (blk.transpose) {
dim0 = Dim.Y;
dim1 = Dim.X;
}
Expand Down
222 changes: 214 additions & 8 deletions src/llm/walkthrough/Walkthrough07_Mlp.tsx
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
import { Vec3 } from "@/src/utils/vector";
import React from "react";
import { Dim, Vec3 } from "@/src/utils/vector";
import { Phase } from "./Walkthrough";
import { commentary, DimStyle, IWalkthroughArgs, setInitialCamera } from "./WalkthroughTools";
import { commentary, DimStyle, dimStyleColor, IWalkthroughArgs, setInitialCamera } from "./WalkthroughTools";
import { dimProps, findSubBlocks, splitGrid } from "../Annotations";
import { lerp } from "@/src/utils/math";
import { IBlkDef, getBlkDimensions } from "../GptModelLayout";
import { processUpTo, startProcessBefore } from "./Walkthrough00_Intro";
import { drawDataFlow } from "../components/DataFlow";
import { drawDependences } from "../Interaction";
import { makeArray, makeArrayRange } from "@/src/utils/data";

export function walkthrough07_Mlp(args: IWalkthroughArgs) {
let { walkthrough: wt, state, tools: { afterTime, c_blockRef, c_dimRef } } = args;
let { walkthrough: wt, state, layout, tools: { afterTime, c_blockRef, c_dimRef, breakAfter, cleanup } } = args;

if (wt.phase !== Phase.Input_Detail_Mlp) {
return;
}

let block = state.layout.blocks[0];
let block = layout.blocks[0];

setInitialCamera(state, new Vec3(-154.755, 0.000, -460.042), new Vec3(289.100, -8.900, 2.298));
wt.dimHighlightBlocks = [block.ln2.lnResid, block.mlpAct, block.mlpFc, block.mlpFcBias, block.mlpFcWeight, block.mlpProjBias, block.mlpProjWeight, block.mlpResult, block.mlpResidual];

commentary(wt, null, 0)`
commentary(wt)`
The next half of the transformer block, after the self-attention, is the MLP (multi-layer
perceptron). A bit of a mouthful, but here it's a simple neural network with two layers.
Expand All @@ -29,10 +37,208 @@ In the MLP, we put each of our ${c_dimRef('C = 48', DimStyle.C)} length column v
3. A ${c_blockRef('linear transformation', block.mlpProjWeight)} with a ${c_blockRef('bias', block.mlpProjBias)} added, back to a vector of length ${c_dimRef('C', DimStyle.C)}
It's also common to refer to this as a "feed-forward" network, since the data flows through it in a
single direction.
Let's track one of those vectors:
`;
breakAfter();

let t0_fadeOut = afterTime(null, 1.0);

breakAfter();

commentary(wt)`
We first run through the matrix-vector multiplication with bias added, expanding the vector to length ${c_dimRef('4 * C', DimStyle.C4)}. (Note that the output matrix is transposed here.
This is purely for vizualization purposes.)
`;
breakAfter();

let t1_process = afterTime(null, 3.0);

breakAfter();

commentary(wt)`
Next, we apply the GELU activation function to each element of the vector. This is a key part of any neural network, where we introduce some non-linearity into the model. The specific function used, GELU,
looks a lot like a ReLU function (computed as ${<code>max(0, x)</code>}), but it has a smooth curve rather than a sharp corner.
${<ReluGraph />}
`;
breakAfter();

let t2_process = afterTime(null, 3.0);

breakAfter();

commentary(wt)`
We then project the vector back down to length ${c_dimRef('C', DimStyle.C)} with another matrix-vector multiplication with bias added.
`;
breakAfter();

The output of the MLP is then added to the input of the MLP, continuing the residual pathway.
let t3_process = afterTime(null, 3.0);

breakAfter();

commentary(wt)`
Like in the self-attention + projection section, we add the result of the MLP to it's input, element-wise.
`;
breakAfter();

let t4_process = afterTime(null, 3.0);

breakAfter();
commentary(wt)`
We can now repeat this process for all of the columns in the input.`;

breakAfter();

let t5_cleanup = afterTime(null, 1.0, 0.5);
cleanup(t5_cleanup, [t0_fadeOut]);
let t6_processAll = afterTime(null, 6.0);

breakAfter();

commentary(wt)`
And that's the MLP completed. We now have the output of the transformer block, which is ready to be passed to the next block.
`;

let targetIdx = 3;
let inputBlk = block.ln2.lnResid;
let mlp1Blk = block.mlpFc;
let mlp2Blk = block.mlpAct;
let mlpRes = block.mlpResult;
let mlpResid = block.mlpResidual;

function dimExceptVector(blk: IBlkDef, axis: Dim, disable: boolean) {
if (t0_fadeOut.t === 0 || t6_processAll.t > 0) {
return;
}

if (disable) {
blk.access!.disable = true;
}

let col = splitGrid(layout, blk, axis, targetIdx + 0.5, lerp(0.0, 1.0, t0_fadeOut.t))!;

for (let sub of blk.subs!) {
sub.opacity = lerp(1.0, 0.2, t0_fadeOut.t);
}

col.opacity = 1.0;

return col!;
}

dimExceptVector(inputBlk, Dim.X, false);
let mlp1Col = dimExceptVector(mlp1Blk, Dim.Y, true);
let mlp2Col = dimExceptVector(mlp2Blk, Dim.Y, true);
let mlpResCol = dimExceptVector(mlpRes, Dim.X, true);
let mplResIdCol = dimExceptVector(mlpResid, Dim.X, true);

function processVector(blk: IBlkDef, col: IBlkDef | undefined, t: number, pinIdx: Vec3) {
if (t === 0) {
return;
}

let dim0 = blk.transpose ? Dim.Y : Dim.X;
let dim1 = blk.transpose ? Dim.X : Dim.Y;
let { cx: numCells } = dimProps(blk, dim1);

let xPos = Math.floor(lerp(0, numCells, t));

let destIdx = new Vec3().setAt(dim0, targetIdx).setAt(dim1, xPos).round_();

if (col) {
let row = splitGrid(layout, col, dim1, xPos, 0.0);
for (let a of findSubBlocks(col, dim1, 0, xPos)) {
a.access!.disable = false;
}
}

if (t < 1.0) {
drawDataFlow(state, blk, destIdx, pinIdx);
drawDependences(state, blk, destIdx);
} else if (col) {
col!.access!.disable = false;
}
}

processVector(mlp1Blk, mlp1Col, t1_process.t, new Vec3(40));
processVector(mlp2Blk, mlp2Col, t2_process.t, new Vec3(mlp1Blk.cx / 2, -15));
processVector(mlpRes, mlpResCol, t3_process.t, new Vec3(mlpRes.cx / 2, -15));
processVector(mlpResid, mplResIdCol, t4_process.t, new Vec3(mlpRes.cx / 2, -15));

if (t5_cleanup.t > 0.4) {
mlp1Blk.access!.disable = true;
mlp2Blk.access!.disable = true;
mlpRes.access!.disable = true;
mlpResid.access!.disable = true;
}

if (t6_processAll.t > 0) {
let prevInfo = startProcessBefore(state, inputBlk);
processUpTo(state, t6_processAll, mlpResid, prevInfo);
}
}


const ReluGraph: React.FC = () => {

let fnRelu = (x: number) => Math.max(0, x);
let fnGelu = (x: number) => x * 0.5 * (1 + Math.tanh(Math.sqrt(2 / Math.PI) * (x + 0.044715 * Math.pow(x, 3))));

function createMapping(range0: number, range1: number, domain0: number, domain1: number) {
let m = (range1 - range0) / (domain1 - domain0);
let b = range0 - m * domain0;
return (x: number) => m * x + b;
}

let w = 200;
let h = 160;

let halfW = 3.2;
let halfH = halfW * h / w;
let hOffset = 1.1;

let xScale = createMapping(0, w, -halfW, halfW);
let yScale = createMapping(h, 0, -halfH + hOffset, halfH + hOffset);

let xPts = makeArrayRange(100, -halfW, halfW);

function makePath(fn: (x: number) => number) {
let path = "";
for (let x of xPts) {
let y = fn(x);
path += (path ? 'L' : 'M') + `${xScale(x)},${yScale(y)} `;
}
return path;
}

let vertTickVals = [-1, 1, 2, 3];

let vertTicks = vertTickVals.map(a => {
return { x: xScale(0), y: yScale(a), label: a };
});

let horizTickVals = [-3, -2, -1, 1, 2, 3];
let horizTicks = horizTickVals.map(a => {
return { x: xScale(a), y: yScale(0), label: a };
});

let tickColor = "gray";

return <div className="flex justify-center my-2">
<svg viewBox={`0 0 ${w} ${h}`} width={w} height={h} className="bg-slate-200 rounded">
<line x1={xScale(-halfW)} x2={xScale(halfW)} y1={yScale(0)} y2={yScale(0)} stroke={"gray"} strokeWidth={1} />
<line x1={xScale(0)} x2={xScale(0)} y1={yScale(-halfH + hOffset)} y2={yScale(halfH + hOffset)} stroke={"gray"} strokeWidth={1} />
{/* <path d={makePath(fnRelu)} stroke={"blue"} fill="none" strokeWidth={1} /> */}
<path d={makePath(fnGelu)} stroke={dimStyleColor(DimStyle.Intermediates).toHexColor()} fill="none" strokeWidth={3} />
{vertTicks.map((a, i) => <g key={i} transform={`translate(${a.x}, ${a.y})`}>
<line x1={-5} x2={5} y1={0} y2={0} stroke={tickColor} strokeWidth={1} />
<text x={10} y={5} fontSize={10} fill={tickColor}>{a.label}</text>
</g>)}
{horizTicks.map((a, i) => <g key={i} transform={`translate(${a.x}, ${a.y})`}>
<line x1={0} x2={0} y1={-5} y2={5} stroke={tickColor} strokeWidth={1} />
<text x={0} y={18} fontSize={10} textAnchor="middle" fill={tickColor}>{a.label}</text>
</g>)}
</svg>
</div>;
};
4 changes: 4 additions & 0 deletions src/utils/data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ export function makeArray<T = number>(length: number, val?: T): T[] {
return new Array(length).fill(val ?? 0);
}

export function makeArrayRange(length: number, min: number, max: number): number[] {
return new Array(length).fill(0).map((_, i) => min + i * (max - min) / (length - 1));
}

export function oneHotArray(length: number, index: number, val: number, defaultVal: number = 0.0): number[] {
let arr = new Array(length).fill(defaultVal);
arr[index] = val;
Expand Down

0 comments on commit d0588b4

Please sign in to comment.