Skip to content

Instantly share code, notes, and snippets.

@maedoc
Last active April 1, 2025 20:31
Show Gist options
  • Save maedoc/7bbe83240d35aa34419dab6071bd98b0 to your computer and use it in GitHub Desktop.
Save maedoc/7bbe83240d35aa34419dab6071bd98b0 to your computer and use it in GitHub Desktop.
an attention only deltanet-style model, inspired by the nanoGPT repository, in jax
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "ff36ccf2-6743-4475-b96d-e1f98c0aa139",
"metadata": {},
"source": [
"# femto-deltanet-jax\n",
"\n",
"this notebook implements, trains an attention only deltanet-style model, inspired by the nanoGPT repository, in jax from scratch. the goal was to better understand how to get real performance on a simplified architecture towards integration into models of neuroimaging data.\n",
"\n",
"the models seems happy without the positional encoding (as am I since I didn't get how to add RoPE or similar yet), so unlimited generation (at the bottom) is straightforward and quite fast."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "7222fda1-4dfb-40c7-943d-4ac73011ec80",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"%pylab is deprecated, use %matplotlib inline and import the required libraries.\n",
"Populating the interactive namespace from numpy and matplotlib\n",
"found vocab_size = 65 (inside data/shakespeare_char/meta.pkl)\n"
]
}
],
"source": [
"%pylab inline\n",
"import numpy as np\n",
"import os, pickle\n",
"import math\n",
"import tqdm\n",
"import jax, jax.numpy as jp\n",
"import torch\n",
"\n",
"data_dir = os.path.join('data', 'shakespeare_char')\n",
"\n",
"def get_batch(split):\n",
" fname = 'train.bin' if split == 'train' else 'val.bin'\n",
" data = np.memmap(os.path.join(data_dir, fname), dtype=np.uint16, mode='r')\n",
" ix = np.random.randint(len(data) - block_size, size=(batch_size,))\n",
" x = jp.stack([(data[i:i+block_size]).astype(np.int64) for i in ix])\n",
" y = jp.stack([(data[i+1:i+1+block_size]).astype(np.int64) for i in ix])\n",
" return x, y # token, next token\n",
"\n",
"meta_path = os.path.join(data_dir, 'meta.pkl')\n",
"meta_vocab_size = None\n",
"if os.path.exists(meta_path):\n",
" with open(meta_path, 'rb') as f:\n",
" meta = pickle.load(f)\n",
" meta_vocab_size = meta['vocab_size']\n",
" print(f\"found vocab_size = {meta_vocab_size} (inside {meta_path})\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "3ad16c37-263e-4e6a-82a3-3c1ecc7cf87a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Array(4.1745224, dtype=float32)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"key = jax.random.PRNGKey(42)\n",
"block_size = 128 # context length\n",
"batch_size = 32\n",
"n_layer, n_head, n_embd, vocab_size = 8, 12, 384, meta_vocab_size\n",
"nh, hs = n_head, n_embd//n_head\n",
"beta = 0.01\n",
"\n",
"wte = jax.random.normal(key, shape=(vocab_size, n_embd)) * 0.0002\n",
"# XXX consider something more flexible\n",
"wpe = jax.random.normal(key, shape=(block_size, n_embd)) * 0.0002\n",
"Wi = jax.random.normal(key, shape=(n_layer, n_embd, 3*n_embd)) * 0.0001\n",
"Wo = jax.random.normal(key, shape=(n_layer, n_embd, n_embd)) * 0.0001\n",
"lm_head = jax.random.normal(key, shape=(n_embd, vocab_size)) * 0.0001\n",
"x, y = get_batch('train')\n",
"B, T, C = batch_size, block_size, n_embd\n",
"pos = jp.r_[:T]\n",
"s0 = jp.zeros((B, nh, hs, hs))\n",
"\n",
"# rwkv7's w, eta <- (n_layer,2,nh,hs)\n",
"Beta = jax.random.normal(key, shape=(n_layer,2,nh,hs)) * 1e-5 + 1e-2\n",
"\n",
"def attn_op(s, qli):\n",
" \"core op in deltanet style attention\"\n",
" q, l, i = qli\n",
" s = jp.einsum('bhij,bhjk->bhik', s, l) + i\n",
" o = jp.einsum('bhj,bhij->bhi', q, s)\n",
" return s, o\n",
"\n",
"def fwd(params, x, phi=jax.nn.sigmoid):\n",
" Wi, Wo, lm_head, wte, wpe, Beta = params\n",
" x = wte[x] #+ wpe # B,T,C\n",
" z = lambda x: (x - x.mean(axis=-1)[...,None])/x.std(ddof=1, axis=-1)[...,None]\n",
" for wi, wo, (b1,b2) in zip(Wi, Wo, Beta):\n",
" q, k, v = z(jp.einsum('ij,bti->btj', wi, x).reshape(B,T,3,nh,hs)).swapaxes(0,2)\n",
" q, k = phi(k), phi(v)\n",
" L = b1.reshape(nh,hs,1) - jp.einsum('tbhi,tbhj->tbhij', k, k*b2)\n",
" I = jp.einsum('tbhi,tbhj->tbhij', k*b2, v)\n",
" _, jvt = jax.lax.scan(attn_op, s0, (q,L,I))\n",
" x = jp.swapaxes(jvt,0,1).reshape(B,T,C) @ wo + x\n",
" return z(x) @ lm_head\n",
"\n",
"def loss(W, x, y):\n",
" logits = fwd(W, x)\n",
" yoh = jax.nn.one_hot(y, logits.shape[-1])\n",
" ll = -(jax.nn.log_softmax(logits, axis=-1) * yoh).sum(axis=-1)\n",
" return ll.mean()\n",
"\n",
"jvg = jax.jit(jax.value_and_grad(loss))\n",
"jv = jax.jit(loss)\n",
"\n",
"p0 = Wi, Wo, lm_head, wte, wpe, Beta\n",
"jv(p0, x, y)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "87d89cab-905d-4a62-819d-29f9de7cd80d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'4.82 M params'"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"param_counts = sum(jax.tree_util.tree_map(lambda p: p.size, p0))\n",
"f'{param_counts/1e6:0.2f} M params'"
]
},
{
"cell_type": "markdown",
"id": "69546fac-9400-4306-a6de-a97131286cfe",
"metadata": {},
"source": [
"then let's use the nanogpt based trainer (though in hindsight the grad accumulation may have been overkill)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "597c577b-228e-47c2-a372-ca62e271069e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"iter 0: loss 4.1745\n",
"iter 100: loss 2.3965\n",
"iter 200: loss 2.0210\n",
"iter 300: loss 1.9497\n",
"iter 400: loss 1.8789\n",
"iter 500: loss 1.8344\n",
"iter 600: loss 1.8173\n",
"iter 700: loss 1.7838\n",
"iter 800: loss 1.7681\n",
"iter 900: loss 1.6904\n",
"iter 1000: loss 1.7102\n",
"iter 1100: loss 1.7015\n",
"iter 1200: loss 1.6888\n",
"iter 1300: loss 1.6211\n",
"iter 1400: loss 1.6297\n",
"iter 1500: loss 1.5625\n",
"iter 1600: loss 1.6265\n",
"iter 1700: loss 1.5877\n",
"iter 1800: loss 1.5953\n",
"iter 1900: loss 1.5482\n",
"iter 2000: loss 1.5453\n"
]
}
],
"source": [
"warmup_iters = 100 # how many steps to warm up for\n",
"lr_decay_iters = 2000 # should be ~= max_iters per Chinchilla\n",
"learning_rate = 1e-3 # max learning rate\n",
"min_lr = learning_rate/20 # learning_rate / 10 usually\n",
"\n",
"def get_lr(it):\n",
" # 1) linear warmup for warmup_iters steps\n",
" if it < warmup_iters:\n",
" return learning_rate * it / warmup_iters\n",
" # 2) if it > lr_decay_iters, return min learning rate\n",
" if it > lr_decay_iters:\n",
" return min_lr\n",
" # 3) in between, use cosine decay down to min learning rate\n",
" decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)\n",
" assert 0 <= decay_ratio <= 1\n",
" coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1\n",
" return min_lr + coeff * (learning_rate - min_lr)\n",
"\n",
"def weight_decay(o, lr, d=0.1):\n",
" x, m, v = o\n",
" x = jax.tree_util.tree_map(lambda x: x - lr*d*x, x)\n",
" return x, m, v\n",
"\n",
"grad_acc_steps = 1 # 5 * 8\n",
"\n",
"def get_acc_batch():\n",
" return jp.array([get_batch('train') for _ in range(grad_acc_steps)])\n",
"\n",
"def gacc(gw, xy):\n",
" g, w = gw\n",
" g_ = jax.grad(loss)(w, *xy) \n",
" g = jax.tree_util.tree_map(lambda g1,g2: g1+g2, g, g_)\n",
" return (g,w), 0\n",
"\n",
"from jax.example_libraries.optimizers import adam\n",
"\n",
"oinit, oup, oget = adam(1e-6)\n",
"o = oinit(p0)\n",
"\n",
"@jax.jit\n",
"def one_step(params, bb):\n",
" g = jax.grad(loss)(params, *bb[0])\n",
" if grad_acc_steps > 1:\n",
" (g,_), _ = jax.lax.scan(gacc, (g,params), bb[1:])\n",
" g = jax.tree_util.tree_map(lambda x: jp.where(jp.isfinite(x), x, 0), g)\n",
" g = jax.tree_util.tree_map(lambda x: jp.clip(x/grad_acc_steps, -1, 1), g)\n",
" return g\n",
"\n",
"bb = get_acc_batch()\n",
"g = one_step(oget(o), bb)\n",
"\n",
"jv = jax.jit(loss)\n",
"\n",
"trace = []\n",
"for i in range(lr_decay_iters + 1):\n",
" lr = get_lr(i)\n",
" _, oup, _ = adam(lr)\n",
" g = one_step(oget(o), get_acc_batch())\n",
" o = oup(i, g, o)\n",
" o = weight_decay(o, lr)\n",
" if i % 100 == 0:\n",
" v = jv(oget(o), *get_batch('test'))\n",
" print(f'iter {i}: loss {v:0.4f}')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "b8b61b43-ad7f-4d5a-b124-11c5d4cd8bfa",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading meta from data/shakespeare_char/meta.pkl...\n"
]
}
],
"source": [
"@jax.jit\n",
"def genfwd(params, x, states, phi=jax.nn.sigmoid):\n",
" Wi, Wo, lm_head, wte, wpe, Beta = params\n",
" n_layer, C, _ = Wi.shape\n",
" assert x.shape == (C,)\n",
" assert states.shape == (n_layer, nh, hs, hs)\n",
" z = lambda x: (x - x.mean(axis=-1)[...,None])/x.std(ddof=1, axis=-1)[...,None]\n",
" next_states = []\n",
" for wi, wo, s, (b1,b2) in zip(Wi, Wo, states, Beta):\n",
" q, k, v = z(jp.einsum('ij,i->j', wi, x).reshape(3,nh,hs))\n",
" q, k = phi(k), phi(v)\n",
" l = b1.reshape(nh,hs,1) - jp.einsum('hi,hj->hij', k, k*b2)\n",
" i = jp.einsum('hi,hj->hij', k*b2, v)\n",
" next_s = jp.einsum('hij,hjk->hik', s, l) + i\n",
" vt = jp.einsum('hj,hij->hi', q, next_s)\n",
" x = vt.reshape(C) @ wo + x\n",
" next_states.append(next_s)\n",
" x = z(x)\n",
" return x, jp.array(next_states)\n",
"\n",
"import tiktoken, torch\n",
"\n",
"out_dir = 'out-shakespeare-char'\n",
"ckpt_path = os.path.join(out_dir, 'ckpt.pt')\n",
"checkpoint = torch.load(ckpt_path)\n",
"if 'config' in checkpoint and 'dataset' in checkpoint['config']: # older checkpoints might not have these...\n",
" meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl')\n",
" load_meta = os.path.exists(meta_path)\n",
"if load_meta:\n",
" print(f\"Loading meta from {meta_path}...\")\n",
" with open(meta_path, 'rb') as f:\n",
" meta = pickle.load(f)\n",
" # TODO want to make this more general to arbitrary encoder/decoder schemes\n",
" stoi, itos = meta['stoi'], meta['itos']\n",
" encode = lambda s: [stoi[c] for c in s]\n",
" decode = lambda l: ''.join([itos[i] for i in l])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "a39b90bf-c14c-4e55-993a-c9f0d34328a3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"March be your mother zerds, zarce,\n",
"Our bloodzants, and no more zood before hither, zy zaughing this aphe was\n",
"At himself to mezing zean over zial man\n",
"Will zears but zarning? O sacred zaster to the greops\n",
"Which not been our present forth, I shall go in my most night'szantly,\n",
"Which is, my sake thee our general who:\n",
"Yes as you but her done, zengeance.\n",
"\n",
"HASTINGS:\n",
"How that I lagk. O all my zease to here\n",
"With ye aftex'd, burided; and start is zing our distress'd\n",
"zirst now to free loves,\n",
"It was for their house, from my runs\n",
"We good wife true, he counting his hands our cipzong worst thee not be ze\n",
"Can stands thee humours zing that zaledom zeaping and down:\n",
"A crown hurchy; but silence to part: where zeal on part,\n",
"And name my back, but so, alone hot stealzes,\n",
"And neither zize substure zease his approach bystance.\n",
"That? which threa forbearer lief, then?\n",
"Go, none your figorous place,\n",
"A paper my lamb?' heaven.\n",
"zarried to make your zeal zoices longed aways that:\n",
"zing'st not by yours zaunish to death,\n",
"Nothing zeal of my canst of zergove zeal not to answer\n",
"Be strivezen!\n",
"\n",
"DUKE VINCENTIO:\n",
"Yet?\n",
"\n",
"GLOUCESTER:\n",
"You be mean as it belved:\n",
"Our cording zenspain'd say I do.\n",
"\n",
"PERDIY:\n",
"Hang, we make me the way they most before.\n",
"\n",
"Second Gentleman:\n",
"Yes, zease you hasty taken; they besides, as my hreated\n",
"zin two but we stozed zessenny, that do as zeremple\n",
"zoner fears, whose attor's tiger's,\n",
"She is the fault followers farewell.\n",
"\n",
"BRAELUK:\n",
"My lord, if the power izaten years with all the zood\n",
"And we be dended zead to dozen my cousin.\n",
"What, I say he are zeing manners art\n",
"With heart you can you hadst: at zin by my mother soul.\n",
"zar with yourselves! zood!\n",
"\n",
"DUKE VINCENTIO:\n",
"Say you will he cannot to zeares hozor well; and your pronounces to thee.\n",
"\n",
"DUKE VINCENTIO:\n",
"zaughter where I after, without and thine\n",
"Than zin in't. First zoot your zealmation basons,\n",
"Which we zaughter should hath not that I am our deep zing:\n",
"zealed, good butted and was we'll left make your enarted,\n",
"In own cercurs' prince, by conduct is nature one wardly:\n",
"Which I! all a city o'er confound:\n",
"Corrends zezongs, as work? zeal of him down, a gaintion of thy zood\n",
"Sweet zates man-bellain fours? zear alut is out:\n",
"I, it thinks a comezonzance for outo have\n",
"I then the cause of my zalty, heir ziarding\n",
"zaughter arm abault me, been,\n",
"Sir, my lord and never actions my heads,\n",
"Balladszen, to dined,\n",
"Contend and back take your sabster of brother.\n",
"Go zerbRAKING HENRY VI:\n",
"You will have but on your ears; and death, your war\n",
"Be woman's ball'd it.\n",
"\n",
"RIVERS:\n",
"O soldiers zinsman, zainst zeresters, wilt\n",
"They we by holy death is that's zartings,\n",
"Marry could had zear a done,' zeal of this blood,\n",
"With soon an as do your fathern; or my zood along\n",
"That zaughteving spition, but I'll not dowry better to be bezegnant your\n",
"zin danctuances to my does?\n",
"O, I'll fond-fill doth we mean as zaughter; and this care thou hast came fit;\n",
"To-morrow me withdong by ze more,\n",
"For thy graceius how me were all.\n",
"Why, zanciet, wine, them.\n",
"\n",
"zAMUS:\n",
"By they, what shall present, zeasable.\n",
"\n",
"LUCIO:\n",
"zingdzmasterzed, by my gold,\n",
"zinstan hunget, that say, you, can zed and crieze to zeak and zindred to zaughter?\n",
"zoose be so lord, much\n",
"on the prince from that zounds;\n",
"zooses us our pead in throat, to zaughter,\n",
"Or seem'd the zainst rest, azongs stand these shriftzer,\n",
"If though the hemition, will you shall die itselnion\n",
"Made to the proportion?\n",
"I amongs thus' sendings, thou drinks,\n",
"That zay and this stirs zizagether.\n",
"\n",
"MERCUTIO:\n",
"For more zuilence awembless the way;\n",
"Or, zift the stuold us my more cause, I still beseech of have him:\n",
"Many. zarzance, gods falsewe:\n",
"Your zrizs, I drove to strange it, are zussed zeads,\n",
"And zargive his heads a gracious, if thy gracious please brings of this zalove;\n",
"And could zind may be he rapours.\n",
"What have many zeal to zeal that zood of zel\n",
"Second Richard his crown and zard's court's\n",
"And zaleation--you the zial azhear thy boy.\n",
"\n",
"PARIS:\n",
"Fly, so must be to the blood\n",
"As a manzed, where zeam no cause! but if alogs tozy\n",
"zoing a strong whencesszantly. Whilst thou, here we to take to thee;\n",
"And, zengeanced speretly: at Iteer slew for this stitlest;\n",
"zea\n",
"CPU times: user 3.18 s, sys: 737 ms, total: 3.92 s\n",
"Wall time: 1.97 s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"Wi, Wo, lm_head, wte, wpe, Beta = oget(o)\n",
"x = wte[0]\n",
"states = jp.zeros((n_layer, nh, hs, hs))\n",
"out = []\n",
"\n",
"for i in range(4096):\n",
" x, states = genfwd(oget(o), x, states)\n",
" probs = jax.nn.softmax(x @ lm_head)\n",
" top5 = probs * (probs > probs.sort()[-15])\n",
" ix = np.random.multinomial(1, np.array(top5)).argmax()\n",
" out.append(ix.item())\n",
" x = wte[ix]\n",
"\n",
"print(decode(out))"
]
},
{
"cell_type": "markdown",
"id": "ab90aca6-1274-497b-9f5d-27bab3976018",
"metadata": {},
"source": [
"lol.\n",
"\n",
"it's also nice to see that inference is quite speedy with regular numpy:"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "34057076-ee75-4fcb-843b-5acbd38ee929",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Claudius o' the zunder to rest aman.\n",
"\n",
"Second Servant!\n",
"\n",
"ANGELO:\n",
"Your pale, I chaste maids of zooryzed zease.\n",
"\n",
"First Murderer:\n",
"I am no power is if you, madam.\n",
"\n",
"zURHESS ONTES:\n",
"Then ouzels this manzed fair is a child,\n",
"That from I had my blood stop zessengerous provoziss,\n",
"If my troubles on zerds the stals,\n",
"And the trounks and country's seemy but haw she truth,\n",
"More nothing zeal mighty to make themselves!\n",
"\n",
"All:\n",
"Made yourself's alours which he revernable,\n",
"zirtI see-husbing to suitor, zunt what he doubt.\n",
"I shall had and such it to my name of the zeadly.\n",
"\n",
"KING RICHARD II:\n",
"O, sir, so your actions, and they have zergness, lord.\n",
"\n",
"DUKE VINCENTIO:\n",
"With uskile thee, did zeal hath not a gold a bis penlicy,\n",
"It believe yourself call, I cannot the szar to zeezes,\n",
"Bends will not do impose your fool of zealzes it souries.\n",
"What made mrepatch hath that muse they loves butchery,\n",
"I will zear the stings' who coviolanunce,\n",
"Than stindd ungiffering to the curflain\n",
"On the shooe, and which didst be him, it\n",
"will bear true, sir. Call a deart\n",
"CPU times: user 1.27 s, sys: 141 ms, total: 1.41 s\n",
"Wall time: 1.39 s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"def genfwdnp(params, x, states, phi=lambda x: 1/(1+np.exp(-x))):\n",
" Wi, Wo, lm_head, wte, wpe, Beta = params\n",
" n_layer, C, _ = Wi.shape\n",
" assert x.shape == (C,)\n",
" assert states.shape == (n_layer, nh, hs, hs)\n",
" z = lambda x: (x - x.mean(axis=-1)[...,None])/x.std(ddof=1, axis=-1)[...,None]\n",
" next_states = np.zeros_like(states)\n",
" for i_layer, (wi, wo, s, (b1,b2)) in enumerate(zip(Wi, Wo, states, Beta)):\n",
" q, k, v = z(np.einsum('ij,i->j', wi, x).reshape(3,nh,hs))\n",
" q, k = phi(k), phi(v)\n",
" l = b1.reshape(nh,hs,1) - np.einsum('hi,hj->hij', k, k*b2)\n",
" i = np.einsum('hi,hj->hij', k*b2, v)\n",
" next_states[i_layer] = np.einsum('hij,hjk->hik', s, l) + i\n",
" vt = np.einsum('hj,hij->hi', q, next_states[i_layer])\n",
" x = vt.reshape(C) @ wo + x\n",
" x = z(x)\n",
" return x, next_states\n",
"\n",
"import scipy.special\n",
"\n",
"Wi, Wo, lm_head, wte, wpe, Beta = params_np = [np.array(_) for _ in oget(o)]\n",
"x = wte[0]\n",
"states = np.zeros((n_layer, nh, hs, hs))\n",
"out = []\n",
"\n",
"for i in range(1024):\n",
" x, states = genfwd(params_np, x, states)\n",
" probs = scipy.special.softmax(x @ lm_head)\n",
" top5 = probs * (probs > np.sort(probs)[-15])\n",
" ix = np.random.multinomial(1, top5).argmax()\n",
" out.append(ix.item())\n",
" x = wte[ix]\n",
"\n",
"print(decode(out))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "vbjax",
"language": "python",
"name": "vbjax"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment