Skip to content

Instantly share code, notes, and snippets.

@nhira
Last active March 2, 2024 20:18
Show Gist options
  • Save nhira/8012ed74d003efa201f086488684e800 to your computer and use it in GitHub Desktop.
Save nhira/8012ed74d003efa201f086488684e800 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"last_runtime": {
"build_target": "//learning/grp/tools/ml_python:ml_notebook",
"kind": "private"
},
"private_outputs": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"source": [
"# High performance LLMs - 2024 (Session 1)\n",
"\n",
"This notebook is a companion to Rafi Witten's class (content & recording on [GitHub](https://github.com/rwitten/HighPerfLLMs2024/tree/main/s01)). In this first session, we train a simple (limited) neural network using the [lm1b](https://www.statmt.org/lm-benchmark/) dataset. \n",
"\n",
"We learn:\n",
"1. How to set up a training loop\n",
"1. Why encoding matters\n",
"1. How to observe the network as training continues"
],
"metadata": {
"id": "QwCqI3dxLuYV"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-uXcr0UvkXsp"
},
"outputs": [],
"source": [
"import tensorflow_datasets as tfds\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import flax.linen as nn\n",
"from flax.training import train_state\n",
"\n",
"import numpy as np\n",
"\n",
"import optax\n",
"\n",
"BATCH_IN_SEQUENCES = 384\n",
"SEQUENCE_LENGTH = 128\n",
"\n",
"VOCAB_DIM = 256\n",
"EMBED_DIM = 512\n",
"FF_DIM = 2048\n",
"\n",
"# we will use this character (as an ASCII code) to get the model to propose results\n",
"# we need to use a character that is not expected in the dataset often (0 can't be used for that reason)\n",
"BLANK = 255\n",
"\n",
"LAYERS = 4\n",
"\n",
"HEAD_DEPTH = 128\n",
"NUM_HEADS = 4\n",
"\n",
"LEARNING_RATE = 1e-3\n"
]
},
{
"cell_type": "code",
"source": [
"class OurModel(nn.Module):\n",
" @nn.compact\n",
" def __call__(self, x):\n",
" '''\n",
" x is [BATCH, SEQUENCE]\n",
" '''\n",
" embedding = self.param(\n",
" 'embedding',\n",
" nn.initializers.normal(1),\n",
" (VOCAB_DIM, EMBED_DIM),\n",
" jnp.float32,\n",
" )\n",
" x = embedding[x] ##OUTPUT should be [BATCH, SEQUENCE, EMBED]\n",
"\n",
" for i in range(LAYERS):\n",
" feedforward = self.param(\n",
" 'feedforward_' + str(i),\n",
" nn.initializers.lecun_normal(),\n",
" (EMBED_DIM, FF_DIM),\n",
" jnp.float32,\n",
" )\n",
" x = x @ feedforward\n",
" x = jax.nn.relu(x)\n",
" embed = self.param(\n",
" 'embed_' + str(i),\n",
" nn.initializers.lecun_normal(),\n",
" (FF_DIM, EMBED_DIM),\n",
" jnp.float32,\n",
" )\n",
" x = x @ embed\n",
" x = jax.nn.relu(x)\n",
"\n",
" return x @ embedding.T\n"
],
"metadata": {
"id": "AcdDu5elnzfc"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def convert_to_ascii(string_array, max_length):\n",
" result = np.zeros((len(string_array), max_length), dtype=np.uint8)\n",
" for i, string in enumerate(string_array):\n",
" for j, char in enumerate(string):\n",
" if j >= SEQUENCE_LENGTH:\n",
" break\n",
" result[i, j] = char\n",
" return result\n",
"\n"
],
"metadata": {
"id": "lK3QghYTobhr"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"We expect input data to be null-terminated (`chr(0)`), so we use a character we expect to be minimally present in our input to trigger prediction (`chr(255)`). For each sequence in our input, we replace the first character with this ``BLANK`` character to get the model to propose outputs."
],
"metadata": {
"id": "ImqQKgOCtThj"
}
},
{
"cell_type": "code",
"source": [
"def prepare_input(np_array):\n",
" zero_array = np.zeros( (BATCH_IN_SEQUENCES,SEQUENCE_LENGTH), dtype = jnp.uint8)\n",
" # zero_array[:, 1:SEQUENCE_LENGTH] = np_array[:, 0:SEQUENCE_LENGTH-1]\n",
" zero_array[:, 1:SEQUENCE_LENGTH] = np_array[:, 1:SEQUENCE_LENGTH]\n",
" zero_array[:, 0] = BLANK\n",
" return zero_array\n"
],
"metadata": {
"id": "PchIJdaOob1U"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def print_output_from_logits(output_logits, output_index):\n",
" _o = jax.numpy.argmax(output_logits, axis=2)[output_index]\n",
" print(\" ==> output-a: \", _o[0:18])\n",
" _s = \"\"\n",
" for _c in _o:\n",
" _s += chr(_c) if (_c >= 32) else \"_\"\n",
" print(\" ==> output-s: \", _s)\n",
"\n",
"def try_model(params, model, input_index=0):\n",
" print(\" Trying with input:\", example['text'][input_index])\n",
" _expected = convert_to_ascii(example['text'].numpy(), SEQUENCE_LENGTH)\n",
" print(\" ==> input: \", _expected[input_index][0:18])\n",
" _i = prepare_input(_expected)\n",
" _o = model.apply(params, _i)\n",
" print_output_from_logits(_o, input_index)\n",
"\n",
"def calculate_loss(params, model, inputs, outputs):\n",
" proposed_outputs = model.apply(params, inputs)\n",
" one_hot = jax.nn.one_hot(outputs, VOCAB_DIM)\n",
" loss = optax.softmax_cross_entropy(proposed_outputs, one_hot)\n",
" return jnp.mean(loss)\n"
],
"metadata": {
"id": "kZ7ex-3xocDW"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Training\n",
"Let's start by loading our dataset."
],
"metadata": {
"id": "cnx9E734l5kh"
}
},
{
"cell_type": "code",
"source": [
"ds = tfds.load('lm1b', split='train', shuffle_files=False)\n",
"ds = ds.batch(BATCH_IN_SEQUENCES)\n",
"print(len(ds))"
],
"metadata": {
"id": "XLC53oGJpsjS"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Next, let's use [Optax](https://github.com/google-deepmind/optax) (Adam, specifically) to optimize our training."
],
"metadata": {
"id": "G7B22TtKmFfa"
}
},
{
"cell_type": "code",
"source": [
"rngkey = jax.random.key(0)\n",
"model = OurModel()\n",
"_params = model.init(rngkey, jnp.ones((BATCH_IN_SEQUENCES, SEQUENCE_LENGTH), dtype = jnp.uint8))\n",
"tx = optax.adam(learning_rate = LEARNING_RATE)\n",
"state = train_state.TrainState.create(\n",
" apply_fn = model.apply,\n",
" params = _params,\n",
" tx = tx\n",
")"
],
"metadata": {
"id": "VVra0jY8p4Tt"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"We use the ``break`` at line 18 to interrupt the training cycle and try out the model (next few cells).\n",
"\n",
"We also keep track of the least loss using ``minloss`` so we can try out the model and see what it has learned since the last \"minimum loss\". For example, we see that in the first 10 or so iterations, the model is still learning to \"copy\" the input over to the output. Until iteration 500 (or so), this \"copy\" operation continues to improve.\n",
"\n",
"In the first position, we see the model disproportionately choose 'T' as output. Considering there is no attention here, it makes sense that the model would favor some character. The choice of 'T' Vs. 'A' or double-quotes is likely just the data.\n",
"\n",
"Beyond the first 3,000 iterations, the loss values stop converging and the model actually \"loses\" its ability to \"copy\"."
],
"metadata": {
"id": "Gy6wLbFHEyHX"
}
},
{
"cell_type": "code",
"source": [
"# track steps toward convergence (each time a new minimum loss is found)\n",
"minloss = 100\n",
"for _step, example in enumerate(ds):\n",
" outputs = convert_to_ascii(example['text'].numpy(), SEQUENCE_LENGTH)\n",
" inputs = prepare_input(outputs)\n",
"\n",
" loss, grad = jax.value_and_grad(calculate_loss)(state.params, model, inputs, outputs)\n",
" state = state.apply_gradients(grads = grad)\n",
" if (((_step % 100) == 0) or (loss < minloss)):\n",
" if (loss < minloss):\n",
" print(f\"{_step} -> {loss} **new minimum**)\")\n",
" minloss = loss\n",
" try_model(state.params, model)\n",
" else:\n",
" print(f\"{_step} -> {loss} ({minloss})\")\n",
"\n",
" # stop after the first 2000 iterations\n",
" if _step>=1000: break"
],
"metadata": {
"id": "anvlBxSpp-e1"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"So how prevalent are T's in the first position in our dataset?"
],
"metadata": {
"id": "mVJr-c-ZWFM-"
}
},
{
"cell_type": "code",
"source": [
"# our distribution table\n",
"_distribution = [0] * 255\n",
"_total_count = 0\n",
"# iterate through the first 1000 batches of our dataset\n",
"for _i, _data in enumerate(ds):\n",
" # for each sequence in each batch\n",
" for _j, _sent in enumerate(_data['text'].numpy()):\n",
" # update the counter for the character in the first position\n",
" _distribution[_sent[0]] += 1\n",
" _total_count += 1\n",
" if _i>=999: break\n",
"\n",
"# print out the most frequent characters\n",
"for _i, _count in enumerate(_distribution):\n",
" if (_count > (0.33 * max(_distribution))):\n",
" print(f\"{_i} {chr(_i)} {_count} ({round(_count*100/_total_count, 2)}%)\")\n"
],
"metadata": {
"id": "dLkVcqcnER_U"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"We saw this with the observations in training, but let's confirm that our model isn't taking character position into account (``abc`` is treated the same as ``bca``).\n",
"\n",
"To test, we'll look at the results of input ``[187, 134, 12]`` Vs. ``[134, 12, 187]``)."
],
"metadata": {
"id": "Lcc7jsvsCQ0l"
}
},
{
"cell_type": "code",
"source": [
"sequence_1 = [187, 134, 12]\n",
"sequence_2 = [134, 12, 187]\n",
"\n",
"# let's get the output for sequence_1\n",
"print(\"Sequence 1\")\n",
"_i = np.zeros( (BATCH_IN_SEQUENCES,SEQUENCE_LENGTH), dtype = jnp.uint8)\n",
"_i[0, 0:3] = sequence_1\n",
"print(_i[0])\n",
"_o_sequence_1 = model.apply(state.params, _i)\n",
"print(\"187 = \", _o_sequence_1[0][0][0:6])\n",
"print(\"134 = \", _o_sequence_1[0][1][0:6])\n",
"print(\" 12 = \", _o_sequence_1[0][2][0:6])\n",
"\n",
"# let's get the output for sequence_2\n",
"print(\"\\nSequence 2\")\n",
"_i = np.zeros( (BATCH_IN_SEQUENCES,SEQUENCE_LENGTH), dtype = jnp.uint8)\n",
"_i[0, 0:3] = sequence_2\n",
"print(_i[0])\n",
"_o_sequence_2 = model.apply(state.params, _i)\n",
"print(\"134 = \", _o_sequence_2[0][0][0:6])\n",
"print(\" 12 = \", _o_sequence_2[0][1][0:6])\n",
"print(\"187 = \", _o_sequence_2[0][2][0:6])"
],
"metadata": {
"id": "D8tmw2bSNuMM"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"For each sequence above, the first line shows the input and the following three lines show the model results (probabilities for each vocabulary entry, which happens to be a single character for us).\n",
"\n",
"For brevity, we're only looking at the first 6 probabilities, but notice how they seem to match. The results for the 134 entry seem to be the same, regardless of which position the character is in.\n",
"\n",
"But are they really identical?"
],
"metadata": {
"id": "eXlWN3L-ncJF"
}
},
{
"cell_type": "code",
"source": [
"# are the results for sequence_1/char identical to sequence_2/char?\n",
"print(\"Are results for 187 identical? \", (_o_sequence_1[0][0] == _o_sequence_2[0][2]).all())\n",
"print(\"Are results for 134 identical? \", (_o_sequence_1[0][1] == _o_sequence_2[0][0]).all())\n",
"print(\"Are results for 12 identical? \", (_o_sequence_1[0][2] == _o_sequence_2[0][1]).all())"
],
"metadata": {
"id": "ImcJ7DlMpRr-"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"So what does that mean? These are examples that show that our model isn't really learning about \"words\" as much as individual characters, almost like a substitution cipher.\n",
"\n",
"Let's see some more examples ..."
],
"metadata": {
"id": "xGKhjrizq-aL"
}
},
{
"cell_type": "code",
"source": [
"for _i in range(10):\n",
" try_model(state.params, model, _i)"
],
"metadata": {
"id": "5z_tM7VLc8Ij"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Conclusion\n",
"We learned how to use JAX to train a simple neural net and observe its capabilities. We also saw that as a net trains with more data, it might \"lose convergence\"."
],
"metadata": {
"id": "XIqZeLm5sVGz"
}
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment