Skip to content

Instantly share code, notes, and snippets.

@arthurcolle
Created February 20, 2025 03:36
Show Gist options
  • Save arthurcolle/033db4b2cc961f7d718a4989db0cb626 to your computer and use it in GitHub Desktop.
Save arthurcolle/033db4b2cc961f7d718a4989db0cb626 to your computer and use it in GitHub Desktop.
let's build gpt2
[00:00:00.000 --> 00:00:04.320] Hi everyone. So today we are going to be continuing our Zero2Hero series
[00:00:04.320 --> 00:00:10.640] and in particular today we are going to reproduce the GPT2 model, the 124 million version of it.
[00:00:10.640 --> 00:00:17.440] So when OpenAI released GPT2, this was 2019 and they released it with this blog post.
[00:00:17.440 --> 00:00:23.040] On top of that they released this paper and on top of that they released this code on GitHub,
[00:00:23.040 --> 00:00:29.600] so OpenAI/GPT2. Now when we talk about reproducing GPT2, we have to be careful because in particular
[00:00:29.600 --> 00:00:34.880] in this video we're going to be reproducing the 124 million parameter model. So the thing to
[00:00:34.880 --> 00:00:41.040] realize is that there's always a miniseries when these releases are made, so there are the GPT2
[00:00:41.040 --> 00:00:46.800] miniseries made up of models at different sizes and usually the biggest model is called the GPT2.
[00:00:46.800 --> 00:00:52.160] But basically the reason we do that is because you can put the model sizes on the x-axis of
[00:00:52.160 --> 00:00:56.720] plots like this and on the y-axis you put a lot of downstream metrics that you're interested in,
[00:00:56.720 --> 00:01:02.000] like translation, summarization, question answering and so on, and you can chart out the scaling loss.
[00:01:02.000 --> 00:01:06.800] So basically as the model size increases you're getting better and better at downstream metrics.
[00:01:06.800 --> 00:01:14.480] And so in particular for GPT2, if we scroll down in the paper, there are four models in the GPT2
[00:01:14.480 --> 00:01:22.720] miniseries starting at 124 million all the way up to 1558 million. Now the reason my numbers,
[00:01:22.720 --> 00:01:27.520] the way I say them, disagree with this table is that this table is wrong. If you actually go to the
[00:01:27.520 --> 00:01:34.400] GPT2 github repo they sort of say that there was an error in how they added up the parameters,
[00:01:34.400 --> 00:01:40.320] but basically this is the 124 million parameter model, etc. So the 124 million parameter had 12
[00:01:40.320 --> 00:01:47.840] layers in the transformer and it had 768 channels in the transformer, 768 dimensions. And I'm going
[00:01:47.840 --> 00:01:51.680] to be assuming some familiarity with what these terms mean because I covered all of this in my
[00:01:51.680 --> 00:01:57.200] previous video, let's build GPT2, let's build GPT from scratch. So I covered that in a previous
[00:01:57.200 --> 00:02:02.640] video in this playlist. Now if we do everything correctly and everything works out well, by the
[00:02:02.640 --> 00:02:07.360] end of this video we're going to see something like this, where we're looking at the validation loss,
[00:02:07.360 --> 00:02:13.760] which basically measures how good we are at predicting the next token in a sequence on some
[00:02:13.760 --> 00:02:18.880] validation data that the model has not seen during training. And we see that we go from doing that
[00:02:18.880 --> 00:02:23.680] task not very well, because we're initializing from scratch, all the way to doing that task quite well
[00:02:23.680 --> 00:02:29.920] by the end of the training. And hopefully we're going to beat the GPT2 124M model.
[00:02:29.920 --> 00:02:35.680] Now previously when they were working on this, this is already five years ago. So this was probably
[00:02:35.680 --> 00:02:40.560] a fairly complicated optimization at the time and the GPUs and the compute was a lot smaller.
[00:02:40.560 --> 00:02:45.920] Today you can reproduce this model in roughly an hour, or probably less even, and it will cost
[00:02:45.920 --> 00:02:50.640] you about 10 bucks if you want to do this on the cloud, cloud compute, a sort of computer that
[00:02:50.640 --> 00:02:56.640] you can all rent. And if you pay $10 for that computer, you wait about an hour or less, you can
[00:02:56.640 --> 00:03:02.720] actually achieve a model that is as good as this model that OpenAI released. And one more thing
[00:03:02.720 --> 00:03:08.800] to mention is unlike many other models, OpenAI did release the weights for GPT2. So those weights
[00:03:08.800 --> 00:03:14.640] are all available in this repository. But the GPT2 paper is not always as good with all of the
[00:03:14.640 --> 00:03:19.600] details of the training. So in addition to the GPT2 paper, we're going to be referencing the GPT3
[00:03:19.600 --> 00:03:25.200] paper, which is a lot more concrete in a lot of the hyper parameters and optimization settings and
[00:03:25.200 --> 00:03:31.680] so on. And it's not a huge departure in the architecture from the GPT2 version of the model.
[00:03:31.680 --> 00:03:37.040] So we're going to be referencing both GPT2 and GPT3 as we try to reproduce GPT2 124M.
[00:03:37.040 --> 00:03:42.960] So let's go. So the first thing I would like to do is actually start at the end or at the target.
[00:03:42.960 --> 00:03:48.480] So in other words, let's load the GPT2 124M model as it was released by OpenAI
[00:03:48.480 --> 00:03:52.800] and maybe take it for a spin. Let's sample some tokens from it. Now the issue with that is when
[00:03:52.800 --> 00:03:58.160] you go to the codebase of GPT2 and you go into the source and you click in on the model at Pi,
[00:03:58.160 --> 00:04:03.680] you'll realize that actually this is using TensorFlow. So the original GPT2 code here was written in
[00:04:03.680 --> 00:04:12.000] TensorFlow, which is, you know, not, let's just say not used as much anymore. So we'd like to use
[00:04:12.000 --> 00:04:16.480] PyTorch because it's a lot friendlier, easier, and I just personally like it a lot more.
[00:04:16.480 --> 00:04:20.240] The problem with that is the initial code is in TensorFlow. We'd like to use PyTorch.
[00:04:20.240 --> 00:04:24.000] So instead, to get the target, we're going to use the hugging face transformers
[00:04:24.000 --> 00:04:30.480] code, which I like a lot more. So when you go into the transformers, source transformers, models,
[00:04:30.480 --> 00:04:35.840] GPT2, modeling GPT2.py, you will see that they have the GPT2 implementation of that
[00:04:35.840 --> 00:04:43.840] transformer here in this file. And it's like medium readable, but not fully readable.
[00:04:43.840 --> 00:04:49.440] But what it does is it did all the work of converting all those weights
[00:04:49.440 --> 00:04:53.680] from TensorFlow to PyTorch friendly. And so it's much easier to load and work with.
[00:04:53.680 --> 00:05:01.280] So in particular, we can look at the GPT2 model here and we can load it using hugging face transformers.
[00:05:01.280 --> 00:05:07.600] So swinging over, this is what that looks like. From transformers import the GPT2
[00:05:07.600 --> 00:05:15.440] element head model and then from pre-trained GPT2. Now, one awkward thing about this is that
[00:05:15.440 --> 00:05:20.160] when you do GPT2 as the model that we're loading, this actually is the 124 million
[00:05:20.160 --> 00:05:26.560] parameter model. If you want the actual GPT2, the 1.5 billion, then you actually want to do
[00:05:26.560 --> 00:05:33.440] dash Excel. So this is the 124M, our target. Now what we're doing is when we actually get this,
[00:05:33.440 --> 00:05:38.160] we're initializing the PyTorch and then module as defined here in this class.
[00:05:38.160 --> 00:05:44.800] From it, I want to get just the state dict, which is just the raw tensors. So we just have
[00:05:44.800 --> 00:05:51.040] the tensors of that file. And by the way, here, this is a Jupyter notebook. But this is Jupyter
[00:05:51.040 --> 00:05:55.920] notebook running inside VS code. So I like to work with it all in a single
[00:05:55.920 --> 00:06:02.480] sort of interface. So I like to use VS code. So this is the Jupyter notebook extension inside VS code.
[00:06:02.480 --> 00:06:10.320] So we'll get the state dict. This is just a dict. So we can print the key and the value,
[00:06:10.320 --> 00:06:14.160] which is the tensor. And let's just look at the shapes. So these are sort of the
[00:06:14.160 --> 00:06:23.760] different parameters inside the GPT2 model and their shape. So the W wait for token embedding.
[00:06:25.840 --> 00:06:35.040] Is of size 50,257 by 768. Where this is coming from is that we have 50,257 tokens in the GPT2
[00:06:35.040 --> 00:06:40.960] vocabulary. And the tokens, by the way, these are exactly the tokens that we spoken about in
[00:06:40.960 --> 00:06:46.560] the previous video on my tokenization series. So the previous videos, just before this, I go into
[00:06:46.560 --> 00:06:52.480] a ton of detail on tokenization. GPT2 tokenizer happens to have this many tokens. For each token,
[00:06:54.000 --> 00:07:00.800] we have a 768 dimensional embedding that is the distributed representation that stands in
[00:07:00.800 --> 00:07:07.280] for that token. So each token is a little string piece. And then these 768 numbers are the vector
[00:07:07.280 --> 00:07:12.960] that represents that token. And so this just our lookup table for tokens. And then here,
[00:07:12.960 --> 00:07:19.760] we have the lookup table for the positions. So because GPT2 has a maximum sequence of 1024,
[00:07:20.400 --> 00:07:27.120] we have up to 1024 positions that each token can be attending to in the past. And every one of
[00:07:27.120 --> 00:07:36.160] those positions in GPT2 has a fixed vector of 768 that is learned by optimization. And so this is
[00:07:36.160 --> 00:07:41.840] the position embedding and the token embedding. And then everything here is just the other
[00:07:41.840 --> 00:07:47.520] weights and biases and everything else of this transformer. So when you just take, for example,
[00:07:47.520 --> 00:07:52.560] the positional embeddings and flatten it out and take just a 20 elements, you can see that these
[00:07:52.560 --> 00:07:59.120] are just parameters. These are weights, floats, just we can take and we can plot them. So these
[00:07:59.120 --> 00:08:03.920] are the position embeddings. And we get something like this, and you can see that this has structure.
[00:08:03.920 --> 00:08:11.040] And it has structure because what we have here really is every row in this visualization
[00:08:11.040 --> 00:08:17.680] is a different position, a fixed absolute position in the range from zero to 1024.
[00:08:17.680 --> 00:08:24.160] And each row here is the representation of that position. And so it has structure because
[00:08:24.160 --> 00:08:30.000] these positional embeddings end up learning these sinusoids and cosines that sort of like
[00:08:30.000 --> 00:08:36.560] represent each of these positions. And each row here stands in for that position and is processed
[00:08:36.560 --> 00:08:43.280] by the transformer to recover all the relative positions and sort of realize which token is where
[00:08:43.280 --> 00:08:46.720] and attend to them depending on their position, not just their content.
[00:08:46.720 --> 00:08:54.240] So when we actually just look into an individual column inside these, and I just grabbed three
[00:08:54.240 --> 00:09:01.840] random columns, you'll see that, for example, here we are focusing on every single channel.
[00:09:02.480 --> 00:09:13.280] And we're looking at what that channel is doing as a function of position from zero to 1023 really.
[00:09:13.280 --> 00:09:18.880] And we can see that some of these channels basically like respond more or less to different
[00:09:18.880 --> 00:09:24.800] parts of the position spectrum. So this green channel really likes to fire for everything after
[00:09:25.360 --> 00:09:32.720] 200 up to 800, but not less, but a lot less and has a sharp drop off here near zero.
[00:09:32.720 --> 00:09:36.160] So who knows what these embeddings are doing and why they are the way they are.
[00:09:36.160 --> 00:09:39.680] You can tell, for example, that because they're a bit more jagged and they're kind of noisy,
[00:09:39.680 --> 00:09:44.720] you can tell that this model was not fully trained. And the more trained this model was,
[00:09:44.720 --> 00:09:48.480] the more you would expect to smooth this out. And so this is telling you that this is a little
[00:09:48.480 --> 00:09:54.560] bit of an under trained model. But in principle, actually, these curves don't even have to be smooth.
[00:09:54.560 --> 00:09:58.720] This should just be totally random noise. And in fact, in the beginning of the optimization,
[00:09:58.720 --> 00:10:04.400] it is complete random noise because this position embedding table is initialized completely at random.
[00:10:04.400 --> 00:10:08.960] So in the beginning, you have jaggedness. And the fact that you end up with something smooth is
[00:10:08.960 --> 00:10:14.000] already kind of impressive that that just falls out of the optimization. Because in principle,
[00:10:14.000 --> 00:10:18.240] you shouldn't even be able to get any single graph out of this that makes sense. But we actually
[00:10:18.240 --> 00:10:22.480] get something that looks a little bit noisy. But for the most part, it looks sinusoidal like.
[00:10:24.000 --> 00:10:30.240] In the original transformer paper, the attention is all you need paper. The positional embeddings
[00:10:30.240 --> 00:10:36.160] are actually initialized and fixed, flammable correctly, to sinusoids and cosines of different
[00:10:36.160 --> 00:10:41.200] frequencies. And that's the positional encoding and it's fixed. But in GPT2, these are just parameters
[00:10:41.200 --> 00:10:45.680] and they're trained from scratch, just like any other parameter. And that seems to work
[00:10:45.680 --> 00:10:50.400] about as well. And so what they do is they kind of like recover these sinusoidal like features
[00:10:50.400 --> 00:10:55.920] during the optimization. We can also look at any of the other matrices here. So
[00:10:55.920 --> 00:11:04.320] here I took the first layer of the transformer and looking at like one of its weights and just
[00:11:04.320 --> 00:11:10.560] the first block of 300 by 300. And you see some structure, but like, again, like who knows what
[00:11:10.560 --> 00:11:15.040] any of this is, if you're into mechanistic interoperability, you might get a real kick out
[00:11:15.040 --> 00:11:19.440] of trying to figure out like what is going on, what is this structure, and what does this all mean.
[00:11:19.440 --> 00:11:22.880] But we're not going to be doing that in this video. But we definitely see that there's some
[00:11:22.880 --> 00:11:27.360] interesting structure and that's kind of cool. What we're most interested in is we've loaded
[00:11:27.360 --> 00:11:32.800] the weights of this model that was released by OpenAI. And now using the Hagen phase transformers,
[00:11:32.800 --> 00:11:39.520] we can not just get all the raw weights, but we can also get the what they call pipeline
[00:11:39.520 --> 00:11:46.000] and sample from it. So this is the prefix. Hello, I'm a language model comma. And then we're sampling
[00:11:46.960 --> 00:11:52.960] 30 tokens. And we're getting five sequences. And I ran this, and this is what it produced.
[00:11:52.960 --> 00:11:59.440] Hello, I'm a language model. But what I'm really doing is making a human-readable document.
[00:11:59.440 --> 00:12:03.520] There are other languages, but those are thought that thought. So you can read through these if you
[00:12:03.520 --> 00:12:08.800] like. But basically, these are five different completions of the same prefix from this duty
[00:12:08.800 --> 00:12:17.280] to 124M. Now, if I go here, I took this example from here. And sadly, even though we are fixing
[00:12:17.280 --> 00:12:23.920] the seed, we are getting different generations from the snippet than what they got. So presumably,
[00:12:23.920 --> 00:12:30.320] the code changed. But what we see though at this stage that's important is that we are getting
[00:12:30.320 --> 00:12:36.400] coherent text. So we've loaded the model successfully. We can look at all its parameters. And the keys
[00:12:36.400 --> 00:12:42.720] tell us where in the model these come from. And we want to actually write our own GPT2 class so
[00:12:42.720 --> 00:12:46.400] that we have full understanding of what's happening there. We don't want to be working with something
[00:12:46.400 --> 00:12:51.520] like the modeling GPT2.py, because it's just too complicated. We want to write this from scratch
[00:12:51.520 --> 00:12:57.120] ourselves. So we're going to be implementing the GPT model here in parallel. And as our first task,
[00:12:57.120 --> 00:13:04.480] let's load the GPT2 on 24M into the class that we're going to develop here from scratch. That's
[00:13:04.480 --> 00:13:10.080] going to give us confidence that we can load the OpenAI model. And therefore, there's a setting of
[00:13:10.080 --> 00:13:14.560] weights that exactly is the 124 model. But then of course, what we're going to do is we're going
[00:13:14.560 --> 00:13:20.880] to initialize the model from scratch instead, and try to train it ourselves on a bunch of documents
[00:13:20.880 --> 00:13:24.880] that we're going to get. And we're going to try to surpass that model. So we're going to get
[00:13:24.880 --> 00:13:28.000] different weights and everything's going to look different, hopefully better even.
[00:13:28.000 --> 00:13:33.360] But we're going to have a lot of confidence that because we can load the OpenAI model,
[00:13:33.360 --> 00:13:37.680] we are in the same model family and model class, and we just have to rediscover a good setting
[00:13:37.680 --> 00:13:44.400] of the weights, but from scratch. So let's now write the GPT2 model, and let's load the weights,
[00:13:44.400 --> 00:13:48.960] and make sure that we can also generate text that looks coherent. Okay, so let's now swing
[00:13:48.960 --> 00:13:53.280] over to the attention is on any paper that started everything. And let's scroll over to the model
[00:13:53.280 --> 00:13:59.200] architecture, the original transformer. Now remember that GPT2 is slightly modified from the original
[00:13:59.200 --> 00:14:05.680] transformer. In particular, we do not have the encoder. GPT2 is a decoder only transformer,
[00:14:05.680 --> 00:14:10.640] as we call it. So this entire encoder here is missing. And in addition to that, this cross
[00:14:10.640 --> 00:14:16.480] attention here that was using that encoder is also missing. So we delete this entire part.
[00:14:16.480 --> 00:14:21.680] Everything else stays almost the same, but there are some differences that we're going to
[00:14:21.680 --> 00:14:29.360] sort of look at here. So there are two main differences. When we go to the GPT2 paper,
[00:14:29.360 --> 00:14:35.680] under 2.3 that model, we notice that first, there's a reshuffling of the layer norms. So they change
[00:14:35.680 --> 00:14:43.680] place. And second, an additional layer normalization was added here to the final self-intentioned block.
[00:14:43.680 --> 00:14:48.960] So basically, all the layer norms here, instead of being after the MLP or after the attention,
[00:14:48.960 --> 00:14:54.000] they swing before it. And an additional layer norm gets added here right before the final classifier.
[00:14:54.000 --> 00:15:01.440] So now let's implement some of the first sort of skeleton and end modules here in our GPT and
[00:15:01.440 --> 00:15:06.720] end module. And in particular, we're going to try to match up to this schema here that is used by
[00:15:06.720 --> 00:15:11.120] hugging face transformers, because that will make it much easier to load these weights from this
[00:15:11.120 --> 00:15:17.440] state dict. So we want something that reflects this schema here. So here's what I came up with.
[00:15:19.760 --> 00:15:25.920] Basically, we see that the main container here that has all the modules is called transformer.
[00:15:25.920 --> 00:15:30.000] So I'm reflecting that within an end module dict. And this is basically a module that
[00:15:30.000 --> 00:15:36.080] allows you to index into the sub-modules using keys, just like a dictionary, strings.
[00:15:36.080 --> 00:15:41.360] Within it, we have the weights of the token embeddings, W-T-E, and that's an
[00:15:41.360 --> 00:15:46.560] an embedding. And the weights of the position embeddings, which is also just an an embedding.
[00:15:46.560 --> 00:15:51.280] And if you remember, an embedding is really just a fancy little wrapper module around just a single
[00:15:51.280 --> 00:15:59.360] array of numbers, a single block of numbers, just like this. It's a single tensor.
[00:15:59.360 --> 00:16:05.600] And embedding is a glorified wrapper around a tensor that allows you to
[00:16:05.600 --> 00:16:11.760] access its elements by indexing into the rows. Now, in addition to that, we see here that we have
[00:16:11.760 --> 00:16:17.840] a dot H, and then there's a, this is indexed using numbers instead of indexed using strings.
[00:16:17.840 --> 00:16:23.600] So there's a dot H dot zero, one, two, et cetera, all the way up till dot H dot 11.
[00:16:23.600 --> 00:16:28.560] And that's because there are 12 layers here in this transformer. So to reflect that,
[00:16:28.560 --> 00:16:34.160] I'm creating also an H, I think that probably stands for hidden. And instead of a module dict,
[00:16:34.160 --> 00:16:39.040] this is a model list. So we can index it using integers exactly as we see here, dot zero,
[00:16:39.040 --> 00:16:46.640] dot one, two, et cetera. And the module list has n layer blocks, and the blocks are yet to be
[00:16:46.640 --> 00:16:52.720] defined and then module in a bit. In addition to that, following the GPT two paper, we need an
[00:16:52.720 --> 00:16:58.720] additional final layer norm that we're going to put in there. And then we have the final classifier,
[00:16:58.720 --> 00:17:07.680] the language model head, which projects from 768, the number of embedding dimensions in this GPT,
[00:17:07.680 --> 00:17:13.760] all the way to the vocab size, which is 550,257. And GPT two uses no bias for this final
[00:17:13.760 --> 00:17:22.080] sort of projection. So this is the skeleton. And you can see that it reflects this. So the WTE
[00:17:22.080 --> 00:17:26.720] has the token embeddings. Here it's called output embedding, but it's really the token embeddings.
[00:17:26.720 --> 00:17:32.320] The PE is the positional encoyings. Those two pieces of information as we saw previously are
[00:17:32.320 --> 00:17:38.640] going to add, and then going to the transformer. The dot H is the older blocks in gray. And the
[00:17:38.640 --> 00:17:45.840] LNF is this new layer that gets added here by the GPT two model. And lmhead is this linear part
[00:17:45.840 --> 00:17:53.040] here. So that's the skeleton of the GPT two. We now have to implement the block. Okay, so let's
[00:17:53.040 --> 00:17:59.760] now recurse to the block itself. So we want to define the block. So I'll start putting them here.
[00:18:00.320 --> 00:18:06.960] So the block, I like to write it out like this. These are some of the initializations. And then
[00:18:06.960 --> 00:18:12.560] this is the actual forward pass of what this block computes. And notice here that there's a change
[00:18:12.560 --> 00:18:19.200] from the transformer again that is mentioned in the GPT two paper. So here, the layer normalizations
[00:18:19.200 --> 00:18:23.360] are after the application of attention or feed forward. In addition to that note,
[00:18:23.360 --> 00:18:28.480] that the normalizations are inside the residual stream. You see how feed forward is applied,
[00:18:28.480 --> 00:18:33.920] and then this arrow goes through and through the normalization. So that means that your residual
[00:18:33.920 --> 00:18:40.080] pathway has normalizations inside them. And this is not very good or desirable. You actually
[00:18:40.080 --> 00:18:45.600] prefer to have a single clean residual stream all the way from supervision all the way down to
[00:18:45.600 --> 00:18:52.480] the inputs, the tokens. And this is very desirable and nice because the gradients that flow from the
[00:18:52.480 --> 00:18:58.480] top, if you remember from your micro grad, addition just distributes gradients during the backward
[00:18:58.480 --> 00:19:05.600] stage to both of its branches equally. So addition is a branch in the gradients. And so that means
[00:19:05.600 --> 00:19:11.840] that the gradients from the top flow straight to the inputs, the tokens, through the residual pathways
[00:19:11.840 --> 00:19:16.000] unchanged. But then in addition to that, the gradient also flows through the blocks. And the
[00:19:16.000 --> 00:19:20.560] blocks, you know, contribute their own contribution over time and kick in and change the optimization
[00:19:20.560 --> 00:19:27.280] over time. But basically, clean residual pathway is desirable from an optimization perspective. And
[00:19:27.280 --> 00:19:33.040] then this is the pre normalization version, where you see that our X first goes through the layer
[00:19:33.040 --> 00:19:39.120] normalization, and then the attention, and then goes back out to go to the layer normalization
[00:19:39.120 --> 00:19:44.960] number two, and the multi-layer perceptron, sometimes also referred to as a feed forward network or an
[00:19:44.960 --> 00:19:50.800] F of M. And then that goes into the residual stream again. And the one more thing that is kind of
[00:19:50.800 --> 00:19:56.000] interesting to note is that recall that attention is a communication operation. It is where all the
[00:19:56.000 --> 00:20:01.760] tokens, and there's 1024 tokens lined up in a sequence. And this is where the tokens communicate,
[00:20:01.760 --> 00:20:09.280] this is where they exchange information. So attention is a aggregation function. It's a pooling function.
[00:20:09.280 --> 00:20:18.000] It's a weighted sum function. It is a reduce operation. Whereas MLP, this MLP here happens
[00:20:18.000 --> 00:20:22.160] at every single token individually. There's no information being collected or exchanged between
[00:20:22.160 --> 00:20:28.240] the tokens. So the attention is the reduce, and the MLP is the map. And what you end up with is
[00:20:28.240 --> 00:20:33.520] that the transformers just ends up just being a repeated application of map reduce. If you want
[00:20:33.520 --> 00:20:38.480] to think about it that way. So this is where they communicate, and this is where they think
[00:20:38.480 --> 00:20:42.640] individually about the information that they gathered, and every one of these blocks,
[00:20:42.640 --> 00:20:49.520] iteratively defines the representation inside the residual stream. So this is our block.
[00:20:49.520 --> 00:20:55.200] It's likely modified from this picture. Okay, so let's now move on to the MLP.
[00:20:55.200 --> 00:21:01.280] So the ML block I implemented as follows. It is relatively straightforward. We basically have
[00:21:01.280 --> 00:21:08.400] two linear projections here that are sandwiched in between the GALU nonlinearity. So end-end up
[00:21:08.400 --> 00:21:15.040] GALU, approximate is 10H. Now when we swing on, swing over to the part of documentation,
[00:21:15.040 --> 00:21:20.240] this is end up GALU, and it has this format, and it has two versions, the original version of GALU,
[00:21:20.240 --> 00:21:25.600] which we'll step into in a bit, and the approximate version of GALU, which we can request using 10H.
[00:21:26.640 --> 00:21:33.360] So as you can see, just as a preview here, GALU is basically like a relu, except there's no flat,
[00:21:33.360 --> 00:21:39.520] exactly flat tail here at exactly zero. But otherwise, it looks very much like a slightly
[00:21:39.520 --> 00:21:46.320] smoother relu. It comes from this paper here, Gaussian error linear units, and you can step
[00:21:46.320 --> 00:21:50.960] through this paper, and there's some mathematical kind of reasoning that leads to interpretation,
[00:21:50.960 --> 00:21:56.560] at least to the specific formulation. It has to do with stochastic radio risers, and the expectation
[00:21:56.560 --> 00:22:00.560] of a modification to it have to drop out, so you can read through all of that if you'd like here.
[00:22:00.560 --> 00:22:05.920] And there's a little bit of the history as to why there's an approximate version of GALU,
[00:22:05.920 --> 00:22:09.040] and that comes from this issue here, as far as I can tell.
[00:22:09.040 --> 00:22:16.880] And in this issue, Daniel Hendrix mentions that at the time when they developed this nonlinearity,
[00:22:16.880 --> 00:22:22.000] the IRF function, which you need to evaluate the exact GALU, was very slow in tensor flows,
[00:22:22.000 --> 00:22:26.960] so they ended up basically developing this approximation. And this approximation then ended up being
[00:22:26.960 --> 00:22:31.680] picked up by BERT and by GPT-2, et cetera. But today, there's no real good reason to use the
[00:22:31.680 --> 00:22:37.200] approximate version. You'd prefer to just use the exact version, because in my expectations
[00:22:37.200 --> 00:22:42.640] that there's no big difference anymore, and this is kind of like a historical kind of quirk.
[00:22:42.640 --> 00:22:49.760] But we are trying to reproduce GPT-2 exactly, and GPT-2 used the 10H approximate version,
[00:22:49.760 --> 00:22:56.720] so we prefer to stick with that. Now, one other reason to actually just intuitively use GALU,
[00:22:56.720 --> 00:23:02.000] instead of RALU, is previously in videos in the past, we've spoken about the dead RALU
[00:23:02.000 --> 00:23:08.640] neuron problem, where in this tale of RALU, if it's exactly flat at zero, any activations that fall
[00:23:08.640 --> 00:23:13.280] there will get exactly zero gradient. There's no change, there's no adaptation, there's no
[00:23:13.280 --> 00:23:18.560] development of the network, if any of these activations end in this flat region. But the
[00:23:18.560 --> 00:23:23.120] GALU always contributes a local gradient, and so there's always going to be a change, always
[00:23:23.120 --> 00:23:28.080] going to be an adaptation, and sort of smoothing it out ends up empirically working better in practice,
[00:23:28.080 --> 00:23:33.360] as demonstrated in this paper, and also as demonstrated by it being picked up by the BERT paper, GPT-2
[00:23:33.360 --> 00:23:39.280] paper, and so on. So for that reason, we adopt this nonlinearity here in the 10, in the GPT-2
[00:23:39.280 --> 00:23:44.960] reproduction. Now, in more modern networks, also like LAMA-3 and so on, this nonlinearity also
[00:23:44.960 --> 00:23:51.520] further changes to SWIGU and other variants like that, but for GPT-2, they use this approximate GALU.
[00:23:51.520 --> 00:23:57.440] Okay, and finally, we have the attention operation. So let me paste in my attention.
[00:23:57.440 --> 00:24:04.400] So I know this is a lot, so I'm going to go through this a bit quickly, a bit slowly, but
[00:24:04.400 --> 00:24:08.160] not too slowly, because we have covered this in the previous video, and I would just punch you there.
[00:24:08.160 --> 00:24:14.480] So this is the attention operation. Now, in the previous video, you will remember, this is not
[00:24:14.480 --> 00:24:20.560] just attention. This is multi-headed attention, right? And so in the previous video, we had this
[00:24:20.560 --> 00:24:26.400] multi-headed attention module, and this implementation made it obvious that these heads are not actually
[00:24:26.400 --> 00:24:32.720] that complicated. There's basically in parallel inside every attention block. There's multiple
[00:24:32.720 --> 00:24:38.480] heads, and they're all functioning in parallel, and their outputs are just being concatenated,
[00:24:38.480 --> 00:24:43.520] and that becomes the output of the multi-headed attention. So the heads are just kind of like
[00:24:43.520 --> 00:24:49.840] parallel streams, and their outputs get concatenated. And so it was very simple and made the head
[00:24:49.840 --> 00:24:53.440] be kind of like fairly straightforward in terms of its implementation.
[00:24:53.440 --> 00:24:59.440] What happens here is that instead of having two separate modules, and indeed many more modules
[00:24:59.440 --> 00:25:06.400] that get concatenated, all of that is just put into a single self-attention module. And instead,
[00:25:06.400 --> 00:25:13.520] I'm being very careful in doing a bunch of transpose split tensor gymnastics to make this very
[00:25:13.520 --> 00:25:17.360] efficient impact work. But fundamentally and algorithmically, nothing is different from the
[00:25:17.360 --> 00:25:26.240] implementation we saw before in this giver repository. So to remind you very briefly, and I don't want
[00:25:26.240 --> 00:25:32.880] to go into this in too much time, but we have these tokens lined up in a sequence, and there's
[00:25:32.880 --> 00:25:39.040] one thousand twenty of them. And then each token at this stage of the attention emits three vectors,
[00:25:39.040 --> 00:25:46.560] the query key and the value. And first what happens here is that the queries and the keys
[00:25:46.560 --> 00:25:52.960] have to multiply each other to get sort of the attention amount, like how interesting they find
[00:25:52.960 --> 00:25:57.440] each other. So they have to interact multiplicatively. So we're doing here as we're calculating the
[00:25:57.440 --> 00:26:03.440] QKV while splitting it. And then there's a bunch of gymnastics, as I mentioned here. And the way this
[00:26:03.440 --> 00:26:10.160] works is that we're basically making the number of heads and H into a batch dimension. And so it's
[00:26:10.160 --> 00:26:16.800] a batch dimension just like B, so that in these operations that follow PyTorch treats B and H as
[00:26:16.800 --> 00:26:22.880] batches. And it applies all the operations on all of them in parallel in both the batch and the heads.
[00:26:24.080 --> 00:26:28.800] And the operations that can apply are number one, the queries and the keys interact to give us
[00:26:28.800 --> 00:26:35.280] our attention. This is the autoregressive masks that make sure that the tokens only attend to
[00:26:35.280 --> 00:26:42.160] tokens before them and never to tokens in the future. The softmax here normalizes the attention,
[00:26:42.160 --> 00:26:48.080] so it sums to one always. And then recall from the previous video that doing the attention matrix
[00:26:48.080 --> 00:26:53.280] multiply with the values is basically a way to do a weighted sum of the values of the tokens that
[00:26:53.280 --> 00:26:58.400] we found interesting at every single token. And then the final transpose contiguous
[00:26:58.400 --> 00:27:03.760] and view is just reassembling all of that again. And this actually performs the concatenation
[00:27:03.760 --> 00:27:09.760] operation. So you can step through this slowly if you'd like, but it is equivalent mathematically
[00:27:09.760 --> 00:27:14.960] to our previous implementation is just more efficient in PyTorch. So that's why I chose this
[00:27:14.960 --> 00:27:20.240] implementation instead. Now in addition to that, I'm being careful with how I name my variables.
[00:27:20.240 --> 00:27:26.560] So for example, c@en is the same as c@en. And so actually our keys should basically exactly
[00:27:26.560 --> 00:27:30.560] follow the schema of the hugging face transformers code. And that will make it very easy for us to
[00:27:30.560 --> 00:27:36.800] now port over all the weights from exactly this sort of naming conventions, because all of our
[00:27:36.800 --> 00:27:42.800] variables are named the same thing. But at this point, we have finished the DQPT2 implementation.
[00:27:42.800 --> 00:27:48.000] And what that allows us to do is we don't have to basically use this file from ugly face,
[00:27:48.000 --> 00:27:58.560] which is fairly long. This is 2,000 lines of code. Instead, we just have less than 100
[00:27:58.560 --> 00:28:03.120] lines of code. And this is the complete GPT2 implementation. So at this stage, we should just
[00:28:03.120 --> 00:28:08.320] be able to take over all the weights, set them, and then do generation. So let's see what that
[00:28:08.320 --> 00:28:12.640] looks like. Okay, so here, I've also changed the GPT config so that the numbers here,
[00:28:12.640 --> 00:28:18.400] dive parameters agree with the GPT2-124M model. So the maximum sequence length, which I call block
[00:28:18.400 --> 00:28:26.880] size here, is 124. The number of tokens is 5257, which if you watch my token as a video, know that
[00:28:26.880 --> 00:28:35.920] this is 50,000 merges, BP merges, 256 byte tokens, the leaves of the BPE tree, and one special end
[00:28:35.920 --> 00:28:41.520] of text token that delimits different documents and can start generation as well. And there are
[00:28:41.520 --> 00:28:46.720] 12 layers. There are 12 heads in the attention and the dimension of the transformers was 768.
[00:28:46.720 --> 00:28:53.920] So here's how we can now load the parameters from hugging face to our code here and initialize
[00:28:53.920 --> 00:29:00.240] the GPT class with those parameters. So let me just copy paste a bunch of code here. And I'm
[00:29:00.240 --> 00:29:07.360] not going to go through this code too slow too quickly, too slowly, because honestly, it's not
[00:29:07.360 --> 00:29:10.960] that interesting. It's not that exciting. We're just loading the weights. So it's kind of dry.
[00:29:10.960 --> 00:29:15.760] But as I mentioned, there are four models in this mini series of GPT2. This is some of the
[00:29:15.760 --> 00:29:21.440] Jupiter code code that we had here on the right. I'm just putting it over. These are
[00:29:21.440 --> 00:29:26.320] the hyper parameters of the GPT2 models. We're creating the config object and creating our own
[00:29:26.320 --> 00:29:31.600] model. And then what's happening here is we're creating the state dict, both for our model and
[00:29:31.600 --> 00:29:37.920] for the hugging face model. And then what we're doing here is we're going over the hugging face
[00:29:37.920 --> 00:29:45.600] model keys. And we're copying over those tensors. And in the process, we are kind of ignoring a
[00:29:45.600 --> 00:29:50.720] few of the buffers. They're not parameters, they're buffers. So for example, attention of bias,
[00:29:50.720 --> 00:29:56.000] that's just used for the R aggressive mask. And so we are ignoring some of those masks.
[00:29:56.000 --> 00:30:01.680] And that's it. And then one additional kind of annoyance is that this comes from the TensorFlow
[00:30:01.680 --> 00:30:06.400] repo. And I'm not sure how this is a little bit annoying, but some of the weights are transposed
[00:30:06.400 --> 00:30:11.920] from what PyTorch would want. And so manually, I hard coded the weights that should be transposed.
[00:30:11.920 --> 00:30:18.240] And then we transpose them if that is so. And then we return this model. So the firm
[00:30:18.240 --> 00:30:26.480] pre-trained is a constructor or a class method in Python that returns the GPT object. If we just
[00:30:26.480 --> 00:30:30.880] give it the model type, which in our case is GPT2, the smallest model that we're interested in.
[00:30:31.760 --> 00:30:37.200] So this is the code. And this is how you would use it. And we can pop open the terminal here
[00:30:37.200 --> 00:30:44.960] in VS code. And we can Python train GPT2.py and fingers crossed.
[00:30:44.960 --> 00:30:53.680] Okay, so we didn't crash. And so we can load the weights and the biases and everything else
[00:30:53.680 --> 00:30:58.080] into our and in module. But now let's also get additional confidence that this is working.
[00:30:58.080 --> 00:31:02.160] And let's try to actually generate from this model. Okay, now before we can actually generate
[00:31:02.160 --> 00:31:06.080] from this model, we have to be able to forward it. We didn't actually write that code yet.
[00:31:06.080 --> 00:31:12.640] So here's the forward function. So the input to the forward is going to be our indices,
[00:31:12.640 --> 00:31:20.160] our tokens, token indices. And they are always of shape B by T. And so we have batch dimension
[00:31:20.160 --> 00:31:27.120] of B. And then we have the time dimension of up to T. And the T can be more than the block size.
[00:31:27.120 --> 00:31:32.800] The block size is the maximum sequence length. So B by T indices arranged a sort of like a
[00:31:32.800 --> 00:31:39.520] two dimensional layout. And remember that basically every single row of this is of size up to block
[00:31:39.520 --> 00:31:46.320] size. And this is T tokens that are in a sequence. And then we have B independent sequences stacked
[00:31:46.320 --> 00:31:51.760] up in a batch so that this is efficient. Now here we are forwarding the position embeddings
[00:31:51.760 --> 00:31:55.760] and the token embeddings. And this code should be very recognizable from the previous lecture.
[00:31:56.320 --> 00:32:02.080] So we basically use a range, which is kind of like a version of range, but for PyTorch.
[00:32:02.080 --> 00:32:08.880] And we're iterating from zero to T and creating this positions sort of indices.
[00:32:08.880 --> 00:32:15.120] And then we are making sure that they're unseen devices IDX because we're not going to be training
[00:32:15.120 --> 00:32:19.040] on only CPU. That's going to be too inefficient. We want to be training on GPU. And that's going
[00:32:19.040 --> 00:32:24.960] to come in a bit. Then we have the position embeddings and the token embeddings and the addition
[00:32:24.960 --> 00:32:29.360] operation of those two. Now notice that position embeddings are going to be identical for every
[00:32:29.360 --> 00:32:36.880] single row of of input. And so there's broadcasting hidden inside this plus where we have to create
[00:32:36.880 --> 00:32:41.200] an additional dimension here. And then these two add up because the same position embeddings
[00:32:41.200 --> 00:32:46.240] apply to every single row of our examples stacked up in a batch. Then we forward the
[00:32:46.240 --> 00:32:52.320] transformer blocks and finally the last layer norm and the element. So what comes out after
[00:32:52.320 --> 00:32:59.200] forward is the logits. And if the input was B by T indices, then at every single B by T,
[00:32:59.200 --> 00:33:06.720] we will calculate the logits for what token comes next in a sequence. So what is the token
[00:33:06.720 --> 00:33:14.080] B, T plus one, the one on the right of this token. And vocab size here is the number of
[00:33:14.080 --> 00:33:19.120] possible tokens. And so therefore this is the tensor that we're going to obtain. And these
[00:33:19.120 --> 00:33:25.680] logits are just a softmax away from coming probabilities. So this is the forward pass
[00:33:25.680 --> 00:33:30.000] of the network. And now we can get logits. And so we're going to be able to generate from the model
[00:33:30.000 --> 00:33:35.440] imminently. Okay, so now we're going to try to set up the identical thing on the left here
[00:33:35.440 --> 00:33:41.120] that matches hugging face on the right. So here we sampled from the pipeline and we sampled
[00:33:41.120 --> 00:33:46.560] five times up to 30 tokens with a prefix of hello on the language model. And these are the
[00:33:46.560 --> 00:33:51.040] completions that we achieved. So we're going to try to replicate that a lot here. So number
[00:33:51.040 --> 00:33:54.960] term sequences five max length is 30. So the first thing we do, of course, is we initialize
[00:33:54.960 --> 00:34:00.480] our model, then we put it into evaluation mode. Now this is a good practice to put the model into
[00:34:00.480 --> 00:34:05.680] eval when you're not going to be training it, you're just going to be using it. And I don't
[00:34:05.680 --> 00:34:09.760] actually know if this is doing anything right now for the following reason. Our model up above
[00:34:09.760 --> 00:34:16.080] here contains no modules or layers that actually have a different behavior at training or evaluation
[00:34:16.080 --> 00:34:20.960] time. So for example, dropout, bashroom, and a bunch of other layers have this kind of behavior.
[00:34:20.960 --> 00:34:24.960] But all of these layers that we've used here should be identical in both training and evaluation
[00:34:24.960 --> 00:34:32.080] time. So potentially model that eval is nothing, but then I'm not actually sure if this is the case
[00:34:32.080 --> 00:34:37.200] and maybe pytorch internals do some clever things depending on the evaluation mode inside here.
[00:34:37.200 --> 00:34:43.680] The next thing we're doing here is we are moving the entire model to CUDA. So we're moving this
[00:34:43.680 --> 00:34:49.840] all of the tensors to GPU. So I'm SSH'd here to a cloud box and I have a bunch of GPUs on this box.
[00:34:49.840 --> 00:34:55.600] And here I'm moving the entire model and all of its members and all of its tensors and everything
[00:34:55.600 --> 00:35:01.600] like that. Everything gets shipped off to basically a whole separate computer that is sitting on the
[00:35:01.600 --> 00:35:06.880] GPU. And the GPU is connected to the CPU and they can communicate, but it's basically a whole separate
[00:35:06.880 --> 00:35:11.840] computer with its own computer architecture. And it's really located to parallel processing tasks
[00:35:11.840 --> 00:35:17.360] like those of running neural networks. So I'm doing this so that the model lives on the GPU,
[00:35:17.360 --> 00:35:21.840] a whole separate computer, and it's just going to make our code a lot more efficient,
[00:35:21.840 --> 00:35:28.320] because all of this stuff runs a lot more efficiently than the GPUs. So that's the model itself.
[00:35:28.320 --> 00:35:35.280] Now, the next thing we want to do is we want to start with this as the prefix when we do the
[00:35:35.280 --> 00:35:42.160] generation. So let's actually create those prefix tokens. So here's the code that I've written.
[00:35:42.160 --> 00:35:47.520] We're going to import the tech token library from OpenAI and we're going to get the GPT-2 encoding.
[00:35:47.520 --> 00:35:54.800] So that's the tokenizer for GPT-2. And then we're going to encode this string and get a list of
[00:35:54.800 --> 00:36:00.400] integers which are the tokens. Now these integers here should actually be fairly straightforward
[00:36:00.400 --> 00:36:05.360] because we can just copy paste this string. And we can sort of inspect what it is in
[00:36:05.360 --> 00:36:09.680] the tech tokenizer. So just pasting that in, these are the tokens that are going to come out.
[00:36:09.680 --> 00:36:16.720] So this list of integers is what we expect tokens to become. And as you recall, if you saw my video,
[00:36:16.720 --> 00:36:21.600] of course, all the tokens, they're just little string chunks, right? So these are, this is the
[00:36:21.600 --> 00:36:28.800] chunkation of this string into GPT-2 tokens. So once we have those tokens, it's a list of integers,
[00:36:28.800 --> 00:36:33.760] we can create a torch tensor out of it. In this case, it's eight tokens. And then we're going to
[00:36:33.760 --> 00:36:40.800] replicate these eight tokens for five times to get five rows of eight tokens. And that is our
[00:36:40.800 --> 00:36:50.160] initial input X, as I call it here. And it lives on the GPU and more. So X now is this IDX that we
[00:36:50.160 --> 00:36:57.360] can put into forward to get our logits so that we know what comes as the sixth token,
[00:36:57.360 --> 00:37:04.720] sorry, as the ninth token in every one of these five rows. Okay, and we are now ready to generate.
[00:37:04.720 --> 00:37:10.400] So let me paste in one more code block here. So what's happening here in this code block is
[00:37:10.400 --> 00:37:17.600] we have these X, which is of size B by T, right? So batch by time. And we're going to be in every
[00:37:17.600 --> 00:37:23.200] iteration of this loop, we're going to be adding a column of new indices into each one of these rows,
[00:37:23.200 --> 00:37:28.240] right? And so these are the new indices, and we're appending them to the sequence as we're
[00:37:28.240 --> 00:37:34.640] sampling. So with each loop iteration, we get one more column into X. And all of the operations
[00:37:34.640 --> 00:37:38.320] happening in the context manager of torch.no grad, this is just telling PyTorch that we're
[00:37:38.320 --> 00:37:42.560] not going to be calling that backward on any of this. So it doesn't have to cache all the
[00:37:42.560 --> 00:37:46.720] intermediate tensors, it's not going to have to prepare in any way for a potential backward
[00:37:46.720 --> 00:37:53.520] later. And this saves a lot of space and also possibly some time. So we get our logits,
[00:37:53.520 --> 00:37:58.960] we get the logits at only the last location, we throw away all the other logits, we don't need
[00:37:58.960 --> 00:38:05.920] them, we only care about the last column's logits. So this is being wasteful. But this is just kind
[00:38:05.920 --> 00:38:12.640] of like an inefficient implementation of sampling. So it's correct by inefficient. So we get the last
[00:38:12.640 --> 00:38:16.960] column of logits, pass it through a soft mass to get our probabilities. Then here, I'm doing
[00:38:16.960 --> 00:38:21.520] top case sampling of 50. And I'm doing that because this is the hugging face default. So just
[00:38:21.520 --> 00:38:30.000] looking at the hugging face dogs here of a pipeline, there's a bunch of quarks that go into hugging
[00:38:30.000 --> 00:38:36.720] face. And I mean, that's kind of a lot honestly. But I guess the important one that I noticed is
[00:38:36.720 --> 00:38:43.440] that they're using top K by default, which is 50. And what that does is that, so that's being used
[00:38:43.440 --> 00:38:47.760] here as well. And what that does is basically we want to take our probabilities, and we only want
[00:38:47.760 --> 00:38:54.080] to keep the top 50 probabilities. And anything that is lower than the 50th probability, we just
[00:38:54.080 --> 00:38:59.680] clamped to zero and renormalize. And so that way we are never sampling very rare tokens.
[00:38:59.680 --> 00:39:04.640] The tokens we're going to be sampling are always in the top 50 of most likely tokens.
[00:39:04.640 --> 00:39:09.040] And this helps keep the model kind of on track. And it doesn't blabber on and it doesn't get lost,
[00:39:09.040 --> 00:39:14.640] and doesn't go off the rails as easily. And it kind of like sticks in the vicinity of likely
[00:39:14.640 --> 00:39:19.280] tokens a lot better. So this is the way to do it in PyTorch. And you can step through it if you like,
[00:39:19.280 --> 00:39:23.600] I don't think it's super insightful. So I'll speed through it. But roughly speaking, we get this new
[00:39:23.600 --> 00:39:31.440] column of tokens. We append them on X, and basically the columns of X grow until this while loop gets
[00:39:31.440 --> 00:39:40.800] tripped up. And then finally, we have an entire X of size. Five by 30 in this case, in this example.
[00:39:40.800 --> 00:39:46.480] And we can just basically print all those individual rows. So I'm getting all the rows.
[00:39:46.480 --> 00:39:51.040] I'm getting all the tokens that are sampled. And I'm using the decode function from tick
[00:39:51.040 --> 00:39:57.120] tokenizer to get back the string, which we can print. And so terminal, new terminal.
[00:39:59.280 --> 00:40:01.600] And let me Python train GPT to.
[00:40:01.600 --> 00:40:14.640] Okay. So these are the generations that we're getting. Hello, I'm a language model, not a program.
[00:40:14.640 --> 00:40:21.600] New line, new line, etc. Hello, I'm a language model. And one of the main things that bothers
[00:40:21.600 --> 00:40:26.000] me when they create languages is how easy it becomes to create something that I mean, so this
[00:40:26.000 --> 00:40:30.160] will just like blabber on right in all these cases. Now one thing you will notice is that these
[00:40:30.160 --> 00:40:36.960] generations are not the generations of fucking face here. And I can't find the discrepancy, to be
[00:40:36.960 --> 00:40:41.120] honest, and I didn't fully go through all these options, but probably there's something else hiding
[00:40:41.120 --> 00:40:45.680] in on addition to the top P. So I'm not able to match it up, but just for correctness.
[00:40:45.680 --> 00:40:51.920] Down here below in the Drupal notebook and using the hugging face model. So this is the hugging
[00:40:51.920 --> 00:41:00.560] face model here. I was, I replicated the code. And if I do this and I run that, then I am getting the
[00:41:00.560 --> 00:41:07.280] same results. So basically, the model internals are not wrong. It's just I'm not 100% sure what
[00:41:07.280 --> 00:41:12.000] the pipeline does in hugging face. And that's why we're not able to match them up. But otherwise,
[00:41:12.000 --> 00:41:17.600] the code is correct. And we've loaded all the tensors correctly. So we're initializing the model
[00:41:17.600 --> 00:41:22.560] correctly and everything here works. So long story short, we've ported all the weights, we
[00:41:22.560 --> 00:41:28.080] initialize the GPT to this is the exact opening on GPT to, and it can generate sequences and they
[00:41:28.080 --> 00:41:34.400] look sensible. And now here, of course, we're initializing with GPT to model weights. But now
[00:41:34.400 --> 00:41:38.880] we want to initialize from scratch from random numbers. And we want to actually train the model
[00:41:38.880 --> 00:41:46.080] that will give us sequences as good as or better than these ones in quality. And so that's what we
[00:41:46.080 --> 00:41:50.400] turn to next. So it turns out that using the random model is actually fairly straightforward,
[00:41:50.400 --> 00:41:57.520] because PyTorch already initializes our model randomly and by default. So when we create the
[00:41:57.520 --> 00:42:03.520] GPT model and the constructor, this is all, all of these layers and modules have random
[00:42:03.520 --> 00:42:08.800] initializers that are there by default. So when these linear layers get created and so on,
[00:42:08.800 --> 00:42:12.880] there's default constructors, for example, using the Javier initialization that we saw in the past
[00:42:13.520 --> 00:42:19.440] to construct the weights of these players. And so creating a random model instead of a GPT
[00:42:19.440 --> 00:42:24.480] to model is actually fairly straightforward. And we would just come here. And instead we would
[00:42:24.480 --> 00:42:31.360] create model equals GPT. And then we want to use the default config, GPT config. And the default
[00:42:31.360 --> 00:42:37.760] config uses the 124m parameters. So this is the random model initialization. And we can run it.
[00:42:43.120 --> 00:42:49.120] And we should be able to get results. Now the results here, of course, are total garbage
[00:42:49.120 --> 00:42:53.440] carble. And that's because it's a random model. And so we're just getting all these random token
[00:42:53.440 --> 00:42:59.200] strength pieces chunked up totally a random. So that's what we have right now. Now, one more
[00:42:59.200 --> 00:43:03.120] thing I wanted to point out, by the way, is in case you do not have CUDA available, because you
[00:43:03.120 --> 00:43:08.400] don't have a GPU, you can still follow along with what we're doing here, to some extent.
[00:43:09.440 --> 00:43:13.360] And probably not to the very end, because by the end, we're going to be using multiple GPUs and
[00:43:13.360 --> 00:43:17.600] actually doing a serious training run. But for now, you can actually follow along decently. Okay.
[00:43:17.600 --> 00:43:22.960] So one thing that I like to do in PyTorch is I like to auto detect the device that is available
[00:43:22.960 --> 00:43:30.480] to you. So in particular, you could do that like this. So here we are trying to detect the device to
[00:43:30.480 --> 00:43:35.360] run on that has the highest compute capability. You can think about it that way. So by default,
[00:43:35.360 --> 00:43:39.200] we start with CPU, which of course is available everywhere, because every single computer will
[00:43:39.200 --> 00:43:45.440] have a CPU. But then we can try to detect the heavy GPU, you saw use a CUDA. And then if you don't
[00:43:45.440 --> 00:43:52.080] have a CUDA, do you at least have MPS? MPS is the backend for Apple Silicon. So if you have a MacBook
[00:43:52.080 --> 00:43:56.720] that is fairly new, you probably have Apple Silicon on the inside. And then that has a GPU that is
[00:43:56.720 --> 00:44:01.680] actually fairly capable, depending on which MacBook you have. And so you can use MPS, which will be
[00:44:01.680 --> 00:44:06.720] potentially faster than CPU. And so we can print the device here. Now, once we have the device,
[00:44:06.720 --> 00:44:14.720] we can actually use it in place of CUDA. So we just swap it in. And notice that here, when we call
[00:44:14.720 --> 00:44:22.720] model on X, if this X here is on CPU, instead of GPU, then it will work fine, because here in the
[00:44:22.720 --> 00:44:28.880] forward, which is where PyTorch will come, when we create a pose, we are careful to use the device
[00:44:28.880 --> 00:44:34.480] of IDX to create this tensor as well. And so there won't be any mismatch where one tensor is on
[00:44:34.480 --> 00:44:40.880] CPU, one is on GPU, and that you can't combine those. But here we are carefully initializing on
[00:44:40.880 --> 00:44:47.680] the correct device, as indicated by the input to this model. So this will auto detect device.
[00:44:47.680 --> 00:44:53.040] For me, this will be, of course, GPU. So using device CUDA.
[00:44:55.600 --> 00:45:00.480] But you can also run with, as I mentioned, another device. And it's not going to be too
[00:45:00.480 --> 00:45:08.160] much slower. So if I override device here, if I override device equals CPU, then
[00:45:08.160 --> 00:45:15.120] we'll swap print CUDA, of course. But now we're actually using CPU, one, two, three,
[00:45:15.120 --> 00:45:22.640] four, five, six, okay, about six seconds. And actually, we're not using Torch compile and
[00:45:22.640 --> 00:45:27.040] stuff like that, which will speed up everything a lot faster as well. But you can't follow along,
[00:45:27.040 --> 00:45:33.120] even on a CPU, I think, to a decent extent. So that's a note on that. Okay, so I do want to loop
[00:45:33.120 --> 00:45:38.160] around eventually into what it means to have different devices in PyTorch. And what it is exactly
[00:45:38.160 --> 00:45:43.840] that PyTorch does in the background for you, when you do something like module dot two device,
[00:45:43.840 --> 00:45:48.640] or where you take a torch tensor and do a dot two device. And what exactly happens and how that
[00:45:48.640 --> 00:45:53.600] works? But for now, I'd like to get to training. And I'd like to start training the model. And for
[00:45:53.600 --> 00:45:59.040] now, let's just say the device makes code go fast. And let's go into how we can actually train the
[00:45:59.040 --> 00:46:04.000] model. So to train the model, we're going to need some data set. And for me, the best debugging
[00:46:04.000 --> 00:46:09.600] simplest data set that I like to use is the tiny Shakespeare data set. And it's available at this
[00:46:09.600 --> 00:46:16.000] URL. So you can W get it, or you can just search tiny Shakespeare data set. And so I have in my
[00:46:16.000 --> 00:46:23.040] file system as just Ls input dot txt. So I already downloaded it. And here I'm reading the data set
[00:46:23.040 --> 00:46:28.880] getting the first 1000 characters and printing the first 100. Now remember that GPT two has
[00:46:28.880 --> 00:46:34.480] roughly a compression ratio, the tokenizer has a compression ratio of roughly three to one.
[00:46:34.480 --> 00:46:39.840] So 1000 characters is roughly 300 tokens here that will come out of this in the slides that
[00:46:39.840 --> 00:46:47.120] we're currently getting. So this is the first few characters. And if you want to get a few more
[00:46:47.120 --> 00:46:54.480] statistics on this, we can do word count on input dot txt. So we can see that this is 40,000 lines,
[00:46:54.480 --> 00:47:00.960] about 200,000 words in this data set, and about 1 million bytes in this file. And knowing that
[00:47:00.960 --> 00:47:05.600] this file is only ASCII characters, there's no crazy Unicode here as far as I know. And so
[00:47:05.600 --> 00:47:10.960] every ASCII character is encoded with one byte. And so this is the same number, roughly a million
[00:47:10.960 --> 00:47:17.040] characters inside this data set. So that's the data set size by default, very small and minimal
[00:47:17.040 --> 00:47:22.320] data set for debugging to get us off the ground. In order to tokenize this data set, we're going to
[00:47:22.320 --> 00:47:30.800] get tick token encoding for GPT two, encode the data, the first 1000 characters, and then not
[00:47:30.800 --> 00:47:36.000] only going to print the first 24 tokens. So these are the tokens as a list of integers.
[00:47:36.000 --> 00:47:41.680] And if you can read GPT two tokens, you will see that 198 here, you'll recognize that as the
[00:47:41.680 --> 00:47:46.320] slashing character. So that is a new line. And then here, for example, we have two new lines. So
[00:47:46.320 --> 00:47:53.120] that's 198 twice here. So this is just the tokenization of the first 24 tokens. So what we want to do
[00:47:53.120 --> 00:47:59.120] now is we want to actually process these token sequences and feed them into a transformer. And
[00:47:59.120 --> 00:48:05.680] in particular, we want them, we want to rearrange these tokens into this ID X variable that we're
[00:48:05.680 --> 00:48:09.440] going to be feeding into the transformer. So we don't want a single very long one dimensional
[00:48:09.440 --> 00:48:16.880] sequence. We want an entire batch, where each sequence is up to, it's basically T tokens,
[00:48:16.880 --> 00:48:21.840] and T cannot be larger than the maximum sequence length. And then we have these T,
[00:48:22.640 --> 00:48:29.600] T long sequence of tokens. And we have B independent examples of sequences. So how can we create a
[00:48:29.600 --> 00:48:35.360] B by T tensor that we can feed into the forward out of these one dimensional sequences. So here's
[00:48:35.360 --> 00:48:41.440] my favorite way to to achieve this. So if we take torch, and then we create a tensor object out of
[00:48:41.440 --> 00:48:46.720] this list of integers and just the first 24 tokens, my favorite way to do this is basically you do a
[00:48:46.720 --> 00:48:54.800] dot view of, for example, four by six, which multiply to 24. And so it's just a two dimensional
[00:48:54.800 --> 00:48:58.560] rearrangement of these tokens. And you'll notice that when you view this one dimensional sequence
[00:48:58.560 --> 00:49:07.280] as two dimensional four by six here, the first six tokens up to here end up being the first row,
[00:49:07.280 --> 00:49:13.200] the next six tokens here end up being the second row, and so on. And so basically, it's just going
[00:49:13.200 --> 00:49:21.440] to stack up the every six tokens in this case, as independent rows, and it creates a batch of
[00:49:21.440 --> 00:49:27.840] tokens in this case. And so for example, if we are token 25, in the transformer, when we feed this
[00:49:27.840 --> 00:49:33.440] in, and this becomes the IDX, this token is going to see these three tokens, and it's going to try
[00:49:33.440 --> 00:49:40.640] to predict that 198 comes next. So in this way, we are able to create this two dimensional batch,
[00:49:40.640 --> 00:49:46.480] that's quite nice. Now, in terms of the label that we're going to need for the target to calculate
[00:49:46.480 --> 00:49:50.880] the loss function, how do we get that? Well, we could write some code inside the forward pass,
[00:49:50.880 --> 00:49:55.840] because we know that the next token in a sequence, which is the label, is just to the right of us.
[00:49:55.840 --> 00:50:02.160] But you'll notice that actually we for this token at the very end, 13, we don't actually have the
[00:50:02.160 --> 00:50:08.000] next correct token because we didn't load it. So we actually didn't get enough information here.
[00:50:09.200 --> 00:50:14.560] So I'll show you my favorite way of basically getting these batches. And I like to personally
[00:50:14.560 --> 00:50:19.840] have not just the input to the transformer, which I like to call x, but I also like to create the
[00:50:19.840 --> 00:50:26.000] labels tensor, which is of the exact same size as x, but contains the targets at every single
[00:50:26.000 --> 00:50:30.880] position. And so here's the way that I like to do that. I like to make sure that I fetch plus one
[00:50:30.880 --> 00:50:38.800] token, because we need the ground truth for the very last token, for 13. And then when we're creating
[00:50:38.800 --> 00:50:44.800] the input, we take everything up to the last token, not including and view it as four by six. And when
[00:50:44.800 --> 00:50:51.440] we're creating targets, we do the buffer, but starting at index one, not index zero. So we're
[00:50:51.440 --> 00:50:55.840] skipping the first element and we view it in the exact same size. And then when I print this,
[00:50:55.840 --> 00:51:02.320] here's what happens, where we see that basically as an example for this token 25,
[00:51:02.320 --> 00:51:08.320] its target was 198. And that's now just stored at the exact same position in the target tensor,
[00:51:08.320 --> 00:51:16.320] which is 198. And also this last token 13 now has its label, which is 198. And that's just because
[00:51:16.320 --> 00:51:22.400] we loaded this plus one here. So basically, this is the way I like to do it. You take long sequences,
[00:51:22.400 --> 00:51:29.280] you view them in two dimensional terms, so that you get batches of time. And then we make sure to
[00:51:29.280 --> 00:51:36.080] load one additional token. So we basically load a buffer of tokens of b times t plus one. And then
[00:51:36.080 --> 00:51:40.800] we sort of offset things and view them. And then we have two tensors, one of them is the input to
[00:51:40.800 --> 00:51:47.440] the transformer. And the other exactly is the labels. And so let's now reorganize this code and
[00:51:47.440 --> 00:51:53.760] create a very simple data loader object that tries to basically load these tokens and
[00:51:53.760 --> 00:51:59.040] feed them to the transformer and calculate the loss. Okay, so I reshuffled the code here,
[00:51:59.040 --> 00:52:04.240] accordingly. So as you can see here, I'm temporarily overriding to run on CPU.
[00:52:05.040 --> 00:52:09.520] And importing to token and all of this should look familiar, we're loading 1000 characters.
[00:52:09.520 --> 00:52:13.920] I'm setting BT to just be four and 32 right now, just because we're debugging,
[00:52:13.920 --> 00:52:18.640] we just want to have a single batch that's very small. And all of this should now look familiar
[00:52:18.640 --> 00:52:23.280] and follows what we did on the right. And then here, we get the we create the model
[00:52:23.280 --> 00:52:30.720] and get the logits. And so here, as you see, I already ran this only runs in a few seconds.
[00:52:30.720 --> 00:52:38.640] But because we have a batch of four by 32, our logits are now size four by 32 by 50,000 to 57.
[00:52:38.640 --> 00:52:44.480] So those are the logits for what comes next at every position. And now we have the labels,
[00:52:44.480 --> 00:52:49.440] which are stored in Y. So now is the time to calculate the loss, and then do the backward pass,
[00:52:49.440 --> 00:52:55.040] and then the optimization. So let's first calculate loss. Okay, so to calculate the loss,
[00:52:55.040 --> 00:52:59.840] we're going to adjust the forward function of this and in module in the model. And in particular,
[00:52:59.840 --> 00:53:03.040] we're not just going to be returning logits, but also we're going to return the loss.
[00:53:03.040 --> 00:53:08.400] And we're going to not just passing the input in the seats, but also the targets in Y.
[00:53:08.400 --> 00:53:14.800] And now we will print not load just that shape anymore. We're actually going to bring the loss
[00:53:14.800 --> 00:53:19.200] function and then assist that exit of zero so that we skip some of the sampling logic.
[00:53:19.200 --> 00:53:26.080] So now let's swing up to the forward function, which gets called there, because now we also have
[00:53:26.080 --> 00:53:32.720] these optional targets. And when we get the targets, we can also calculate the loss.
[00:53:32.720 --> 00:53:38.720] And remember that we want to basically return a logist loss and loss by default is none. But
[00:53:38.720 --> 00:53:50.400] let's put this here. If targets is not done, then we want to calculate the loss. And copilot is
[00:53:50.400 --> 00:53:55.600] already getting excited here and calculating the what looks to be correct loss. It is using the
[00:53:55.600 --> 00:54:03.440] cross entropy loss as is documented here. So this is a function in PyTorch under the functional.
[00:54:03.440 --> 00:54:09.520] Now, what is actually happening here, because it looks a little bit scary. Basically, the F dot
[00:54:09.520 --> 00:54:14.560] cross entropy does not like multi dimensional inputs. It can't take a B by T by vocab size.
[00:54:14.560 --> 00:54:19.280] So what's happening here is that we are flattening out of this three dimensional tensor into just
[00:54:19.280 --> 00:54:23.600] two dimensions. The first dimension is going to be calculated automatically, and it's going to be
[00:54:23.600 --> 00:54:30.160] B times T. And then the last dimension is vocab size. So basically, this is flattening out this
[00:54:30.160 --> 00:54:36.560] three dimensional tensor of logits to just be two dimensional B times T, all individual examples
[00:54:36.560 --> 00:54:42.800] and vocab size on in terms of the length of each row. And then it's also flattening out the
[00:54:42.800 --> 00:54:47.360] targets, which are also two dimensional at this stage. But we're going to just flatten them out.
[00:54:47.360 --> 00:54:52.080] So they're just a single tensor of B times T. And this can then pass into cross entropy to
[00:54:52.080 --> 00:54:57.360] calculate a loss, which we return. So this should basically at this point run,
[00:54:57.360 --> 00:55:04.320] because it's not too complicated. So let's run it. And let's see if we should be printing the loss.
[00:55:04.320 --> 00:55:14.320] And here we see that we printed 11 roughly. And so
[00:55:16.960 --> 00:55:22.080] and notice that this is the tensor of a single element, which is this number 11. Now, we also
[00:55:22.080 --> 00:55:27.520] want to be able to calculate a reasonable kind of starting point for a random linearized network.
[00:55:27.520 --> 00:55:33.760] So we covered this in previous videos, but our vocabulary size is 50,257. At initialization of
[00:55:33.760 --> 00:55:40.960] the network, you would hope that every vocab element is getting roughly a uniform probability,
[00:55:41.680 --> 00:55:46.880] so that we're not favoring at initialization, any token way too much. We're not confidently
[00:55:46.880 --> 00:55:51.760] wrong at initialization. So we're hoping is that the probability of any arbitrary token is roughly
[00:55:51.760 --> 00:55:58.560] one over 50,257. And now we can sanity check the loss, because remember that the cross entropy
[00:55:58.560 --> 00:56:05.920] loss is just basically the negative log likelihood. So if we now take this probability, and we take
[00:56:05.920 --> 00:56:11.440] it through the natural logarithm, and then we do the negative, that is the loss we expect at
[00:56:11.440 --> 00:56:17.040] initialization, and we covered this in previous videos. So I would expect something around 10.82,
[00:56:17.040 --> 00:56:21.360] and we're seeing something around the level. So it's not way off. This is roughly the probability
[00:56:21.360 --> 00:56:26.480] expect at initialization. So that tells me that the at initialization or probability distribution
[00:56:26.480 --> 00:56:32.160] is roughly diffuse. It's a good starting point. And we can now perform the optimization and
[00:56:32.160 --> 00:56:35.680] tell the network which elements, you know, should follow correctly in what order.
[00:56:35.680 --> 00:56:41.200] So at this point, we can do a loss that backward, calculate the gradients and do an optimization.
[00:56:41.200 --> 00:56:47.200] So let's get to that. Okay, so let's do the optimization now. So here we have
[00:56:47.200 --> 00:56:54.320] the loss is this is how we get the loss. But now basically we want a load for loop here. So for
[00:56:54.320 --> 00:57:00.400] i in range, let's do 50 steps or something like that. Let's create an optimizer object in PyTorch.
[00:57:02.000 --> 00:57:08.640] And so here we are using the atom optimizer, which is an alternative to stochastic gradient
[00:57:08.640 --> 00:57:13.760] descent optimizer, SGT that we're using. So SGT is a lot simpler, atom is a bit more involved.
[00:57:13.760 --> 00:57:19.440] And actually specifically like the atom w variation, because in my opinion, it kind of just like fixes
[00:57:19.440 --> 00:57:26.080] a bug. So atom w is a bug fix of atom is what I would say. When we go to the documentation for
[00:57:26.080 --> 00:57:34.080] an w. Oh my gosh. We see that it takes about two parameters and it's a little bit more
[00:57:34.080 --> 00:57:38.720] complicated than the SGT we were looking at before, because in addition to basically
[00:57:38.720 --> 00:57:43.440] updating the parameters with the gradient scaled by the learning rate, it keeps these buffers
[00:57:43.440 --> 00:57:49.200] around and it keeps two buffers, the M and the V, which it calls the first and the second moment.
[00:57:49.200 --> 00:57:53.040] So something that looks a bit like momentum is something that looks a bit like RMS prop,
[00:57:53.040 --> 00:57:57.120] if you're familiar with it. But you don't have to be, it's just kind of like a normalization
[00:57:57.120 --> 00:58:01.520] that happens on each gradient element individually and speeds up the optimization,
[00:58:01.520 --> 00:58:05.360] especially for language models. But I'm not going to go into the detail right here.
[00:58:05.360 --> 00:58:11.680] We're going to treat this a bit of a black box and it just optimizes the objective faster
[00:58:11.680 --> 00:58:16.560] than SGT, which is what we've seen in the previous lectures. So let's use it as a black box in our
[00:58:16.560 --> 00:58:23.680] case. Create the optimizer object and then go through the optimization.
[00:58:23.680 --> 00:58:34.800] The first thing to always make sure the copilot did not forget to zero the gradients. So always
[00:58:34.800 --> 00:58:39.440] remember that you have to start with a zero gradient. Then when you get your loss and you do a dot
[00:58:39.440 --> 00:58:45.840] backward, dot backward adds to gradients. So it deposits gradients. It always does a plus equals
[00:58:45.840 --> 00:58:50.080] on whatever the gradients are, which is why you must set them to zero. So this accumulates
[00:58:50.080 --> 00:58:56.960] the gradient from this loss. And then we call the step function on the optimizer to update the
[00:58:56.960 --> 00:59:05.360] parameters and to decrease the loss. And then we print the step and the loss dot item is used
[00:59:05.360 --> 00:59:11.440] here because loss is a tensor with a single element dot item will actually convert that to a single
[00:59:11.440 --> 00:59:17.520] float. And this float will not will live on the CPU. So this gets to some of the internals again
[00:59:17.520 --> 00:59:24.000] of the devices, but loss is a is a tensor with a single element and it lives on GPU for me because
[00:59:24.000 --> 00:59:29.680] I'm using GPUs. When you call dot item, pytorch behind the scenes will take that one dimensional
[00:59:29.680 --> 00:59:36.640] tensor ship it back to the CPU memory and convert it into a float that we can just print. So this
[00:59:36.640 --> 00:59:46.320] is the optimization and this should probably just work. Let's see what happens. Actually,
[00:59:46.320 --> 00:59:51.600] sorry, let me instead of using CPU override, let me delete that. So this is a bit faster for me
[00:59:51.600 --> 01:00:02.960] and it runs on CUDA. Oh, expected all tensors to be on the same device but found at least two devices,
[01:00:02.960 --> 01:00:09.280] CUDA zero and CPU. So CUDA zero is the zero GPU because I actually have a GPUs on this box.
[01:00:09.280 --> 01:00:17.280] So the 0th GPU on my box and CPU. And a model we have moved to device, but when I was writing
[01:00:17.280 --> 01:00:22.640] this code, I actually introduced the bug because buff we never moved to device. And you have to be
[01:00:22.640 --> 01:00:30.400] careful because you can't just do buff dot to of device. It's not stateful. It doesn't convert it
[01:00:30.400 --> 01:00:37.280] to be a device. It instead returns pointer to a new memory which is on the device. So you see how
[01:00:37.280 --> 01:00:42.080] we can just do model that to a device that does not apply to tensors. You have to do buff equals
[01:00:42.080 --> 01:00:48.960] buff dot to device. And then this should work. Okay.
[01:00:48.960 --> 01:00:54.720] So what do we expect to see? We expect to see a reasonable loss in the beginning and then we
[01:00:54.720 --> 01:00:59.120] continue to optimize just a single batch. And so we want to see that we can overfit this single
[01:00:59.120 --> 01:01:04.240] batch. We can we can crush this little batch and we can perfectly predict the indices on just this
[01:01:04.240 --> 01:01:10.960] little batch. And in these that is roughly what we're seeing here. So we started off at roughly
[01:01:10.960 --> 01:01:16.240] 10.82 11 in this case. And then as we continue optimizing on this single batch without loading
[01:01:16.240 --> 01:01:20.560] new examples, we are making sure that we can overfit a single batch. And we are getting to
[01:01:20.560 --> 01:01:25.280] very, very low loss. So the transformer is memorizing this single individual batch.
[01:01:26.080 --> 01:01:30.320] And one more thing I didn't mention is the learning rate here is three negative four,
[01:01:30.320 --> 01:01:36.400] which is a pretty good default for most optimizations that you went around at a very early debugging
[01:01:36.400 --> 01:01:43.840] stage. So this is our simple inner loop. And we are overfitting a single batch. And this looks good.
[01:01:43.840 --> 01:01:47.760] So now what what comes next is we don't just want to overfit a single batch. We actually want
[01:01:47.760 --> 01:01:52.960] to do an optimization. So we actually need to iterate these xy batches and create a little data
[01:01:52.960 --> 01:01:57.760] loader that makes sure that we're always getting a fresh batch and that we're actually optimizing
[01:01:57.760 --> 01:02:02.160] a reasonable objective. So let's do that next. Okay, so this is where I came up with, and I wrote
[01:02:02.160 --> 01:02:08.320] a little data loader light. So what this data loader does is we're importing to token up here,
[01:02:08.320 --> 01:02:14.800] reading the entire text file from this single input.txt, tokenizing it, and then we're just
[01:02:14.800 --> 01:02:20.880] printing the number of tokens in total. And the number of batches and a single epoch of iterating
[01:02:20.880 --> 01:02:26.160] over this dataset. So how many unique batches do we output before we loop back around the beginning
[01:02:26.160 --> 01:02:32.320] of the document and start reading it again. So we start off at position zero, and then we simply
[01:02:32.320 --> 01:02:38.320] walk the document in batches of b times t. So we take chunks of b times t, and then always advance
[01:02:38.320 --> 01:02:45.280] by b times t. And it's important to note that we're always advancing our position by exactly b times
[01:02:45.280 --> 01:02:50.640] t. But when we're fetching the tokens, we're actually fetching from current position to b times
[01:02:50.640 --> 01:02:57.440] t plus one. And we need that plus one, because remember, we need the target token for the last
[01:02:57.440 --> 01:03:05.760] token in the current patch. And so that way we can do the x, y exactly as we did it before. And
[01:03:05.760 --> 01:03:12.480] if we are to run out of data, we'll just loop back around to zero. So this is one way to write
[01:03:12.480 --> 01:03:18.800] a very, very simple data loader. That simply just goes through the file in chunks. And this
[01:03:18.800 --> 01:03:24.800] could enough for us for current purposes. And we're going to complexify it later. And now we'd like
[01:03:24.800 --> 01:03:29.440] to come back around here, and we'd like to actually use our data loader. So the import tick token has
[01:03:29.440 --> 01:03:36.320] moved up. And actually all of this is now useless. So instead we just want a train loader for the
[01:03:36.320 --> 01:03:43.440] training data. And we want to use the same hyperparameters for four. So their size was four and time was 32.
[01:03:44.320 --> 01:03:50.320] And then here, we need to get the x, y for the current batch. So let's see if kopal gets it,
[01:03:50.320 --> 01:03:56.720] because this is simple enough. So we call the next batch. And then we make sure that we have to
[01:03:56.720 --> 01:04:05.040] move our tensors from CPU to the device. So here, when I converted the tokens,
[01:04:05.040 --> 01:04:10.800] notice that I didn't actually move these tokens to the GPU, I left them on the CPU, which is default.
[01:04:12.080 --> 01:04:16.640] And that's just because I'm trying not to waste too much memory on the GPU. In this case, this is a
[01:04:16.640 --> 01:04:22.960] tiny data set that would fit. But it's fine to just ship it to GPU right now for our purposes
[01:04:22.960 --> 01:04:29.120] right now. So we get the next batch, we keep the data loader simple CPU class. And then here,
[01:04:29.120 --> 01:04:37.040] we actually ship it to the GPU and do all the computation. And let's see if this runs. So Python
[01:04:37.040 --> 01:04:42.560] trained to be to the pie. And what do we expect to see before this actually happens? What we
[01:04:42.560 --> 01:04:47.680] expect to see is now we're actually getting the next batch. So we expect to not overfit a single
[01:04:47.680 --> 01:04:54.640] batch. And so I expect our loss to come down, but not too much. And that's because I still
[01:04:54.640 --> 01:05:01.600] expected to come down because in the 50,257 tokens, many of those tokens never occur in our data set.
[01:05:01.600 --> 01:05:06.240] So there's some very easy gains to be made here in the optimization by, for example, taking the
[01:05:06.240 --> 01:05:11.520] biases of all the logits that never occur and driving them to negative infinity. And that would
[01:05:11.520 --> 01:05:16.080] basically just, it's just that all of these crazy unicodes or different languages, those tokens
[01:05:16.080 --> 01:05:20.480] never occur. So their probability should be very low. And so the gains that we should be seeing
[01:05:20.480 --> 01:05:26.000] are along the lines of basically deleting the usage of tokens that never occur. That's probably
[01:05:26.000 --> 01:05:30.800] most of the loss gain that we're going to see at this scale right now. But we shouldn't come to
[01:05:31.680 --> 01:05:38.160] zero because we are only doing 58 iterations. And I don't think that's enough to do an epoch
[01:05:38.160 --> 01:05:46.640] right now. So let's see what we got. We, we have 338,000 tokens, which makes sense with our
[01:05:46.640 --> 01:05:52.720] three to one compression ratio, because there are one million characters. So one epoch with the
[01:05:52.720 --> 01:05:59.840] current setting of B and T will take 2,600 batches. And we're only doing 50 batches of optimization
[01:05:59.840 --> 01:06:06.800] in here. So we start off in a familiar territory, as expected, and then we seem to come down to
[01:06:06.800 --> 01:06:13.440] about 6.6. So basically, I think seem to be working okay right now with respect to our expectations.
[01:06:13.440 --> 01:06:19.120] So that's good. Okay, next, I want to actually fix a bug that we have in our code. It's not a
[01:06:19.120 --> 01:06:27.040] major bug, but it is a bug with respect to how GPT-2 training should happen. So the bug is the
[01:06:27.040 --> 01:06:31.120] following. We were not being careful enough when we were loading the weights from hugging face and
[01:06:31.120 --> 01:06:39.680] we actually missed a little detail. So if we come here, notice that the shape of these two tensors is
[01:06:39.680 --> 01:06:47.040] the same. So this one here is the token embedding at the bottom of the transformer. Right, so and
[01:06:47.040 --> 01:06:53.440] this one here is the language modeling head at the top of the transformer. And both of these are
[01:06:53.440 --> 01:07:00.160] basically two dimensional tensors and their shape is identical. So here, the first one is the output
[01:07:00.160 --> 01:07:05.040] embedding, the token embedding. And the second one is this linear layer at the very top, the classifier
[01:07:05.040 --> 01:07:14.640] layer. Both of them are of shape 50,000 to 57 by 768. This one here is giving us our token embeddings
[01:07:14.640 --> 01:07:21.040] at the bottom. And this one here is taking the 768 channels of the transformer and trying to upscale
[01:07:21.040 --> 01:07:28.240] that to 50,000 to 57 to get the logis for the next token. So they're both the same shape. But more
[01:07:28.240 --> 01:07:35.200] than that, actually, if you look at comparing their elements in PyTorch, this is an element
[01:07:35.200 --> 01:07:39.600] wise equality. So then we use dot all, and we see that every single element is identical.
[01:07:39.600 --> 01:07:46.080] And more than that, we see that if we actually look at the data pointer, this is what this is a
[01:07:46.080 --> 01:07:51.360] way in PyTorch to get the actual pointer to the data and the storage, we see that actually the
[01:07:51.360 --> 01:07:56.480] pointer is identical. So not only are these two separate tensors that happen to have the same shape
[01:07:56.480 --> 01:08:02.240] and elements, they're actually pointing to the identical tensor. So what's happening here is
[01:08:02.240 --> 01:08:10.000] that this is a common wait time scheme that actually comes from the original, from the original
[01:08:10.000 --> 01:08:15.040] attention is all you need paper, and actually even the reference before it. So if we come here,
[01:08:15.040 --> 01:08:25.920] embeddings and softmax in the attention is all you need paper, they mention that in our model,
[01:08:25.920 --> 01:08:30.720] we shared the same wait matrix between the two embedding layers and the pre-softmax linear
[01:08:30.720 --> 01:08:37.520] transformation similar to 30. So this is an awkward way to phrase that these two are shared,
[01:08:37.520 --> 01:08:42.000] and they're tied, and they're the same matrix. And the 30 reference is this paper.
[01:08:42.000 --> 01:08:49.440] So this came out in 2017. And you can read the full paper, but basically it argues for this
[01:08:49.440 --> 01:08:54.880] wait time scheme. And I think intuitively the idea for why you might want to do this
[01:08:54.880 --> 01:09:00.080] comes from this paragraph here. And basically you can observe that
[01:09:03.280 --> 01:09:10.080] you actually want these two matrices to behave similar in the following sets. If two tokens
[01:09:10.080 --> 01:09:14.560] are very similar semantically, like maybe one of them is all lowercase and the other one is
[01:09:14.560 --> 01:09:18.240] all uppercase, or it's the same token in the different language or something like that,
[01:09:18.240 --> 01:09:22.160] if you have similarity between two tokens, presumably you would expect that they are
[01:09:22.160 --> 01:09:27.600] nearby in the token embedding space. But in the exact same way, you'd expect that if you have
[01:09:27.600 --> 01:09:32.720] two tokens that are similar semantically, you'd expect them to get the same probabilities
[01:09:33.280 --> 01:09:40.240] at the output of a transformer, because they are semantically similar. And so both positions
[01:09:40.240 --> 01:09:45.120] in the transformer at the very bottom and at the top have this property that similar tokens
[01:09:45.120 --> 01:09:50.720] should have similar embeddings or similar weights. And so this is what motivates their
[01:09:50.720 --> 01:09:54.400] exploration here. And they kind of, you know, I don't want to go through the entire paper.
[01:09:54.400 --> 01:10:00.160] And you can go through it, but this is what they observe. They also observe that if you look at
[01:10:00.160 --> 01:10:06.960] the output embeddings, they also behave like word embeddings. If you just kind of try to use those
[01:10:06.960 --> 01:10:12.800] weights as word embeddings. So they kind of observe this similarity. They try to tie them,
[01:10:12.800 --> 01:10:18.160] and they observe that they can get much better performance in that way. And so this was adopted
[01:10:18.160 --> 01:10:23.680] and the attention is on a meat paper. And then it was used again in GPT-2 as well. So
[01:10:25.200 --> 01:10:30.160] I couldn't find it in the Transformers implementation. I'm not sure where they tie those embeddings,
[01:10:30.160 --> 01:10:37.920] but I can't find it in the original GPT-2 code introduced by OpenAI. So this is OpenAI GPT-2
[01:10:37.920 --> 01:10:42.880] source model. And here where they are forwarding this model, and this is an intensive flow, but
[01:10:42.880 --> 01:10:50.560] that's okay, we see that they get the WTE token embeddings. And then here is the encoder of the
[01:10:50.560 --> 01:10:55.840] token embeddings and the position. And then here at the bottom, they use the WTE again
[01:10:55.840 --> 01:11:02.800] to do the logits. So when they get the logits, it's a mat model of this output from the transformer
[01:11:02.800 --> 01:11:10.240] and the WTE tensor is reused. And so the WTE tensor basically is used twice on the bottom of the
[01:11:10.240 --> 01:11:15.920] transformer and on the top of the transformer. And in the backward pass, we'll get gradients
[01:11:15.920 --> 01:11:22.080] contributions from both branches. And these gradients will add up on the WTE tensor.
[01:11:22.080 --> 01:11:27.440] So we'll get a contribution from the classifier layer. And then at the very end of the transformer,
[01:11:27.440 --> 01:11:33.760] we'll get a contribution at the bottom of it flowing again into the WTE tensor.
[01:11:33.760 --> 01:11:39.920] So we are currently not sharing WTE in our code, but we want to do that.
[01:11:42.160 --> 01:11:49.840] So weight sharing scheme. And one way to do this, let's see if Kupal gets it.
[01:11:49.840 --> 01:11:58.800] Oh, it does. Okay. So this is one way to do it. Basically, relatively straightforward,
[01:11:58.800 --> 01:12:04.640] what we're doing here is we're taking the WTE dot weight, and we're simply redirecting
[01:12:05.520 --> 01:12:13.120] it to point to the element. So this basically copies the data pointer, right, it copies the
[01:12:13.120 --> 01:12:20.960] reference. And now the WTE dot weight becomes orphaned, the old value of it, and pytorch will
[01:12:20.960 --> 01:12:28.640] clean it up. And so we are only locked with a single tensor. And it's going to be used twice
[01:12:28.640 --> 01:12:34.640] in the forward pass. And this is, to my knowledge, all that's required. So we should be able to
[01:12:35.520 --> 01:12:40.960] use this. And this should probably train. We're just going to basically be using this exact same
[01:12:40.960 --> 01:12:48.960] sensor twice. And we weren't being careful with tracking the likelihoods. But according to the
[01:12:48.960 --> 01:12:52.240] paper, and according to the results, you'd actually expect slightly better results doing this.
[01:12:52.240 --> 01:12:57.600] And in addition to that, one other reason that this is very, very nice for us is that this is a
[01:12:57.600 --> 01:13:06.160] ton of parameters, right? What is the size of here? It's 768 times 50,257. So this is 40 million
[01:13:06.160 --> 01:13:13.840] parameters. And this is a 124 million parameter model. So 40 divided 124. So this is like 30% of
[01:13:13.840 --> 01:13:19.440] the parameters are being saved using this wait time scheme. And so this might be one of the
[01:13:19.440 --> 01:13:23.120] reasons that this is working slightly better. If you're not training the model long enough,
[01:13:23.840 --> 01:13:28.240] because of the wait time, you don't have to train as many parameters. And so you become more efficient
[01:13:28.240 --> 01:13:34.240] in terms of the training process, because you have fewer parameters. And you're putting in this
[01:13:34.240 --> 01:13:38.960] inductive bias that these two embeddings should share similarities between tokens.
[01:13:38.960 --> 01:13:45.040] So this is the wait time scheme. And we saved a ton of parameters. And we expect our model to work
[01:13:45.040 --> 01:13:48.960] slightly better because of this scheme. Okay, next, I would like us to be a bit more careful
[01:13:48.960 --> 01:13:53.920] with the initialization and to try to follow the way GPT to initialize their model. Now,
[01:13:53.920 --> 01:13:59.360] unfortunately, the GPT two paper and the GPT three paper are not very explicit about initialization.
[01:13:59.360 --> 01:14:04.240] So we kind of have to read between lines. And instead of going to the paper, which is quite vague,
[01:14:04.240 --> 01:14:10.480] there's a bit of information in the code that opening up released. So when we go to the model.py,
[01:14:10.480 --> 01:14:16.160] we see that when they initialize their weights, they are using the standard deviation of 0.02.
[01:14:16.720 --> 01:14:22.160] And that's how they, they, so this is a normal distribution for the weights. And the standard
[01:14:22.160 --> 01:14:29.040] deviation is 0.02. For the bias, they initialize that with zero. And then when we scroll down here,
[01:14:29.040 --> 01:14:39.040] why is this not scrolling? The token embeddings are initialized at 0.02. And position embeddings
[01:14:39.040 --> 01:14:45.120] at 0.01 for some reason. So those are the initialization. And we'd like to mirror that and GPT two
[01:14:46.320 --> 01:14:51.520] in our module here. So here's a snippet of code that I sort of came up with very quickly.
[01:14:51.520 --> 01:14:58.880] So what's happening here is at the end of our initializer for the GPT module, we're calling the
[01:14:58.880 --> 01:15:06.240] apply function of an end module. And that iterates all the sub modules of this module and applies
[01:15:06.240 --> 01:15:12.400] in it weights function on them. And so what's happening here is that we're in it, we're iterating
[01:15:12.400 --> 01:15:17.680] all the modules here. And if they are an end that linear module, then we're going to make sure to
[01:15:17.680 --> 01:15:23.360] initialize the weight using a normal with the standard deviation of 0.02. If there's a bias in
[01:15:23.360 --> 01:15:28.640] this layer, we make sure to initialize that to zero. Note that zero initialization for the bias
[01:15:28.640 --> 01:15:35.760] is not actually the pytorch default. By default, the bias here is initialized with a uniform. So
[01:15:35.760 --> 01:15:41.120] that's interesting. So we make sure to use zero. And for the embedding, we're just going to get 0.02.
[01:15:41.120 --> 01:15:46.960] And keep it the same. So we're not going to change it to 0.01 for positional, because it's about the
[01:15:46.960 --> 01:15:52.240] same. And then if you look through our model, the only other layer that requires initialization,
[01:15:52.240 --> 01:15:56.640] and that has parameters, is the layer norm. And the pytorch default initialization
[01:15:56.640 --> 01:16:01.200] assess the scale in the layer norm to be one, and the offset in the layer norm to be zero.
[01:16:01.200 --> 01:16:04.640] So that's exactly what we want. And so we're just going to keep it that way.
[01:16:05.520 --> 01:16:15.440] And so this is the default initialization, if we are following the, where is it, the GPT to a source
[01:16:15.440 --> 01:16:20.960] code that they released. I would like to point out, by the way, that typically the standard deviation
[01:16:20.960 --> 01:16:25.440] here on this initialization, if you follow the heavier initialization, would be one over the square
[01:16:25.440 --> 01:16:31.120] root of the number of features that are incoming into this layer. But if you'll notice, actually,
[01:16:31.120 --> 01:16:37.200] 0.02 is basically consistent with that, because the demo sizes inside these transformers for GPT
[01:16:37.200 --> 01:16:44.640] two are roughly 768, 1600, etc. So one over the square root of, for example, 768 gives us 0.03.
[01:16:44.640 --> 01:16:55.680] If we plug in 600, 1600, we get 0.02. If we plug in three times that 0.014, etc. So basically 0.02 is
[01:16:55.680 --> 01:17:04.640] roughly in the vicinity of reasonable values for these initializations anyway. So it's not
[01:17:04.640 --> 01:17:11.920] completely crazy to be hard coding 0.02 here. But you'd like typically something that grows with
[01:17:11.920 --> 01:17:16.640] the model size instead. But we will keep this because that is the GPT two initialization per
[01:17:16.640 --> 01:17:21.120] their source code. But we are not fully done yet on initialization because there's one more caveat
[01:17:21.120 --> 01:17:27.920] here. So here, a modified initialization, which accounts for the accumulation on the residual path
[01:17:27.920 --> 01:17:32.560] with model depth is used. We scaled the weight of residual layers of initialization by factor
[01:17:32.560 --> 01:17:37.680] one over square root of n, where n is the number of residual layers. So this is what GPT two paper
[01:17:37.680 --> 01:17:43.360] says. So we have not implemented that yet. And we can do so now. Now I'd like to actually kind of
[01:17:43.360 --> 01:17:48.560] like motivate a little bit what they mean here, I think. So here's roughly what they mean.
[01:17:50.240 --> 01:17:56.240] If you start out with zeros in your residual stream, remember that each residual stream is a
[01:17:56.240 --> 01:18:02.800] is of this form where we continue adding to it x is x plus something, some kind of contribution.
[01:18:02.800 --> 01:18:09.840] So every single block of the residual network contributes some amount and it gets added.
[01:18:09.840 --> 01:18:17.040] And so what ends up happening is that the variance of the activations in the residual stream grows.
[01:18:17.840 --> 01:18:23.840] So here's a small example, if we start at zero, and then we for 100 times, we have sort of this
[01:18:23.840 --> 01:18:33.040] residual stream of 768 zeros. And then 100 times we add random, which is a normal distribution,
[01:18:33.040 --> 01:18:38.320] zero mean one standard deviation. If we add to it, then by the end, the residual stream has
[01:18:38.320 --> 01:18:44.240] grown to have standard deviation of 10. And that's just because we're always adding
[01:18:46.320 --> 01:18:52.560] these numbers. And so this scaling factor that they use here exactly compensates for that growth.
[01:18:52.560 --> 01:19:00.080] So if we take n, and we basically scale down every one of these contributions into the residual
[01:19:00.080 --> 01:19:06.560] stream by one over the square root of n. So one over the square root of n is n to the negative 0.5,
[01:19:06.560 --> 01:19:14.000] right? Because n to the 0.5 is the square root, and then one over the square root is n negative 0.5.
[01:19:14.000 --> 01:19:19.120] If we scale it in this way, then we see that we actually get 1.
[01:19:19.120 --> 01:19:27.040] So this is a way to control the growth of activations inside the residual stream
[01:19:27.040 --> 01:19:32.320] in the forward pass. And so we'd like to initialize in the same way, where these weights
[01:19:32.320 --> 01:19:40.080] that are at the end of each block, so this seed project layer, the gpt paper proposes to scale
[01:19:40.080 --> 01:19:43.360] down those weights by one over the square root of the number of residual layers.
[01:19:43.360 --> 01:19:47.040] So one crude way to implement this is the following.
[01:19:47.040 --> 01:19:53.600] I don't know if this is a pytorch sanctioned, but it works for me, is we all do in the initialization.
[01:19:53.600 --> 01:20:06.880] See, that special nano gpt scale in it is one. So we're setting kind of like a flag for this module.
[01:20:07.680 --> 01:20:10.960] There must be a bad way to pytorch, right, but I don't know.
[01:20:10.960 --> 01:20:16.960] Okay, so we're basically attaching this flag and trying to make sure that it doesn't conflict
[01:20:16.960 --> 01:20:23.440] with anything previously. And then when we come down here, this std should be 0.02 by default.
[01:20:24.880 --> 01:20:33.600] But then if has after module of this thing, then s d d times equals
[01:20:33.600 --> 01:20:42.480] not guessing correctly. So we want one over the square root of the number of layers. So
[01:20:42.480 --> 01:20:48.240] the number of residual layers here is twice times.
[01:20:49.520 --> 01:20:56.800] Solve that conflict layers. And then this times negative point five. So we want to scale down
[01:20:56.800 --> 01:21:03.440] that standard deviation. And this should be correct and implement that. I should clarify,
[01:21:03.440 --> 01:21:07.280] by the way, that the two times number of layers comes from the fact that every single one of
[01:21:07.280 --> 01:21:12.000] our layers in the transformer actually has two blocks that add to the residual pathway, right?
[01:21:12.000 --> 01:21:15.760] We have the attention and then the MLP. So that's where the two times comes from.
[01:21:16.800 --> 01:21:22.320] And the other thing to mention is that what's slightly awkward, but we're not going to fix it,
[01:21:22.320 --> 01:21:29.600] is that because we are weight sharing the WTE and the LMAD, in this iteration over all
[01:21:29.600 --> 01:21:34.000] sub modules, we're going to actually come around to that tensor twice. So we're going to first
[01:21:34.000 --> 01:21:38.800] initialize it as an embedding with 0.02. And then we're going to come back around it again in a
[01:21:38.800 --> 01:21:44.960] linear and initialize it again using 0.02. And it's going to be 0.02 because the LMAD is, of course,
[01:21:44.960 --> 01:21:50.240] not scaled. So it's not going to come here. It's just it's going to be basically initialized twice
[01:21:50.240 --> 01:21:55.920] using the identical same initialization. But that's OK. And then scrolling over here,
[01:21:55.920 --> 01:22:04.640] I added some code here so that we have reproducibility to set the seeds. And now we should be able to
[01:22:04.640 --> 01:22:11.440] python train GPT2.py and let this running. And as far as I know, this is the GPT2 initialization
[01:22:12.240 --> 01:22:20.080] in the way we've implemented right now. So this looks reasonable to me. Okay, so at this point,
[01:22:20.080 --> 01:22:24.560] we have the GPT2 model. We have some confidence that it's correctly implemented. We've initialized
[01:22:24.560 --> 01:22:28.640] it properly. And we have a data loader that's iterating through data batches. And we can train.
[01:22:28.640 --> 01:22:33.760] So now comes the fun part. I'd like us to speed up the training by a lot. So we're getting our
[01:22:33.760 --> 01:22:38.640] money's worth with respect to the hardware that we are using here. And we're going to speed up
[01:22:38.640 --> 01:22:44.320] the training by quite a bit. Now you always want to start with what hardware do you have? What does
[01:22:44.320 --> 01:22:51.840] it offer? And are you fully utilizing it? So in my case, if we go to NVIDIA SMI, we can see that
[01:22:51.840 --> 01:23:02.640] I have eight GPUs. And each one of those GPUs is an 8100 SXM 80 gigabytes. So this is the GPU that
[01:23:02.640 --> 01:23:09.440] I have available to me in this box. Now, when I use to spin up these kinds of boxes, by the way,
[01:23:09.440 --> 01:23:16.800] my favorite place to go to is Lambda Labs. They do sponsor my development and that of my projects.
[01:23:16.800 --> 01:23:22.000] But this is my favorite place to go. And this is where you can spin up one of these machines,
[01:23:22.000 --> 01:23:26.160] and you pay per hour. And it's very, very simple. So I like to spin them up and then
[01:23:26.160 --> 01:23:31.360] connect VS code to it. And that's how I develop. Now, when we look at the 8100s that are available
[01:23:31.360 --> 01:23:40.640] here, 8100 80 gigabyte SXM is the GPU that I have here. And we have a bunch of numbers here for
[01:23:40.640 --> 01:23:48.960] how many calculations you can expect out of this GPU. So when I come over here, and I break in
[01:23:48.960 --> 01:23:54.800] right after here. So I'm breaking in right after we calculate the logits and the loss.
[01:23:56.560 --> 01:24:03.680] And the interesting thing I'd like you to note is when I do logits dot d type, this prints a torch
[01:24:03.680 --> 01:24:09.600] dot float 32. So by default in PyTorch, when you create tensors, and this is the case for all
[01:24:09.600 --> 01:24:13.680] the activations and for the parameters of the network and so on, by default, everything is in
[01:24:13.680 --> 01:24:21.200] float 32. That means that every single number activation or weight and so on is using a float
[01:24:21.200 --> 01:24:27.520] representation that has 32 bits. And that's actually quite a bit of memory. And it turns out empirically
[01:24:27.520 --> 01:24:32.400] that for deep learning as a computational workload, this is way too much. And deep learning and the
[01:24:32.400 --> 01:24:38.960] training of these networks can tolerate significantly lower precision. Not all computational workloads
[01:24:38.960 --> 01:24:46.000] can tolerate small precision. So for example, if we go back to the data sheet, you'll see that
[01:24:46.000 --> 01:24:51.360] actually these GPUs support up to FP64. And this is quite useful, I understand for a lot of
[01:24:51.360 --> 01:24:56.560] scientific computing applications. And there they really need this. But we don't need that
[01:24:56.560 --> 01:25:03.520] much precision for deep learning training. So currently we are here, FP 32. And with this code
[01:25:03.520 --> 01:25:10.800] as it is right now, we expect to get at at most 19.5 teraflops of performance. That means we're
[01:25:10.800 --> 01:25:16.480] doing 19.5 trillion operations, floating point operations. So this is floating point multiply,
[01:25:16.480 --> 01:25:23.680] add most, most likely. And so these are the floating point operations.
[01:25:23.680 --> 01:25:32.160] Now notice that if we are willing to go down in precision, so TF 32 is a lower precision format,
[01:25:32.160 --> 01:25:36.560] we're going to see in a second, you can actually get an 8x improvement here. And if you're willing
[01:25:36.560 --> 01:25:42.400] to go down to float 16 or B float 16, you can actually get times 16x performance,
[01:25:42.400 --> 01:25:48.080] all the way to 312 teraflops. You see here that NVIDIA likes to cite numbers that have
[01:25:48.080 --> 01:25:54.240] a asterisk here. This asterisk says with sparsity. But we are not going to be using sparsity in
[01:25:54.240 --> 01:25:58.800] our code. And I don't know that this is very widely used in the industry right now. So most
[01:25:58.800 --> 01:26:04.240] people look at this number here without sparsity. And you'll notice that we could have got even
[01:26:04.240 --> 01:26:12.800] more here. But this is int8. And int8 is used for inference, not for training. Because int8 has a
[01:26:12.800 --> 01:26:24.240] it basically has uniform spacing. And we actually require a float so that we get a better match
[01:26:24.240 --> 01:26:30.720] to the normal distributions that occur during training of neural logs, where both activations
[01:26:30.720 --> 01:26:36.320] and weights are distributed as a normal distribution. And so floating points are really important to
[01:26:36.320 --> 01:26:43.200] match that representation. So we're not typically using int8 for training, but we are using it for
[01:26:43.200 --> 01:26:49.280] inference. And if we bring down the precision, we can get a lot more teraflops out of the tensor
[01:26:49.280 --> 01:26:53.840] course available in the GPUs. We'll talk about that in a second. But in addition to that, if all
[01:26:53.840 --> 01:26:59.680] of these numbers have fewer bits of representation, it's going to be much easier to move them around.
[01:27:00.240 --> 01:27:03.920] And that's where we start to get into the memory bandwidth and the memory of the model.
[01:27:03.920 --> 01:27:08.800] So not only do we have a finite capacity of the number of bits that our GPU can store,
[01:27:08.800 --> 01:27:13.360] but in addition to that, there's a speed with which you can access this memory.
[01:27:13.360 --> 01:27:20.080] And you have a certain memory bandwidth. It's a very precious resource. And in fact, many of the
[01:27:20.080 --> 01:27:25.520] deep learning workloads for training are memory bound. And what that means is actually that the
[01:27:25.520 --> 01:27:30.400] tensor course that do all these extremely fast multiplications, most of the time they're waiting
[01:27:30.400 --> 01:27:37.120] around their idle, because we can't feed them with data fast enough. We can't load the data fast
[01:27:37.120 --> 01:27:41.440] enough for memory. So typical utilizations of your hardware, if you're getting 60%
[01:27:41.440 --> 01:27:48.400] utilization, you're actually doing extremely well. So half of the time in a well-tuned application,
[01:27:48.400 --> 01:27:52.960] your tensor course are not doing multiplies because the data is not available. So the memory
[01:27:52.960 --> 01:27:58.080] bandwidth here is extremely important as well. And if we come down in the precision for all the
[01:27:58.080 --> 01:28:04.240] floats, all the numbers, weights and activations, suddenly require less memory. So we can store more
[01:28:04.240 --> 01:28:09.840] and we can access it faster. So everything speeds up and it's amazing. And now let's reap the
[01:28:09.840 --> 01:28:16.000] benefits of it. And let's first look at the tensor float 32 format. Okay, so first of all,
[01:28:16.000 --> 01:28:22.160] what are tensor cores? Well, tensor cores, tensor core is just an instruction in the A100
[01:28:22.160 --> 01:28:28.480] architecture, right? So what it does is it does basically a little four by four matrix multiply.
[01:28:28.480 --> 01:28:36.160] So this is just matrix multiplication here of four by four matrices. And there are multiple
[01:28:36.160 --> 01:28:42.560] configurations as to what precision any of these matrices are. In what precision the internal
[01:28:42.560 --> 01:28:47.760] accumulate happens. And then what is the output precision, input precision, etc. So there's a few
[01:28:47.760 --> 01:28:53.280] switches, but it's basically a four by four multiply. And then anytime we have any operations that
[01:28:53.280 --> 01:28:59.360] require matrix multiplication, they get broken up into these into this instruction of little four
[01:28:59.360 --> 01:29:03.680] by four multiply. And so everything gets broken up into this instruction because it's the fastest
[01:29:03.680 --> 01:29:08.560] way to multiply matrices. And it turns out that most of the computational work that we're doing up
[01:29:08.560 --> 01:29:14.240] above, all of it really is matrix multiplication. Most of the work computationally happens in the
[01:29:14.240 --> 01:29:22.560] linear layers, linear, linear, etc. There's a few things sandwiched in between. So there's some
[01:29:22.560 --> 01:29:27.680] additions in residuals, there's some galude nonlinearities, there's some layer norms, etc.
[01:29:27.680 --> 01:29:32.720] But if you just time them, you'll see that these are nothing. Like basically, the intra transformer
[01:29:32.720 --> 01:29:39.520] is just a bunch of matrix multiplications really. And especially at this small scale 124 million
[01:29:39.520 --> 01:29:45.120] parameter model. Actually, the biggest matrix multiplication by far is the classifier layer at
[01:29:45.120 --> 01:29:52.800] the top. That is a massive matrix multiply of going from 768 to 50,257. And that matrix multiply
[01:29:52.800 --> 01:29:59.280] dominates anything else that happens in that network roughly speaking. So it's matrix multiplies
[01:29:59.280 --> 01:30:05.040] that become a lot faster, which are hidden inside our linear layers. And they're accelerated through
[01:30:05.040 --> 01:30:09.760] tensor cores. Now, the best reference I would say for tensor cores is basically just go to the
[01:30:09.760 --> 01:30:17.840] 8100 architecture white paper. And then it's pretty detailed. But I think people
[01:30:17.840 --> 01:30:24.000] select relatively readable mostly if you have understand what's happening. So figure nine,
[01:30:24.000 --> 01:30:30.000] tensor float 32. So this is the explanation basically for TF32 and what happens here.
[01:30:31.200 --> 01:30:35.840] And you see that there's many configuration options here available. So the input operands
[01:30:35.840 --> 01:30:44.400] and what precision are they in? The accumulator and what basically the internal representation
[01:30:44.400 --> 01:30:48.480] within the instruction, when you do the accumulate of this matrix multiplication.
[01:30:48.480 --> 01:30:54.960] So the intermediate plus equals of the intermediate little vector multiplies here,
[01:30:54.960 --> 01:31:01.520] that all happens in FB32. And then this is an eight X improvement as I mentioned to the
[01:31:01.520 --> 01:31:06.960] ops that we get. So TF32 specifically, we're looking at this row here. And the way this works is
[01:31:06.960 --> 01:31:19.040] normally FB32 has 32 bits. TF32 is the exact same bits. We have one signed bit,
[01:31:19.040 --> 01:31:25.840] we have eight exponent bits, except the mantissa bits get cropped in the float. And so basically,
[01:31:25.840 --> 01:31:33.920] we end up with just 19 bits instead of 32 bits, because the last 13 bits get truncated, they get
[01:31:33.920 --> 01:31:40.160] dropped. And all of this is internal to the instruction. So none of it is visible to anything
[01:31:40.160 --> 01:31:45.520] in our PyTorch. None of our PyTorch code will change. All of the numbers will look identical.
[01:31:45.520 --> 01:31:51.280] It's just that when you call the tensor core instruction internally in the hardware,
[01:31:51.280 --> 01:31:57.840] it will crop out these 13 bits. And that allows it to calculate this little matrix
[01:31:57.840 --> 01:32:04.080] multiply significantly faster, eight X faster. Now, of course, this speed up comes at a cost.
[01:32:04.080 --> 01:32:09.360] And the cost is that we are reducing the precision. Our accumulate is still in that P32. Our output
[01:32:09.360 --> 01:32:16.160] is FB32. Our inputs are FB32. But internally, things get truncated in the operands to perform
[01:32:16.160 --> 01:32:20.240] the operation faster. And so our results are starting to be a bit more approximate.
[01:32:20.240 --> 01:32:23.520] But empirically, when you actually train with this, you basically can't tell the difference.
[01:32:23.520 --> 01:32:29.200] So the reason I like TF32 is because if you can tolerate a little bit of a precision fudge,
[01:32:29.200 --> 01:32:36.320] then this is pre like none of your code sees this. It's fully internal to the operation.
[01:32:36.320 --> 01:32:40.880] And the operation to you just go to eight X faster, and it's a bit more approximate.
[01:32:40.880 --> 01:32:46.320] And so it's a pretty sweet spot, I would say an optimization. And let's see what that looks like
[01:32:46.320 --> 01:32:53.280] first. So I've set up our codes to just time the iterations. So import time, I changed the hyper
[01:32:53.280 --> 01:32:57.920] parameters so that we have something a bit more that reflects kind of workload that we want to
[01:32:57.920 --> 01:33:03.360] run, because we want to do a fairly large run at the end of this. So let's use batch size 16.
[01:33:03.360 --> 01:33:08.480] And let's now use the actual GPT to maximum sequence length of 1024 tokens.
[01:33:08.480 --> 01:33:16.640] So this is the configuration. And then for 50 iterations, I'm just doing something very lazy
[01:33:16.640 --> 01:33:22.240] here. I'm doing time that time to get the current time. And then this is the optimization loop.
[01:33:22.240 --> 01:33:30.400] And now I want to time how long this takes. Now, one issue with working with GPUs is that
[01:33:31.040 --> 01:33:39.200] as your CPU, when your CPU runs, it's just scheduling work on GPU. It's ordering some work, right?
[01:33:39.200 --> 01:33:44.080] And so it sends a request and then it continues running. And so we can actually, it can happen
[01:33:44.080 --> 01:33:51.360] sometimes that we sort of speed through this. And we queue up a lot of kernels to run on the GPU.
[01:33:51.360 --> 01:33:55.520] And then the CPU sort of like gets here and takes time at that time. But actually the GPU is still
[01:33:55.520 --> 01:34:00.560] running, because it takes a time to actually work through the work that was scheduled to run.
[01:34:01.440 --> 01:34:06.080] And so you're just building up a queue for the GPU. And so actually, if you need to,
[01:34:06.080 --> 01:34:13.120] you want to wait, and this will wait for the GPU to finish all the work that was scheduled to run
[01:34:13.120 --> 01:34:18.960] up above here. And then we can actually take the time. So basically, we're waiting for the GPU to
[01:34:18.960 --> 01:34:26.320] stop this iteration, take the time, and then we're going to just print it. So here I'm going to run
[01:34:26.320 --> 01:34:32.000] the training loop. And here on the right, I'm watching NVIDIA SMI. So we start off at zero.
[01:34:32.000 --> 01:34:37.520] We're not using the GPU. And then by default, PyTorch will use GPU zero. So we see that it gets
[01:34:37.520 --> 01:34:44.800] filled up. And we're using 35 gigabytes out of 80 gigabytes available. And then here on the left,
[01:34:44.800 --> 01:34:52.240] we see that because we cranked up the batch size, now it's only 20 batches to do a single epoch on
[01:34:52.240 --> 01:34:57.840] our tiny Shakespeare. And we see that we're seeing roughly 1000 milliseconds per iteration here.
[01:34:57.840 --> 01:35:05.200] So the first iteration sometimes is slower. And that's because PyTorch might be doing a lot of
[01:35:05.200 --> 01:35:10.240] initializations here on the very first iteration. And so it's probably initializing all these tensors
[01:35:10.240 --> 01:35:15.680] and buffers to hold the gradients. And I'm not sure all the work that happens here. But
[01:35:15.680 --> 01:35:19.360] this could be a slower iteration. When you're timing your logic, you always want to be careful
[01:35:19.360 --> 01:35:26.240] with that. But basically, we're seeing 1000 milliseconds per iteration. And so this will run for roughly
[01:35:26.240 --> 01:35:32.880] 50 seconds as we have it right now. So that's our baseline in float 32. One more thing I wanted
[01:35:32.880 --> 01:35:38.080] to mention is that if this doesn't fit into your GPU, and you're getting out of memory errors,
[01:35:38.080 --> 01:35:43.120] then start decreasing your batch size until things fit. So instead of 16, try eight or four,
[01:35:43.120 --> 01:35:49.120] or whatever you need to fit the batch into your GPU. And if you have a bigger GPU, you can actually
[01:35:49.120 --> 01:35:55.680] potentially get away with 32 and so on. By default, you want to basically max out the batch size that
[01:35:55.680 --> 01:36:01.760] fits on your GPU. And you want to keep it nice numbers. So use numbers that have lots of powers
[01:36:01.760 --> 01:36:09.280] of two in them. So 16 is a good number, eight, 24, 32, 48, these are nice numbers. But don't
[01:36:09.280 --> 01:36:14.160] do something like 17, because that will run very inefficiently on the GPU. And we're going to see
[01:36:14.160 --> 01:36:22.080] that a bit later as well. So for now, let's just stick with 16, 1024. And the one thing that I added
[01:36:22.080 --> 01:36:28.480] also here and I ran it again, is I'm calculating tokens per second throughput during training,
[01:36:28.480 --> 01:36:34.720] because we might end up changing the batch size around over time. But tokens per second is the
[01:36:34.720 --> 01:36:39.280] objective measure that we actually really care about. How many tokens of data are we training on?
[01:36:39.280 --> 01:36:43.760] And what is the throughput of tokens that we're getting in our optimization? So right now we're
[01:36:43.760 --> 01:36:49.760] processing and training on 163,000 tokens per second roughly. And that's a bit more objective
[01:36:49.760 --> 01:36:56.880] metric. Okay, so let's now enable TF 32. Now luckily PyTorch makes this fairly easy for us.
[01:36:56.880 --> 01:37:04.000] And to enable TF 32, you just need to do a single line. And it's this. And when we go to the PyTorch
[01:37:04.000 --> 01:37:08.880] documentation here for this function, basically this tells PyTorch what kind of kernels to run.
[01:37:09.520 --> 01:37:16.160] And by default, I believe it is highest, highest precision for Matmull. And that means that everything
[01:37:16.160 --> 01:37:21.360] happens in float 32, just like it did before. But if we set it too high, as we do right now,
[01:37:21.360 --> 01:37:30.000] matrix multiplications will now use tensor float 32, once available. My GPU is the A100. So it's
[01:37:30.000 --> 01:37:36.400] an ampere series. And therefore, TF 32 is available. If you're an older GPU, this might not be available
[01:37:36.400 --> 01:37:41.760] for you. But for my GPU, it's available. And so what I expect my Torx to do is that every single
[01:37:41.760 --> 01:37:46.800] place where we see an n dot linear inside there, there's a matrix multiplication. And I expect
[01:37:46.800 --> 01:37:53.600] that matrix multiplication now to be running on tensor course, utilizing the TF 32 precision.
[01:37:53.600 --> 01:38:00.720] So this is the single line of change that is, I believe, necessary. And let's rerun this. Now,
[01:38:00.720 --> 01:38:06.480] we saw that in terms of the throughput that is promised to us, we're supposed to be getting
[01:38:06.480 --> 01:38:18.560] 8x roughly. So let's see what happens. And that 8x came from here, right? 8x. And it also came from
[01:38:18.560 --> 01:38:26.560] looking at it here 156 T flops instead of 19.5. Okay, so what actually happened?
[01:38:27.440 --> 01:38:36.480] So we're seeing that our throughput roughly three X, not eight X. So we are going from 1000 milliseconds,
[01:38:36.480 --> 01:38:41.360] we're going down to 300 milliseconds. And our throughput is now about 50,000 tokens per second.
[01:38:41.360 --> 01:38:46.240] So we have a roughly three X instead of eight X. So what happened? And basically, what's happening
[01:38:46.240 --> 01:38:54.240] here is, again, a lot of these workloads are memory bound. And so even though the TF 32 offers in
[01:38:54.240 --> 01:39:01.920] principle, a lot faster throughput, all of these numbers everywhere are still float 32s. And it's
[01:39:01.920 --> 01:39:06.560] float 32 numbers that are being shipped all over the place through the memory system. And it's just
[01:39:06.560 --> 01:39:10.560] costing us way too much time to shuttle around all this data. And so even though we've made the
[01:39:10.560 --> 01:39:15.680] multiply itself much faster, we are memory bound, and we're not actually seeing the full benefit
[01:39:15.680 --> 01:39:22.800] that would come from this napkin map here. That said, we are getting one three X faster throughput.
[01:39:22.800 --> 01:39:30.400] And this is free. Single line of code in PyTorch, all your variables are still float 32 everywhere.
[01:39:30.400 --> 01:39:34.800] It just runs faster, and it's slightly more approximate, but we're not going to notice it
[01:39:34.800 --> 01:39:44.480] basically. So that's TF 32. Okay, so let's now continue. So we've exercised this row, and we saw
[01:39:44.480 --> 01:39:49.440] that we can crop out some of the precision inside the operation itself. But we saw that we're still
[01:39:49.440 --> 01:39:53.360] memory bound. We're still moving around all these floats, right, otherwise, and we're paying that
[01:39:53.360 --> 01:39:58.320] cost because of this. So let's now decrease the amount of stuff that we're going to be moving
[01:39:58.320 --> 01:40:04.560] around. And we're going to do that by dropping down to B float 16. So we're only going to be
[01:40:04.560 --> 01:40:10.000] maintaining 16 bits per float. And we're going to use the B float 16. Now let's plane in a bit
[01:40:10.000 --> 01:40:16.640] at B 16 difference. And we're going to be in this row. So when we go back to the documentation
[01:40:16.640 --> 01:40:25.280] here for the 100, we see here the precision that are available. And this is the original
[01:40:25.280 --> 01:40:32.960] copy 32. The TF 32 crops out the precision. And then here in BF 16, you see that it is very similar
[01:40:32.960 --> 01:40:39.600] to TF 32, but it's even more aggressive in cropping off the precision, the mantissa of this float.
[01:40:39.600 --> 01:40:44.800] So the important thing with B float 16 is that the exponent bits and the sine bit, of course,
[01:40:44.800 --> 01:40:50.160] remain unchanged. So if you're familiar with your float numbers, and I think this should
[01:40:50.160 --> 01:40:56.720] should probably be an entire video by itself, the exponent sets the range that you can represent
[01:40:56.720 --> 01:41:02.640] of your numbers. And the precision is how much precision you have for your numbers. And so
[01:41:02.640 --> 01:41:10.160] the range of numbers is identical, but we can we have fewer possibilities within that range,
[01:41:10.720 --> 01:41:14.400] because we are truncating the mantissa. So we have less precision in that range.
[01:41:14.400 --> 01:41:20.720] Oh, that means is that things are actually fairly nice, because we have the original range of
[01:41:20.720 --> 01:41:26.880] numbers that are representable in float. But we just have less precision for it. And the difference
[01:41:26.880 --> 01:41:33.760] with FB 16 is that they actually touch and change the range. So FB 16 cannot represent the full range
[01:41:33.760 --> 01:41:39.360] of FB 32. It has a reduced range. And that's where you start to actually run into issues,
[01:41:39.360 --> 01:41:44.560] because now you need these gradient scalars and things like that. And I'm not going to detail
[01:41:44.560 --> 01:41:50.720] off that in this video, because that's a whole video by itself. But FB 16 actually historically
[01:41:50.720 --> 01:41:56.800] came first, that was available in the Volta series before Ampere. And so FB 16 came first,
[01:41:56.800 --> 01:42:01.200] and everyone started to train that P16. But everyone had to use all these gradient scaling
[01:42:01.200 --> 01:42:05.920] operations, which are kind of annoying. And it's an additional source of state and complexity.
[01:42:05.920 --> 01:42:11.920] And the reason for that was because the exponent range was reduced in FB 16. So that's the IEEE
[01:42:11.920 --> 01:42:18.240] FB 16 spec. And then they came out with BF 16 and the Ampere. And they made it much simpler,
[01:42:18.240 --> 01:42:22.480] because we're just truncating mantissa, we have the exact same range. And we do not need gradient
[01:42:22.480 --> 01:42:28.960] scalars. So everything is much much simpler. Now, when we do use BF 16, though, we are impacting
[01:42:28.960 --> 01:42:35.120] the numbers that we might be seeing in our PyTorch code. This change is not just local
[01:42:35.120 --> 01:42:42.480] to the operation itself. So let's see how that works. There's some documentation here that,
[01:42:42.480 --> 01:42:46.960] so I think this is probably the best page to explain how to use mixed precision in PyTorch.
[01:42:46.960 --> 01:42:52.960] Because there are many other tutorials and so on, even within PyTorch documentation,
[01:42:52.960 --> 01:42:56.320] there are a lot more confusing. And so I recommend specifically this one,
[01:42:57.200 --> 01:43:01.680] because there's five other copies that I would not recommend. And then when we come here,
[01:43:01.680 --> 01:43:06.560] ignore everything about everything, ignore everything about the gradient scalars,
[01:43:06.560 --> 01:43:15.280] and only look at Torch.autocast. And basically, also, this comes to a single line of code at the end.
[01:43:15.280 --> 01:43:22.720] So this is the context manager that we want. And we want to use that in our network. When you
[01:43:22.720 --> 01:43:30.400] click into the Torch.autocast, autocasting, it has a few more, a bit more guideline for you.
[01:43:30.400 --> 01:43:36.560] So it's telling you do not call BF16 on any of your tensors, just use autocast,
[01:43:36.560 --> 01:43:42.480] and only surround the forward pass of the model and the loss calculation. And that's the only
[01:43:42.480 --> 01:43:46.400] two things that you should be surrounding. Lead the backward and the optimizer step alone.
[01:43:46.400 --> 01:43:51.360] So that's the guidance that comes from the PyTorch team. So we're going to follow that guidance.
[01:43:51.360 --> 01:43:55.680] And for us, because the loss calculation is inside of the model forward pass for us,
[01:43:55.680 --> 01:44:01.440] we are going to be doing this. And then we don't want to be using Torch float 16,
[01:44:01.440 --> 01:44:04.720] because if we do that, we need to start using gradient scalars as well.
[01:44:04.720 --> 01:44:09.120] So we are going to be using the float 16. This is only possible to do an ampere.
[01:44:09.120 --> 01:44:14.400] But this means that the changes are extremely minimal, like basically just this one line of code.
[01:44:16.800 --> 01:44:24.480] Let me first break in to here before we actually run this. So right after logins,
[01:44:24.480 --> 01:44:30.000] I'd like to show you that different from the TF32 that we saw, this is actually going to impact
[01:44:30.000 --> 01:44:38.160] our tensors. So this logis tensor, if we now look at this and we look at the D type,
[01:44:38.160 --> 01:44:44.800] we suddenly see that this is now BF16. It's not float 32 anymore. So our activations have been
[01:44:44.800 --> 01:44:52.000] changed. The activations tensor is now BF16, but not everything has changed. So model dot transformer
[01:44:52.000 --> 01:45:01.680] dot WTE. This is the weight token embedding table. It has a dot weight inside it. And the D type of
[01:45:01.680 --> 01:45:08.880] this weight, this parameter is still torch float 32. So our parameters seem to still be in float 32,
[01:45:08.880 --> 01:45:14.480] but our activations, the logits are now in BF16. So clearly, this is why we get the mixed
[01:45:14.800 --> 01:45:20.000] precision. Some things pytorch is keeping in float 32. Some things pytorch is converting
[01:45:20.000 --> 01:45:28.640] to lower precision. And what gets converted, at what point is not super clear. I remember scrolling
[01:45:28.640 --> 01:45:40.720] down. Is it here? Okay, I can't find it. I thought it was here. Okay, there we go.
[01:45:42.000 --> 01:45:49.040] So there are a few docs on when you're using this autocast, what gets converted to BF16 and when.
[01:45:49.040 --> 01:45:54.880] So for example, only these matrix multiply like operations get converted to BF16. But a lot of
[01:45:54.880 --> 01:45:59.440] operations remain in float 32. So in particular, a lot of normalizations, like layer norms and
[01:45:59.440 --> 01:46:05.760] things like that, not all of those layers might be converted. So only some layers selectively
[01:46:05.760 --> 01:46:14.000] would be running BF16. But things like softmax, layer norms, log, log softmax. So loss function
[01:46:14.000 --> 01:46:18.480] calculations, a lot of those things might remain in float 32, because they are more susceptible to
[01:46:18.480 --> 01:46:27.120] precision changes. Matrix multiplies are fairly robust to precision changes. So some parts of the
[01:46:27.120 --> 01:46:35.040] network are impacted more or less by the precision change. So basically, only some parts of the of
[01:46:35.040 --> 01:46:40.240] the model are running in reduced precision. Let's take it for a spin. And let's actually see what kind
[01:46:40.240 --> 01:46:52.800] of improvement we achieve here. Okay, so we used to be 333 milliseconds. We're now 300.
[01:46:52.800 --> 01:46:58.800] And we used to be somewhere around 50,000 tokens per second. We're now 55. So we're definitely running
[01:46:58.800 --> 01:47:04.400] faster, but maybe not a lot faster. And that's because there are still many, many bottlenecks
[01:47:04.400 --> 01:47:09.280] in our GPT two, we're just getting started. But we have dropped down the precision as far as we can
[01:47:09.280 --> 01:47:15.200] with my current GPU, which is a 100, we're using PyTorch AutoCast. Unfortunately, I don't
[01:47:15.200 --> 01:47:21.760] actually exactly know what PyTorch AutoCast does. I don't actually know exactly what's in BF16,
[01:47:21.760 --> 01:47:27.280] what's in float 32, we could go in and we could start scrutinize it. But these are the kinds of
[01:47:27.280 --> 01:47:33.600] rules that PyTorch has internally. And unfortunately, they don't document it very well. So we're not
[01:47:33.600 --> 01:47:39.520] gonna go into that in too much detail. But for now, we are training in BF16. We do not need a gradient
[01:47:39.520 --> 01:47:46.880] scaler. And the reason things are running faster is because we are able to run tensor course in
[01:47:46.880 --> 01:47:53.920] BF16 now. That means we are in this row. But we are also paying in precision for this.
[01:47:53.920 --> 01:48:01.520] So we expect slightly less accurate results with respect to the original FP32. But empirically,
[01:48:01.520 --> 01:48:06.800] in many cases, this is a worth it kind of trade out, because it allows you to run faster. And you
[01:48:06.800 --> 01:48:15.440] could, for example, train longer and make up for that precision decrease. So that's BF16 for now.
[01:48:15.440 --> 01:48:20.960] Okay, so as we can see, we are currently at about 300 milliseconds per iteration. And we're now
[01:48:20.960 --> 01:48:24.960] gonna reach for some really heavy weapons in the PyTorch arsenal. And in particular,
[01:48:24.960 --> 01:48:30.480] we're going to introduce Torch. Compile. So Torch. Compile is really quite incredible infrastructure
[01:48:30.480 --> 01:48:36.240] from the PyTorch team. And it's basically a compiler for neural networks. Like it's almost like GCC
[01:48:36.240 --> 01:48:44.480] for CNC++ code. And this is just this GCC of neural nets. So came out a while ago and extremely
[01:48:44.480 --> 01:48:51.040] simple to use. The way to use Torch compile is to do this. It's a single line of code to compile
[01:48:51.040 --> 01:48:57.280] your model and return it. Now this line of code will cost you compilation time. But as you might
[01:48:57.280 --> 01:49:01.920] guess, it's going to make the code a lot faster. So let's actually run that. And because this will
[01:49:01.920 --> 01:49:06.720] take some time to run, let's currently remember we're 300 milliseconds. And we'll see what happens.
[01:49:06.720 --> 01:49:11.680] Now, while this is running, I'd like to explain a little bit of what Torch.compile does under the
[01:49:11.680 --> 01:49:18.160] hood. So feel free to read this page of PyTorch. But basically, there's no real good reason for
[01:49:18.160 --> 01:49:22.960] you to not use Torch compile in your PyTorch. I kind of feel like you should be using it almost
[01:49:22.960 --> 01:49:29.120] by default, unless you're debugging and you want your code to run really fast. And there's one
[01:49:29.120 --> 01:49:32.880] line here in Torch compile that I found that actually kind of like gets to why this is faster.
[01:49:32.880 --> 01:49:39.680] Speedup comes from reducing Python overhead and GPU read writes. So let me unpack that a little bit.
[01:49:39.680 --> 01:49:46.480] Okay, here we are. Okay, so we're going from 300 milliseconds. We're now running at 129 milliseconds.
[01:49:47.760 --> 01:49:54.080] So this is 300 divided 129, about 2.3x improvement from a single line of code in PyTorch.
[01:49:54.080 --> 01:49:59.520] So quite incredible. So what is happening? What's happening under the hood? Well, when you pass
[01:49:59.520 --> 01:50:06.720] the model to Torch compile, what we have here in this end module, this is really just the algorithmic
[01:50:06.720 --> 01:50:13.200] description of what we'd like to happen in our network. And Torch compile will analyze the entire
[01:50:13.200 --> 01:50:18.480] thing. And it will look at what operations you like to use. And with the benefit of knowing
[01:50:18.480 --> 01:50:23.680] exactly what's going to happen, it doesn't have to run in what's called eager mode. It doesn't
[01:50:23.680 --> 01:50:30.160] have to just kind of like go layer by layer, like the Python interpreter normally would start at the
[01:50:30.160 --> 01:50:36.480] forward. And the Python interpreter will go, okay, let's do this operation. And then let's do that
[01:50:36.480 --> 01:50:41.840] operation. And it kind of materializes all the operations as it goes through. So these
[01:50:42.400 --> 01:50:48.000] calculations are dispatched and run in this order. And the Python interpreter and this code doesn't
[01:50:48.000 --> 01:50:52.640] know what kind of operations are going to happen later. But Torch compile sees your entire code
[01:50:52.640 --> 01:50:58.160] at the same time. And it's able to know what operations you intend to run. And it will kind of
[01:50:58.160 --> 01:51:02.800] optimize that process. The first thing we'll do is we'll take out the Python interpreter from
[01:51:02.800 --> 01:51:07.840] the forward pass entirely. And it will kind of compile this entire neural net as a single object
[01:51:07.840 --> 01:51:12.720] with no Python interpreter involved. So it knows exactly what's going to run, it will just run that.
[01:51:12.720 --> 01:51:20.400] And it's all going to be running an efficient code. The second thing that happens is this read
[01:51:20.400 --> 01:51:24.480] write that they mentioned very briefly. So a good example of that, I think, is the
[01:51:24.480 --> 01:51:30.240] GALU nonlinearity that we've been looking at. So here we use the N and GALU. Now this here
[01:51:30.240 --> 01:51:37.600] is me basically just breaking up the N and GALU, which you remember has this formula. So this
[01:51:37.600 --> 01:51:42.160] here is the equivalent implementation to what's happening inside GALU algorithmically, it's identical.
[01:51:42.160 --> 01:51:49.600] Now, by default, if we just were using this instead of an N dot GALU here, what would happen
[01:51:49.600 --> 01:51:54.320] without Torch compile? Well, the Python interpreter would make its way here. And then it would be,
[01:51:54.320 --> 01:52:00.160] okay, well, there's an input. Well, let me first, let me raise this input to the third power.
[01:52:00.160 --> 01:52:04.080] And it's going to dispatch a kernel that takes your input and raises to the third power.
[01:52:04.800 --> 01:52:11.360] And that kernel will run. And when this kernel runs, what ends up happening is this input is
[01:52:11.360 --> 01:52:17.360] stored in the memory of the GPU. So here's a helpful example of the layout of what's happening,
[01:52:17.360 --> 01:52:22.560] right? You have your CPU, this is in every single computer. There's a few cores in there,
[01:52:22.560 --> 01:52:29.200] and you have your RAM, your memory. And the CPU can talk to the memory, and this is all well known.
[01:52:29.200 --> 01:52:33.120] But now we've added the GPU. And the GPU is a slightly different architecture, of course,
[01:52:33.120 --> 01:52:38.560] they can communicate. And it's different in that it's got a lot more cores than a CPU.
[01:52:38.560 --> 01:52:43.600] All of those cores are individually a lot simpler, too. But it also has memory, right?
[01:52:43.600 --> 01:52:50.640] This high bandwidth memory. Sorry if I'm botching it, HBM. I don't even know what that stands for.
[01:52:50.640 --> 01:52:55.440] I'm just realizing now. But this is the memory, and it's very equivalent to
[01:52:55.440 --> 01:53:01.520] RAM, basically, in the computer. And what's happening is that input is living in the memory.
[01:53:02.160 --> 01:53:11.840] And when you do input cubed, this has to travel to the GPU, to the course, and to all the caches
[01:53:11.840 --> 01:53:18.960] and registers on the actual chip of this GPU. And it has to calculate all the elements to the third,
[01:53:18.960 --> 01:53:24.960] and then it saves the result back to the memory. And it's this travel time that actually causes
[01:53:24.960 --> 01:53:30.960] a lot of issues. So here, remember this memory bandwidth? We can communicate about two terabytes
[01:53:30.960 --> 01:53:37.440] per second, which is a lot. But also, we have to traverse this link, and it's very slow. So here
[01:53:37.440 --> 01:53:42.080] on the GPU, we're on chip, and everything is super-passed within the chip. But going to the memory is
[01:53:42.080 --> 01:53:48.480] extremely expensive, takes extremely long amount of time. And so we load the input, do the calculations,
[01:53:48.480 --> 01:53:54.400] and load back the output. And this round trip takes a lot of time. And now right after we do that,
[01:53:54.400 --> 01:54:01.360] we multiply by this constant. So what happens then is we dispatch another kernel, and then the result
[01:54:01.360 --> 01:54:06.640] travels back, all the elements get multiplied by constant, and then the result travel back to
[01:54:06.640 --> 01:54:13.360] the memory. And then we take the result, and we add back input. And so this entire thing, again,
[01:54:13.360 --> 01:54:19.360] travels to the GPU, adds the inputs and gets written back. So we're making all these
[01:54:19.360 --> 01:54:24.560] round trips from the memory to actually where the computation happens. Because all the tensor cores
[01:54:24.560 --> 01:54:29.360] and ALUs and everything like that is all stored on the chip in the GPU. So we're doing a ton of
[01:54:29.360 --> 01:54:35.920] round trips. And PyTorch, without using Church Compile, doesn't know to optimize this, because
[01:54:35.920 --> 01:54:40.240] it doesn't know what kind of operations you're running later. You're just telling it, raise the
[01:54:40.240 --> 01:54:44.400] power to the third, then do this, then do that, and it will just do that in that sequence.
[01:54:44.400 --> 01:54:48.960] But Church Compile sees your entire code. It will come here, and it will realize, wait, all of these
[01:54:48.960 --> 01:54:54.320] are element-wise operations. And actually, what I'm going to do is I'm going to do a single
[01:54:54.320 --> 01:54:59.920] trip of input to the GPU. Then for every single element, I'm going to do all of these operations
[01:54:59.920 --> 01:55:06.320] while that memory is on the GPU, or chunks of it, rather. And then I'm going to write back a single
[01:55:06.320 --> 01:55:10.160] time. So we're not going to have these round trips. And that's one example of what's called
[01:55:10.160 --> 01:55:15.360] kernel fusion, and is a major way in which everything is sped up. So basically, if you have your
[01:55:15.360 --> 01:55:19.280] benefit of onset, and you know exactly what you're going to compute, you can optimize your
[01:55:19.280 --> 01:55:23.680] round trips to the memory, and you're not going to pay the memory bandwidth cost.
[01:55:23.680 --> 01:55:28.240] And that's fundamentally what makes some of these operations a lot faster, and what they mean by
[01:55:28.240 --> 01:55:36.480] read writes here. So let me erase this, because we are not using it. And yeah, we should be using
[01:55:36.480 --> 01:55:43.440] Church Compile. And our code is now significantly faster, and we're doing about 125,000 tokens per
[01:55:43.440 --> 01:55:49.200] second. But we still have a long way to go. Before we move on, I wanted to supplement the discussion
[01:55:49.200 --> 01:55:53.440] a little bit with a few more figures, because this is a complicated topic, but it's worth
[01:55:53.440 --> 01:55:58.000] understanding on the high level what's happening here. And I could probably spend an entire video
[01:55:58.000 --> 01:56:04.240] of like two hours on this, but just the preview of that basically. So this chip here, that is the
[01:56:04.240 --> 01:56:11.680] GPU. This chip is where older calculations happen mostly. But this chip also does have some memory
[01:56:11.680 --> 01:56:18.880] in it. But most of the memory by far is here in the high bandwidth memory, HBM, and is connected,
[01:56:18.880 --> 01:56:26.560] they're connected. But these are two separate chips, basically. Now, here, this is a zoom in of kind of
[01:56:26.560 --> 01:56:33.200] this cartoon diagram of a GPU. And what we're seeing here is number one, you see this HBM. I
[01:56:33.200 --> 01:56:38.080] realize it's probably very small for you. But on the sides here, it says HBM. And so that's the
[01:56:38.080 --> 01:56:45.360] links to the HBM. Now the HBM is again, off chip. On the chip, there are a large number of these
[01:56:45.360 --> 01:56:51.840] streaming multiprocessors. Every one of these is an SM. There's 120 of them in total. And this is
[01:56:51.840 --> 01:56:57.280] where a lot of the calculations happen. And this is a zoom in of a single individual SM. It has
[01:56:57.280 --> 01:57:01.120] these four quadrants. And see, for example, tensor core, this is where a lot of the matrix
[01:57:01.120 --> 01:57:06.160] multiply stuff happens. But there's all these other units to do all different kinds of calculations
[01:57:06.160 --> 01:57:13.440] for FP64, FP32, and for integers, and so on. Now, so we have all this logic here to the
[01:57:13.440 --> 01:57:18.400] calculations. But in addition to that, on the chip, there is memory sprinkled throughout the chip.
[01:57:18.400 --> 01:57:25.520] So L2 cache is some amount of memory that lives on the chip. And then on the SMs themselves,
[01:57:25.520 --> 01:57:31.360] there's L1 cache. I realize it's probably very small for you, but this blue bar is L1. And there's
[01:57:31.360 --> 01:57:38.080] also registers. And so there is memory stored here. But the way this memory is stored is very
[01:57:38.080 --> 01:57:44.400] different from the way memory stored in HBM. This is a very different implementation using
[01:57:44.400 --> 01:57:49.520] just in terms of like what the silicon looks like. It's a very different implementation.
[01:57:49.520 --> 01:57:55.280] So here you would be using transistors and capacitors. And here it's a very different
[01:57:55.280 --> 01:58:04.240] implementation with SRAM and what that looks like. But long story short, there is memory
[01:58:04.240 --> 01:58:10.320] inside the chip, but it's not a lot of memory. That's the critical point. So this is an example
[01:58:10.320 --> 01:58:15.760] diagram of a slightly different GPU, just like here, where it shows that, for example, typical
[01:58:15.760 --> 01:58:22.800] numbers for CPU DRM memory, which is this thing here, you might have one terabyte of disk, but it
[01:58:22.800 --> 01:58:27.040] would be extremely expensive to access, especially for a GPU. You have to go through the CPU here.
[01:58:27.040 --> 01:58:33.600] Now, next we have the HBM. So we have tens of gigabytes of HBM memory on a typical GPU here.
[01:58:33.600 --> 01:58:40.400] But as I mentioned, very expensive to access. And then on the chip itself, everything is extremely
[01:58:40.400 --> 01:58:46.080] fast within the chip, but we only have a couple of 10 megabytes of memory collectively throughout
[01:58:46.080 --> 01:58:51.840] the chip. And so there's just not enough space because the memory is very expensive on the chip.
[01:58:51.840 --> 01:58:55.920] And so there's not a lot of it, but it is lightning fast to access in relative terms.
[01:58:55.920 --> 01:59:02.240] And so basically, whenever we have these kernels, the more accurate picture of what's happening here
[01:59:02.240 --> 01:59:08.480] is that we take these inputs which live by default on the global memory, and now we need to perform
[01:59:08.480 --> 01:59:15.680] some calculation. So we start streaming the data from the global memory to the chip.
[01:59:15.680 --> 01:59:20.160] We perform the calculations on the chip and then stream it back and store it back to the global
[01:59:20.160 --> 01:59:26.160] memory. And so if we don't have torch compile, we are streaming the data through the chip doing
[01:59:26.160 --> 01:59:30.160] the calculations and saving to the memory. And we're doing those round trips many, many times.
[01:59:30.160 --> 01:59:35.760] But if it's torch compiled, then we start streaming the memory as before. But then while
[01:59:35.760 --> 01:59:43.040] we're on the chip, we have a chunk of the data that we're trying to process. So that chunk now
[01:59:43.040 --> 01:59:47.680] lives on the chip. While it's on the chip, it's extremely fast to operate on. So if we have kernel
[01:59:47.680 --> 01:59:52.880] fusion, we can do all the operations right there in an element-wise passion. And those are very
[01:59:52.880 --> 01:59:59.280] cheap. And then we do a single round trip back to the global memory. So operator fusion basically
[01:59:59.280 --> 02:00:04.000] allows you to keep your chunk of data on the chip and do lots of calculations on it before you
[02:00:04.000 --> 02:00:10.800] write it back. And that gives huge savings. And that's why torch compile ends up being a lot faster,
[02:00:10.800 --> 02:00:16.160] or that's one of the major reasons. So again, just a very brief intro to the memory hierarchy,
[02:00:16.160 --> 02:00:21.200] and roughly what torch compile does for you. Now torch compile is amazing, but there are
[02:00:21.200 --> 02:00:26.880] operations that torch compile will not find. And an amazing example of that is flash attention
[02:00:26.880 --> 02:00:33.440] to which we turn next. So flash attention comes from this paper from Stanford in 2022.
[02:00:33.440 --> 02:00:41.360] And it's this incredible algorithm for performing attention, and running it a lot faster. So flash
[02:00:41.360 --> 02:00:48.640] attention will come here, and we will take out these four lines. And flash attention implements
[02:00:48.640 --> 02:00:54.880] these four lines really, really quickly. And how does it do that? Well, flash attention is a
[02:00:54.880 --> 02:01:01.760] kernel fusion operation. So you see here, we have in this diagram, they're showing PyTorch,
[02:01:01.760 --> 02:01:06.880] and you have these four operations, they're including dropout, but we are not using dropout
[02:01:06.880 --> 02:01:12.880] here. So we just have these four lines of code here. And instead of those, we are fusing them
[02:01:12.880 --> 02:01:19.520] into a single fused kernel of flash attention. So it's a kernel fusion algorithm, but it's a
[02:01:19.520 --> 02:01:24.720] kernel fusion that torch compile cannot find. And the reason that it cannot find it is that it
[02:01:24.720 --> 02:01:30.160] requires an algorithmic rewrite of how attention is actually implemented here in this case.
[02:01:30.160 --> 02:01:35.280] And what's remarkable about it is that flash attention, actually, if you just count the number
[02:01:35.280 --> 02:01:41.920] of flops, flash attention does more flops than this attention here. But flash attention is
[02:01:41.920 --> 02:01:49.280] actually significantly faster. In fact, they cite 7.6 times faster potentially. And that's because
[02:01:49.280 --> 02:01:55.520] it is very mindful of the memory hierarchy, as I described it just now. And so it's very mindful
[02:01:55.520 --> 02:02:00.800] about what's in high bandwidth memory, what's in the shared memory. And it is very careful
[02:02:00.800 --> 02:02:06.240] with how it orchestrates to computation, such that we have fewer reads and writes to the
[02:02:06.240 --> 02:02:10.720] high bandwidth memory. And so even though we're doing more flops, the expensive part is their load
[02:02:10.720 --> 02:02:16.960] and store into HBM. And that's what they avoid. And so in particular, they do not ever materialize
[02:02:16.960 --> 02:02:23.200] this end-by-end attention matrix, this ATT here, a flash attention is designed such that this
[02:02:23.200 --> 02:02:28.880] matrix never gets materialized at any point, and it never gets read or written to the HBM.
[02:02:29.440 --> 02:02:34.480] And this is a very large matrix, right? Because this is where all the queries and keys interact,
[02:02:34.480 --> 02:02:42.240] and we're sort of getting, for each head, for each batch element, we're getting a T by T matrix
[02:02:42.240 --> 02:02:47.600] of attention, which is a million numbers, even for a single head, at a single batch index,
[02:02:47.600 --> 02:02:54.480] so basically this is a ton of memory, and this is never materialized. And the way that this is
[02:02:54.480 --> 02:03:00.720] achieved is that basically the fundamental algorithmic rewrite here relies on this online
[02:03:00.720 --> 02:03:05.680] softmax trick, which was proposed previously, and I'll show you the paper in a bit. And the
[02:03:05.680 --> 02:03:12.320] online softmax trick coming from a previous paper shows how you can incrementally evaluate
[02:03:12.320 --> 02:03:18.880] a softmax without having to sort of realize all of the inputs to the softmax of the normalization.
[02:03:18.880 --> 02:03:23.520] And you do that by having these intermediate variables m and l, and there's an update to them
[02:03:23.520 --> 02:03:29.840] that allows you to evaluate the softmax in an online manner. Now flash attention,
[02:03:29.840 --> 02:03:34.160] actually, so recently flash attention to came out as well, so I have that paper up here as well,
[02:03:34.160 --> 02:03:39.600] that has additional gains to how it calculates flash attention. And the original paper that
[02:03:39.600 --> 02:03:44.240] this is based on basically is this online normalization calculation for softmax,
[02:03:44.240 --> 02:03:49.520] and remarkably it came out of NVIDIA, and it came out of it like really early 2018,
[02:03:49.520 --> 02:03:55.760] so this is four years before flash attention. And this paper says that we propose a way
[02:03:55.760 --> 02:03:59.760] to compute the classical softmax with fewer memory accesses, and hypothesize that this
[02:03:59.760 --> 02:04:05.600] reduction in memory accesses should improve softmax performance on actual hardware. And so
[02:04:05.600 --> 02:04:10.480] they are extremely correct in this hypothesis. But it's really fascinating to me that they're
[02:04:10.480 --> 02:04:15.280] from NVIDIA, and that they had this realization, but they didn't actually take it to the actual
[02:04:15.280 --> 02:04:20.800] flash attention that had to come four years later from Stanford. So I don't fully understand the
[02:04:20.800 --> 02:04:26.560] historical how this happened historically, but they do basically propose this online update to
[02:04:26.560 --> 02:04:32.800] the softmax right here. And this is fundamentally what they reuse here to calculate the softmax
[02:04:32.800 --> 02:04:36.960] in a streaming manner. And then they realize that they can actually fuse all the other operations
[02:04:36.960 --> 02:04:42.240] with the online softmax calculation into a single fused kernel, flash attention,
[02:04:42.240 --> 02:04:48.960] and that's what we are about to use. So a great example I think of being aware of memory hierarchy,
[02:04:48.960 --> 02:04:53.040] the fact that flops don't matter, the entire memory access pattern matters,
[02:04:53.040 --> 02:04:57.760] and that torch compile is amazing, but there are many optimizations that are still available to us
[02:04:57.760 --> 02:05:03.120] that potentially torch compile cannot find. Maybe one day it could, but right now it seems like
[02:05:03.120 --> 02:05:07.360] a lot to ask. So here's what we're going to do. We're going to use flash attention,
[02:05:08.000 --> 02:05:13.680] and the way to do that basically in PyTorch is we are going to comment out these four lines,
[02:05:13.680 --> 02:05:19.600] and we're going to replace them with a single line. And here we are calling this compound
[02:05:19.600 --> 02:05:28.240] operation in PyTorch called scaled dot product attention. And PyTorch will call flash attention
[02:05:28.240 --> 02:05:33.680] when you use it in this way. I'm not actually 100% sure why torch compile doesn't realize that
[02:05:33.680 --> 02:05:39.040] these four lines should just call flash attention in this exact way. We have to do it again for it,
[02:05:39.040 --> 02:05:47.440] which in my opinion is a little bit odd, but here we are. So you have to use this compound up,
[02:05:47.440 --> 02:05:53.920] and let's wait for a few moments before torch compile gets around to it. And then let's remember
[02:05:53.920 --> 02:06:00.560] that we achieved 6.05661, I have it here. That's the loss we are expecting to see,
[02:06:01.360 --> 02:06:07.840] and we took 130 milliseconds before this change. So we're expecting to see the exact same result
[02:06:07.840 --> 02:06:14.320] by iteration 49, but we expect to see faster runtime, because flash attention is just an
[02:06:14.320 --> 02:06:18.320] algorithmic rewrite, and it's a faster kernel, but it doesn't actually change any of the computation,
[02:06:18.320 --> 02:06:23.440] and we should have the exact same optimization. So okay, so we're a lot faster. We're at about 95
[02:06:23.440 --> 02:06:32.720] milliseconds, and we achieve 6.058. Okay, so they're basically identical up to a floating point
[02:06:32.720 --> 02:06:39.120] fetch factor. So it's the identical computation, but it's significantly faster going from 130 to
[02:06:39.120 --> 02:06:48.800] roughly 96, and so this is 96 divided 130 ish, so this may be 27ish percent improvement.
[02:06:50.240 --> 02:06:56.560] So really interesting, and that is flash attention. Okay, we are now getting to one of my favorite
[02:06:56.560 --> 02:07:02.000] optimizations, and it is simultaneously the dumbest and the most brilliant optimization,
[02:07:02.000 --> 02:07:07.760] and it's always a little bit surprising to me. Anyway, so basically, I mentioned a few minutes
[02:07:07.760 --> 02:07:15.120] ago that there are some numbers that are nice, and some numbers that are ugly. So 64 is a beautiful
[02:07:15.120 --> 02:07:21.120] nice number. 128 is even nicer. 256 is beautiful. What makes these numbers beautiful is that there
[02:07:21.120 --> 02:07:27.360] are many powers of two inside them. You can divide by two many times, and examples of ugly numbers
[02:07:27.360 --> 02:07:32.560] are like 13 and 17 and something like that, prime numbers, numbers that are not even and so on.
[02:07:32.560 --> 02:07:36.880] And so pretty much you always want to use nice numbers in all of your code that deals with neural
[02:07:36.880 --> 02:07:43.280] networks or CUDA, because everything in CUDA works in sort of like powers of two, and lots of kernels
[02:07:43.840 --> 02:07:50.240] are written in terms of powers of two, and there are lots of blocks of sizes 16 and 64 and so on.
[02:07:50.240 --> 02:07:54.720] So everything is written in those terms, and you always have special case handling for all kinds
[02:07:54.720 --> 02:08:01.680] of logic that when your inputs are not made of nice numbers. So let's see what that looks like.
[02:08:01.680 --> 02:08:08.400] Basically scan your code and look for ugly numbers is roughly the heuristic. So three times is kind
[02:08:08.400 --> 02:08:14.640] of ugly. I'm not 100% sure maybe this can be improved, but this is ugly and not ideal.
[02:08:14.640 --> 02:08:24.880] Four times is nice. So that's that's nice. 1024 is very nice. That's a power of two.
[02:08:24.880 --> 02:08:32.800] 12 is a little bit suspicious. Not too many powers up to 768 is great. 50,000 to 57 is a really,
[02:08:32.800 --> 02:08:40.560] really ugly number. It's first of all, it's odd. So and there's no not too many powers of two in there.
[02:08:40.560 --> 02:08:46.480] So this is a very ugly number, and it's highly suspicious. And then when we scroll down, all these
[02:08:46.480 --> 02:08:53.680] numbers are nice. And then here we have mostly nice numbers except for 25. So in this configuration
[02:08:53.680 --> 02:08:59.280] of GPT to Excel, a number of heads is 25. That's a really ugly number. That's an odd number. And
[02:09:00.080 --> 02:09:03.440] actually this did cause a lot of headaches for us recently when we're trying to optimize some
[02:09:03.440 --> 02:09:10.240] kernels to run this fast and require a bunch of special case handling. So basically these numbers
[02:09:10.240 --> 02:09:15.040] are we have some ugly numbers and some of them are easier to fix than others. And in particular,
[02:09:15.040 --> 02:09:19.760] the vocab size being 50,000 to 57, that's a very ugly number, very suspicious and we're going to
[02:09:19.760 --> 02:09:25.040] fix it. Now when you when you fix these things, one of the easy ways to do that is you basically
[02:09:26.320 --> 02:09:32.240] increase the number until it's the nearest power of two that you like. So here's a much nicer number.
[02:09:32.240 --> 02:09:42.880] It's 50,304. And why is that? Because 50,304 can be divided by 8 or by 16 or by 32, 64.
[02:09:42.880 --> 02:09:50.320] It can even be divided by 128, I think. Yeah. So it's a very nice number. So what we're going to do
[02:09:50.320 --> 02:09:55.200] here is this is the GPT config. And you see that we initialize vocab size to 50,257.
[02:09:55.760 --> 02:10:05.520] Let's override just that element to be 50,304. Okay. So everything else stays the same,
[02:10:05.520 --> 02:10:10.800] we're just increasing our vocabulary size. So we're adding, it's almost like we're adding fake
[02:10:10.800 --> 02:10:16.800] tokens. So that vocab size has powers of two inside it. Now, actually, what I'm doing here,
[02:10:16.800 --> 02:10:20.640] by the way, is I'm increasing the amount of computation that our network will be doing.
[02:10:20.640 --> 02:10:25.040] If you just count the flops on like, do the math of how many flops we're doing, we're going to be
[02:10:25.040 --> 02:10:31.120] doing more flops. And we still have to think through whether this doesn't break anything.
[02:10:31.120 --> 02:10:40.080] But if I just run this, let's see what we get. Currently, this ran in maybe 96.5 milliseconds
[02:10:40.080 --> 02:10:44.720] per step. I'm just kind of eyeballing it. And let's see what kind of result we're going to get.
[02:10:44.720 --> 02:10:52.480] While this is compiling, let's think through whether our code actually works. Okay. When we
[02:10:52.480 --> 02:10:56.880] increase the vocab size like this, let's look at where vocab size is actually used.
[02:10:56.880 --> 02:11:02.560] So we swing up to the in it. And we see that it's used inside the embedding table, of course.
[02:11:02.560 --> 02:11:06.400] So all the way at the bottom of the transformer. And it's used at the classifier layer, all the
[02:11:06.400 --> 02:11:12.480] way at the top of the transformer. So in two places. And let's take a look. And we're running at 93.
[02:11:12.480 --> 02:11:20.400] So 93 milliseconds instead of 96.5. So we are seeing a roughly 4% improvement here.
[02:11:21.280 --> 02:11:27.920] By doing more calculations. And the reason for this is we fixed, we made an ugly number
[02:11:27.920 --> 02:11:33.280] into a nice number. Let's I'm going to come into the explanation for that a little bit again.
[02:11:33.280 --> 02:11:36.640] But for now, let's just convince ourselves that we're not breaking anything when we do this.
[02:11:36.640 --> 02:11:42.320] So first of all, we've made the WTE, the embedding table for the tokens. We've made it larger.
[02:11:42.320 --> 02:11:48.080] It's almost like we introduced more tokens at the bottom. And these tokens are never used,
[02:11:48.080 --> 02:11:54.880] because the GPT tokenizer only has tokens up to 50,000 to 56. And so we'll never index into
[02:11:54.880 --> 02:12:00.000] the rows that we've added. So we're wasting a little bit of space here by creating memory that's
[02:12:00.000 --> 02:12:04.960] never going to be accessed, never going to be used, et cetera. Now that's not fully correct,
[02:12:04.960 --> 02:12:10.400] because this WTE weight ends up being shared and ends up being used in the classifier here at the
[02:12:10.400 --> 02:12:15.920] end. So what is that doing to the classifier? I brought here. Well, what that's doing is we're
[02:12:15.920 --> 02:12:20.800] predicting additional dimensions of the classifier now. And we're predicting probabilities for tokens
[02:12:20.800 --> 02:12:26.160] that will, of course, never be present in the training set. And so therefore, the network has
[02:12:26.160 --> 02:12:32.080] to learn that these probabilities have to be driven to zero. And so the logits that the network
[02:12:32.080 --> 02:12:37.760] produces have to drive those dimensions of the output to native infinity. But that's no different
[02:12:37.760 --> 02:12:43.840] from all the other tokens that are already in our dataset, or rather that are not in our dataset.
[02:12:43.840 --> 02:12:49.680] So Shakespeare only probably uses, let's say, 1,000 tokens out of 50,000 to 57 tokens. So most
[02:12:49.680 --> 02:12:53.920] of the tokens are already being driven to zero probability by the optimization. We've just introduced
[02:12:53.920 --> 02:12:58.800] a few more tokens now that in a similar manner will never be used and have to be driven to zero
[02:12:58.800 --> 02:13:07.600] in probability. So functionally, though, nothing breaks, we're using a bit more extra memory. But
[02:13:07.600 --> 02:13:12.720] otherwise, this is a harmless operation, as far as I can tell. But and we're adding calculation
[02:13:12.720 --> 02:13:16.800] by it's running faster. And it's running faster, because as I mentioned in CUDA,
[02:13:16.800 --> 02:13:24.960] so many kernels use block tiles. And these block tiles are usually nice numbers. So powers of 2.
[02:13:24.960 --> 02:13:31.040] So calculations are done in like chunks of 64 or chunks of 32. And when you're when
[02:13:31.040 --> 02:13:37.760] your desired calculation doesn't neatly fit into those block tiles, there are all kinds of boundary
[02:13:37.760 --> 02:13:43.600] kernels that can kick in to like do the last part. So basically, in a lot of kernels,
[02:13:43.600 --> 02:13:48.080] they will chunkate up your input, and they will do the nice part first. And then they have a whole
[02:13:48.080 --> 02:13:54.880] second second phase, where they come back to anything that like remains. And then they process
[02:13:54.880 --> 02:13:58.880] the remaining part. And the kernels for that could be very inefficient. And so you're basically
[02:13:58.880 --> 02:14:04.640] spinning up all this extra compute, and it's extremely inefficient. So you might as well pad
[02:14:04.640 --> 02:14:11.120] your inputs and make it fit nicely. And usually that empirically ends up actually running faster.
[02:14:11.120 --> 02:14:18.880] So this is another example of a 4% improvement that we've added. And this is something that also
[02:14:18.880 --> 02:14:23.200] torch compiled did not find for us. You would hope that torch compiled at some point could
[02:14:23.200 --> 02:14:28.800] figure an optimization like this out. But for now, this is it. And I also have to point out that
[02:14:28.800 --> 02:14:34.480] we're using PyTorch nightly. So that's why we're only seeing 4%. If you're using PyTorch 2.3.1,
[02:14:34.640 --> 02:14:39.040] or earlier, you would actually see something like 30% improvement, just from this change,
[02:14:39.040 --> 02:14:48.000] from changing it to from 50,000 to 57 to 53.04. So again, one of my favorite examples also
[02:14:48.000 --> 02:14:52.240] of having to understand the under the hood and how it all works, and to know what kinds of things
[02:14:52.240 --> 02:14:56.880] to tinker with to push the performance of your code. Okay, so at this point, we have improved
[02:14:56.880 --> 02:15:01.840] the performance by about 11x, right? Because we started at about 1000 milliseconds per step,
[02:15:01.840 --> 02:15:08.240] and we're now down to like 93 milliseconds. So that's quite good. And we're doing a much better
[02:15:08.240 --> 02:15:13.840] job of utilizing our GPU resources. So I'm going to now turn to more algorithmic changes
[02:15:13.840 --> 02:15:18.320] and improvements to the actual optimization itself. And what we would like to do is we'd
[02:15:18.320 --> 02:15:23.200] like to follow the hyper parameters that are mentioned in the GPT-2 or GPT-3 paper.
[02:15:23.200 --> 02:15:30.160] Now, sadly, GPT-2 doesn't actually say too much. It's very nice of them that they released the
[02:15:30.160 --> 02:15:34.720] model weights and the code, but the paper itself is extremely big as to the optimization details.
[02:15:34.720 --> 02:15:40.240] The code itself that they released as well, the code we've been looking at, this is just the
[02:15:40.240 --> 02:15:44.640] inference code. So there's no training code here and very few hyper parameters. So this doesn't
[02:15:44.640 --> 02:15:51.360] also tell us too much. So for that, we have to turn to the GPT-3 paper. And in the appendix of the
[02:15:51.360 --> 02:15:58.800] GPT-3 paper, they have a lot more hyper parameters here for us to use. And the GPT-3 paper in general
[02:15:58.800 --> 02:16:04.880] is a lot more detailed as to all the small details that go into the model training,
[02:16:04.880 --> 02:16:10.560] but GPT-3 models were never released. So GPT-2, we have the weights, but no details,
[02:16:10.560 --> 02:16:16.960] and GPT-3, we have lots of details, but no weights. But roughly speaking, GPT-2 and GPT-3
[02:16:16.960 --> 02:16:22.880] architectures are very, very similar. And basically, there are very few changes. The context length
[02:16:22.880 --> 02:16:28.720] was expanded from 1024 to 2048. And that's kind of like the major change. And some of the hyper
[02:16:28.720 --> 02:16:32.720] parameters around the transformer have changed. But otherwise, they're pretty much the same model.
[02:16:32.720 --> 02:16:37.520] It's just that GPT-3 was trained for a lot longer on a bigger data set and has a lot more
[02:16:37.520 --> 02:16:46.880] thorough evaluations. And the GPT-3 model is 175 billion instead of 1.6 billion in the GPT-2.
[02:16:46.880 --> 02:16:51.600] So one story short, we're going to go to GPT-3 paper to follow along some of the hyper parameters.
[02:16:52.160 --> 02:16:58.880] So to train all the versions of GPT-3, we use atom with beta 1, beta 2 of 0.9 and 0.95.
[02:16:58.880 --> 02:17:04.480] So let's move over here and make sure that the beta's parameter, which you can see here,
[02:17:04.480 --> 02:17:13.680] defaults to 0.9 and 0.999 is actually set to 0.9 and 0.95. And then the epsilon parameter,
[02:17:13.680 --> 02:17:19.680] you can see is the default is 1 and negative 8. And this is also 1 and negative 8. Let's just
[02:17:19.680 --> 02:17:26.640] put it in so that we're explicit. Now, next up, they say we clip the grad global norm of the
[02:17:26.640 --> 02:17:32.160] gradient at 1.0. So what this is referring to is that once we calculate the gradients
[02:17:32.160 --> 02:17:37.600] right after last backward, we basically have the gradients at all the parameter tensors.
[02:17:37.600 --> 02:17:43.680] And what people like to do is basically clip them to have some kind of a maximum norm.
[02:17:44.800 --> 02:17:50.080] So in PyTorch, this is fairly easy to do. It's one line of code here that we have to insert right
[02:17:50.080 --> 02:17:57.040] after we calculate the gradients. And what this utility function is doing is it's calculating the
[02:17:57.040 --> 02:18:02.800] global norm of the parameters. So every single gradient on all the parameters,
[02:18:02.800 --> 02:18:08.480] you square it and you add it all up and you take a big square root of that. And that's the norm of
[02:18:08.480 --> 02:18:15.440] the parameter vector, basically. It's the length of it, if you like to look at it that way.
[02:18:15.440 --> 02:18:20.480] And we are basically making sure that its length is no more than 1.0 and we're going to clip it.
[02:18:20.480 --> 02:18:25.920] And the reason that people like to use this is that sometimes you can get unlucky during the
[02:18:25.920 --> 02:18:31.120] optimization. Maybe it's a bad data batch or something like that. And if you get very unlucky in the
[02:18:31.120 --> 02:18:36.160] batch, you might get really high loss and really high loss could lead to a really high gradient.
[02:18:36.160 --> 02:18:42.320] And this could basically shock your model and shock the optimization. So people like to use a
[02:18:42.320 --> 02:18:49.920] gradient norm clipping to prevent the model from basically getting too big of shocks in terms of
[02:18:49.920 --> 02:18:55.440] the gradient magnitude and the upper bounded in this way. It's a bit of a hacky solution,
[02:18:55.440 --> 02:19:01.440] it's a bit like a patch on top of deeper issues, but people still do it fairly frequently.
[02:19:02.000 --> 02:19:08.080] Now the clip grad norm returns the norm of the gradient, which I like to always visualize
[02:19:08.080 --> 02:19:14.480] because it is useful information. And sometimes you can look at the norm of the gradient. And if
[02:19:14.480 --> 02:19:18.880] it's well behaved, things are good. If it's climbing, things are bad and they're destabilizing during
[02:19:18.880 --> 02:19:22.880] training, sometimes you could get a spike in the norm. And that means there's some kind of
[02:19:22.880 --> 02:19:33.840] an issue or instability. So the norm here will be a norm. And let's do a 0.4 f or something like that.
[02:19:33.840 --> 02:19:43.120] And I believe this is just a float. And so we should be able to print that. So that's global
[02:19:43.120 --> 02:19:49.920] gradient clipping. Now they go into the details of the learning rate scheduler. So they don't just
[02:19:49.920 --> 02:19:55.200] use a fixed learning rate like we do here for three negative four, but there's actually basically
[02:19:55.200 --> 02:20:03.280] a cosine decay learning rate schedule. It's got a warm up and it's got a cosine decay to 10%
[02:20:03.280 --> 02:20:11.760] over some horizon. And so we're going to implement this in a second. I just like to see the norm
[02:20:11.760 --> 02:20:17.200] printed here. Okay, there we go. So what happened here is the norm is actually really high in the
[02:20:17.200 --> 02:20:22.640] beginning 30 or so. And you see that as we continue training, it kind of stabilizes
[02:20:22.640 --> 02:20:31.200] at values below one. And this is not that crazy uncommon for the norm to be high in the very
[02:20:31.200 --> 02:20:35.280] first few stages. Basically what's happening here is the model is completely random. And so
[02:20:35.280 --> 02:20:39.520] there's a ton of learning happening very early in the network. But that learning is kind of like,
[02:20:39.520 --> 02:20:45.920] you know, it's mostly learning the biases of the output tokens. And so it's a bit of an unstable
[02:20:45.920 --> 02:20:50.880] time. But the network usually stabilizes in the very few iterations. So this looks relatively
[02:20:50.880 --> 02:20:55.760] reasonable to me, except usually I would expect this looks a little bit funky that we go from 28
[02:20:55.760 --> 02:21:04.240] to 62 and then to 10. It's not completely insane, but it's just kind of a little bit funky. Okay,
[02:21:04.240 --> 02:21:08.480] so let's now get to the learning rate scheduler. So the learning rate schedule that's used here in
[02:21:08.480 --> 02:21:15.280] GPT three is what's called a cosine decay learning schedule with warm up. And the way this looks is
[02:21:15.280 --> 02:21:21.680] that the learning rate is basically starts right at around zero, linearly ramps up over some amount
[02:21:21.680 --> 02:21:28.080] of time, and then comes down with this cosine sort of form, and comes down to some kind of a
[02:21:28.080 --> 02:21:33.280] minimum learning rate that's up to you. So here the minimum learning rate is zero. But here in the
[02:21:33.280 --> 02:21:38.880] paper, they said that they use cosine decay for learning rate down to 10% of its value over the
[02:21:38.880 --> 02:21:46.160] first 260 billion tokens, and then training continues 10% after. And there's a linear warm-up
[02:21:46.160 --> 02:21:52.080] over the first 375 million tokens. So that's about the learning rate. So let's now implement this.
[02:21:52.080 --> 02:21:59.040] So I already implemented it here. And the way this works is, let me scroll down first here,
[02:21:59.040 --> 02:22:04.400] I changed our training loop a little bit. So this was a 4i in max steps. I just change it to step
[02:22:04.400 --> 02:22:11.120] now, so that we have the notion of a step is a single optimization step in the for loop. And then
[02:22:11.120 --> 02:22:18.000] here, I get the LR for this step of the optimization using a new function I call get LR. And then in
[02:22:18.000 --> 02:22:22.320] PyTorch to set the learning rate, I think this is the way to set the learning rate. It's a little
[02:22:22.320 --> 02:22:27.280] bit gnarly, because you have to basically there's notion of different parameter groups that could
[02:22:27.280 --> 02:22:31.040] exist in the optimizer. And so you actually have to iterate over them, even though we currently have
[02:22:31.040 --> 02:22:36.640] a single parameter group only. And you have to set the LR in this for loop kind of style,
[02:22:36.640 --> 02:22:41.920] is my impression right now. So we have this local for LR, we set the learning rate,
[02:22:41.920 --> 02:22:47.600] and then on the bottom also printing it. So that's all the changes I made to this loop.
[02:22:47.600 --> 02:22:52.240] And then of course the get LR is my scheduler. Now it's worth pointing out that PyTorch actually
[02:22:52.240 --> 02:22:56.880] has learning rate schedulers, and you can use them. And I believe there's a cosine learning rate
[02:22:56.880 --> 02:23:02.640] schedule in PyTorch. I just don't really love using that code, because honestly,
[02:23:02.640 --> 02:23:08.240] it's like five lines of code. And I fully understand what's happening inside these lines. So I don't
[02:23:08.240 --> 02:23:13.200] love to use abstractions where they're kind of in scribble. And then I don't know what they're
[02:23:13.200 --> 02:23:19.360] doing. So personal style. So the max learning rate here is let's say three negative four,
[02:23:19.360 --> 02:23:26.000] but we're going to see that in GPT three here, they have a table of what the maximum learning
[02:23:26.000 --> 02:23:36.880] rate is for every model size. So for, for this one, basically 12 12 layer 768 GPT three. So the
[02:23:36.880 --> 02:23:42.720] GPT three small is roughly like a GPT two one 24m. We see that here, they use a learning rate
[02:23:42.720 --> 02:23:47.200] of six E negative four. So we could actually go higher. In fact, we may want to try to follow
[02:23:47.200 --> 02:23:54.400] that. So just set the max LR here, six. Then the maximum learning rate, the middle learning rate is
[02:23:55.360 --> 02:24:01.600] 10% of that per description in the paper, some number of steps that we're going to warm up over,
[02:24:01.600 --> 02:24:06.000] and then the maximum steps of the optimization, which I now use also in the for look down here.
[02:24:06.000 --> 02:24:12.240] And then you can go for this code if you like. It's not, it's not terribly inside floor interesting.
[02:24:12.240 --> 02:24:17.600] I'm just modulating based on the iteration number of which learning rate there should be.
[02:24:17.600 --> 02:24:24.560] So this is the warm up region. This is the region after the optimization. And then this is the region
[02:24:24.560 --> 02:24:29.280] sort of in between. And this is where I calculate the cosine learning rate schedule. And you can
[02:24:29.280 --> 02:24:33.360] step through this in detail if you'd like. But this is basically implementing this curve.
[02:24:33.360 --> 02:24:39.040] And I ran this already. And this is what that looks like.
[02:24:39.040 --> 02:24:49.040] So when we now run, we start at some very low number. Now note that we don't start exactly at zero,
[02:24:49.040 --> 02:24:52.880] because that would be not useful to update with a learning rate of zero. That's why there's an
[02:24:52.880 --> 02:24:58.000] it plus one, so that on the zero iteration, we are not using exactly zero. We're using something
[02:24:58.000 --> 02:25:03.040] very, very low. Then we linearly warm up to maximum learning rate, which in this case was
[02:25:03.040 --> 02:25:08.400] three negative four when I ran it, but now would be six negative four. And then it starts to decay
[02:25:08.400 --> 02:25:15.040] all the way down to three negative five, which was at the time 10% of the original learning rate.
[02:25:15.040 --> 02:25:18.720] Now one thing we are not following exactly is that they mentioned that
[02:25:21.280 --> 02:25:26.960] let me see if I can find it again. We're not exactly following what they did because
[02:25:26.960 --> 02:25:34.400] they mentioned that their training horizon is 300 billion tokens. And they come down to 10%
[02:25:34.400 --> 02:25:41.440] of the initial learning rate of at 260 billion. And then they train after 260 with 10%. So basically
[02:25:41.440 --> 02:25:46.560] their decay time is less than the max steps time, whereas for us, they're exactly equal.
[02:25:46.560 --> 02:25:54.400] So it's not exactly faithful, but it's an okay. This is okay for us and for our purposes right now.
[02:25:54.400 --> 02:26:00.320] And we're just going to use this ourselves. I don't think it makes too big of a difference,
[02:26:00.320 --> 02:26:05.440] honestly. I should point out that what learning rate schedule you use is totally up to you.
[02:26:05.440 --> 02:26:11.280] There's many different types. Cosine learning rate has been popularized a lot by GPT-2 and GPT-3,
[02:26:11.280 --> 02:26:15.760] but people have come up with all kinds of other learning rate schedules. And this is kind of
[02:26:15.760 --> 02:26:20.880] what an active area of research as to which one is the most effective at training these networks.
[02:26:20.880 --> 02:26:28.160] Okay, next up, the paper talks about the gradual batch size increase. So there's a ramp on the
[02:26:28.160 --> 02:26:32.800] batch size that is linear. And you start with very small batch size and you ramp up to a big
[02:26:32.800 --> 02:26:37.600] batch size over time. We're going to actually skip this and we're not going to work with it.
[02:26:37.600 --> 02:26:42.000] And the reason I don't love to use it is that it complicates a lot of the arithmetic because
[02:26:42.000 --> 02:26:46.400] you are changing the number of tokens that you're processing at every single step of the optimization.
[02:26:46.400 --> 02:26:50.800] And I like to keep that math very, very simple. Also, my understanding is that this is not like
[02:26:50.800 --> 02:26:57.280] a major improvement. And also, my understanding is that this is not like an algorithmic optimization
[02:26:57.280 --> 02:27:02.080] improvement. It's more of a systems and speed improvement. And roughly speaking, this is because
[02:27:02.080 --> 02:27:10.000] in the early stages of the optimization, again, the model is in a very atypical setting.
[02:27:10.000 --> 02:27:15.520] And mostly what you're learning is that you're mostly learning to ignore the tokens that don't
[02:27:15.520 --> 02:27:20.880] come up in your training set very often. You're learning very simple biases and that kind of a
[02:27:20.880 --> 02:27:27.120] thing. And so every single example that you put through your network is basically just telling you
[02:27:27.120 --> 02:27:31.200] use these tokens and don't use these tokens. And so the gradients from every single example are
[02:27:31.200 --> 02:27:36.960] actually extremely highly correlated. They all look roughly the same in the original parts of
[02:27:36.960 --> 02:27:40.800] the optimization because they're all just telling you that these tokens don't appear and these tokens
[02:27:40.800 --> 02:27:46.240] do appear. And so because the gradients are all very similar and they're highly correlated,
[02:27:46.240 --> 02:27:51.840] then why are you doing batch sizes of like millions when if you do a batch size of 32K,
[02:27:51.840 --> 02:27:57.040] you're basically getting the exact same gradient early on in the training. And then later in the
[02:27:57.040 --> 02:28:01.600] optimization, once you've learned all the simple stuff, that's where the actual work starts. And
[02:28:01.600 --> 02:28:06.160] that's where the gradients become more decorrelated, for examples. And that's where they actually offer
[02:28:06.160 --> 02:28:12.400] you sort of statistical power in some sense. So we're going to skip this just because it kind of
[02:28:12.400 --> 02:28:19.440] complicates things. And we're going to go to data are sampled without replacement during training.
[02:28:19.440 --> 02:28:24.880] So until an epoch boundary is reached. So without replacement means that they're not
[02:28:24.880 --> 02:28:31.840] sampling from some fixed pool and then take a sequence, train on it, but then also like return
[02:28:31.840 --> 02:28:37.600] to sequence the pool, they are exhausting a pool. So when they draw a sequence, it's it's gone until
[02:28:37.600 --> 02:28:44.400] the next epoch of training. So we're already doing that because our data loader iterates over chunks
[02:28:44.400 --> 02:28:49.920] of data. So there's no replacement. They don't become eligible to be drawn again until the next
[02:28:49.920 --> 02:28:57.600] epoch. So we're basically already doing that. All models use a weight decay of 0.1 to provide
[02:28:57.600 --> 02:29:02.560] a small amount of regularization. So let's implement a weight decay. And you see here
[02:29:02.560 --> 02:29:07.040] that I've already kind of made the changes. And in particular, instead of creating the optimizer
[02:29:07.040 --> 02:29:14.320] right here, I'm creating a new configure optimizer function inside the model. And I'm passing in some
[02:29:14.320 --> 02:29:18.880] of the hyper parameters instead. So let's look at the configure optimizers, which is supposed to
[02:29:18.880 --> 02:29:28.400] return the optimizer object. Okay, so it looks complicated, but it's actually really simple.
[02:29:28.400 --> 02:29:33.280] And it's just, or just being very careful. And there's a few settings here to go through.
[02:29:33.280 --> 02:29:37.600] The most important thing with respect to this line is that you see there's a weight decay
[02:29:37.600 --> 02:29:45.200] parameter here. And I'm passing that into, well, I'm passing that into something called
[02:29:45.200 --> 02:29:50.720] Optum groups that eventually ends up going into the add and W optimizer. And the weight decay
[02:29:50.720 --> 02:29:57.680] that's by default used in add and W here is 0.01. So it's, it's 10 times lower than what's used in
[02:29:57.680 --> 02:30:04.160] GPT three paper here. So the weight decay basically ends up making its way into the add and W three
[02:30:04.160 --> 02:30:09.280] optimizer groups. Now what else is going on here in this function? So the two things that are happening
[02:30:09.280 --> 02:30:14.000] here that are important is that I'm splitting up the parameters into those that should be weight
[02:30:14.000 --> 02:30:18.960] decay and those that should not be weight decay. So in particular, it is common to not
[02:30:18.960 --> 02:30:25.840] weight decay biases and any other sort of one dimensional tensors. So the one dimensional
[02:30:25.840 --> 02:30:31.840] tensors are in the node decay params. And these are also things like layer norm,
[02:30:31.840 --> 02:30:36.240] scales and biases. It doesn't really make sense to weight decay those. You mostly want to weight
[02:30:36.240 --> 02:30:41.920] decay the weights that participate in matrix multiplications. And you want to potentially
[02:30:41.920 --> 02:30:48.000] weight decay the embeddings. And we've covered in previous video why it makes sense to decay the
[02:30:48.000 --> 02:30:52.320] weights because you can sort of think of it as a regularization because when you're pulling down
[02:30:52.320 --> 02:30:58.080] all the weights, you're forcing the optimization to use more of the weights. And you're not allowing
[02:30:58.080 --> 02:31:03.920] any one of the weights individually to be way too large. You're forcing your forcing the network
[02:31:03.920 --> 02:31:09.280] to kind of distribute the work across more channels because there's sort of like a pool of gravity
[02:31:09.280 --> 02:31:16.240] on the weights themselves. So that's why we are separating it in those ways here. We're only
[02:31:16.240 --> 02:31:23.120] decaying the embeddings and the matmole participating weights. We're printing the number of parameters
[02:31:23.120 --> 02:31:27.360] that we're decaying and not. Most of the parameters will be decayed. And then one more thing that
[02:31:27.360 --> 02:31:35.040] we're doing here is I'm doing another optimization here. And previous Adam W did not have this option,
[02:31:35.040 --> 02:31:40.160] but the later parts of PyTorch introduced it. And that's why I'm guarding it with an inspect
[02:31:40.160 --> 02:31:48.160] that signature, which is basically checking if this fused quad is present inside Adam W.
[02:31:48.160 --> 02:31:53.840] And then if it is present, I'm going to end up using it and passing it in here because some
[02:31:53.840 --> 02:32:00.320] earlier versions do not have fused equals. So here's Adam W fused equals. It did not used to exist
[02:32:00.320 --> 02:32:06.240] and it was added later. And there's some docs here for what's happening. And basically they say that
[02:32:06.240 --> 02:32:11.200] by default they do not use fused because it is relatively new and we want to give it sufficient
[02:32:11.200 --> 02:32:16.240] peak time. So by default, they don't use fused. But fused is a lot faster when it is available
[02:32:16.240 --> 02:32:21.840] and when you're running on CUDA. And what that does is instead of iterating in a for loop over all
[02:32:21.840 --> 02:32:28.000] the parameter tensors and updating them, that would launch a lot of kernels, right? And so fused
[02:32:28.000 --> 02:32:34.000] just means that all those kernels are fused into a single kernel. You get rid of a lot of overhead
[02:32:34.000 --> 02:32:42.000] and you a single time on all the primers call a kernel that updates them. And so it's just basically
[02:32:42.000 --> 02:32:48.480] kernel fusion for the Adam W update instead of iterating over all the tensors. So that's the
[02:32:48.480 --> 02:32:54.080] configure optimizers function that I like to use. And we can rerun and we're not going to see any
[02:32:54.080 --> 02:32:59.760] major differences from what we saw before, but we are going to see some prints coming from here.
[02:32:59.760 --> 02:33:05.840] So let's just take a look at what they look like. So we see that number of decay tensors is 50
[02:33:05.840 --> 02:33:10.000] and it's most of the parameters. A number of non-decade tensors is 98. And these are
[02:33:10.000 --> 02:33:16.000] devices in the layer norm parameters mostly. And that's there's only 100,000 of those. So most of
[02:33:16.000 --> 02:33:22.160] it is decayed. And then we are using the fused implementation of Adam W, which will be a lot faster.
[02:33:22.160 --> 02:33:26.720] So if you have it available, I would advise you to use it. I'm not actually 100% sure why they
[02:33:26.720 --> 02:33:31.680] don't default to it. It seems fairly benign and harmless. And also because we are using the
[02:33:31.680 --> 02:33:38.080] fused implementation, I think this is why we have dropped. Notice that the running time is to be
[02:33:38.080 --> 02:33:43.760] 93 milliseconds per step. And we're now down to 90 milliseconds per step because of using the fused
[02:33:43.760 --> 02:33:51.440] Adam W optimizer. So in a single commit here, we are introducing fused Adam getting improvements
[02:33:51.440 --> 02:33:56.800] on the time. And we're adding or changing the weight decay, but we're only weight decaying the
[02:33:56.800 --> 02:34:00.800] two dimensional parameters, the embeddings, and the matrices that participated in linear.
[02:34:00.800 --> 02:34:10.320] So that is this and we can take this out. And yeah, that is it for this line. One more quick note
[02:34:10.320 --> 02:34:14.720] before we continue here, I just want to point out that the relationship between weight decay,
[02:34:14.720 --> 02:34:20.320] learning rate, batch size, the Adam parameters beta 1, beta 2, the epsilon and so on, these are
[02:34:20.320 --> 02:34:27.520] very complicated mathematical relationships in the optimization literature. And for the most part,
[02:34:27.520 --> 02:34:31.680] I'm in this video, I'm just trying to copy paste the settings that open AI used. But this is a
[02:34:31.680 --> 02:34:37.840] complicated topic quite deep. And yeah, in this video, I just want to copy the parameters because
[02:34:37.840 --> 02:34:41.520] it's a whole different video to really talk about that in detail and give it a proper
[02:34:41.520 --> 02:34:46.320] justice instead of just high level intuitions. And now the next thing that I want to move on to
[02:34:46.320 --> 02:34:51.360] is that this paragraph here, by the way, we're going to turn back around to when we
[02:34:51.360 --> 02:34:57.760] improve our data load. For now, I want to swing back around to this table,
[02:34:57.760 --> 02:35:07.120] where you will notice that for different models, we, of course, have different
[02:35:07.120 --> 02:35:11.840] high parameters for the transformer that dictate the size of the transformer network. We also have
[02:35:11.840 --> 02:35:15.200] a different learning rate. So we're seeing the pattern that the bigger networks are trained by
[02:35:15.200 --> 02:35:21.840] slightly lower learning rates. And we also see this batch size, where in the small networks,
[02:35:21.840 --> 02:35:25.360] they use a smaller batch size, and then the bigger networks, they use a bigger batch size.
[02:35:25.360 --> 02:35:32.720] Now the problem with for us is we can't just use 0.5 million batch size, because if I just try to
[02:35:32.720 --> 02:35:47.920] come in here, and I try to set this b, where's my b? Where do I call the two? Okay, b equals 16.
[02:35:47.920 --> 02:35:54.320] If I try to set, well, we have to be careful, it's not 0.5 million, because this is the batch size
[02:35:54.320 --> 02:36:02.320] in the number of tokens. Every single one of our rows is 1024 tokens. So 0.5 e6, 1 million
[02:36:02.320 --> 02:36:10.480] divide 1024. This would need about a 488 batch size. So the problem is I can't come in here and set
[02:36:10.480 --> 02:36:17.600] this to four eight eight, because my GPU would explode. This would not fit for sure. And so,
[02:36:17.600 --> 02:36:23.600] but we still want to use this batch size, because again, as I mentioned, the batch size is correlated
[02:36:23.600 --> 02:36:28.160] with all the other optimization high parameters, and the learning rates and so on. So we want to
[02:36:28.160 --> 02:36:33.200] have a faithful representation of all the hyper parameters. And therefore, we need to use a batch
[02:36:33.200 --> 02:36:39.360] size of 0.5 million roughly. But the question is, how do we use 0.5 million if we only have a small
[02:36:39.360 --> 02:36:44.640] GPU? Well, for that, we need to use what's called gradient accumulation. So we're going to turn
[02:36:44.640 --> 02:36:50.480] to that next, and it allows us to simulate in a serial way, any arbitrary batch size that we set.
[02:36:50.480 --> 02:36:56.720] And so we can do a batch size of 0.5 million, we just have to run longer, and we have to process
[02:36:56.720 --> 02:37:02.960] multiple sequences, and basically add up all the gradients from them to simulate a batch size of
[02:37:02.960 --> 02:37:07.360] 0.5 million. So let's turn to that next. Okay, so I started the implementation right here,
[02:37:07.360 --> 02:37:13.040] just by adding these lines of code. And basically, what I did is first, I set the total batch size
[02:37:13.040 --> 02:37:18.560] that we desire. So this is exactly 0.5 million. And I used a nice number, a power of two,
[02:37:18.560 --> 02:37:24.480] because two to the 19 is a five 24 to eight. So it's roughly 0.5 million in some nice number.
[02:37:25.360 --> 02:37:32.160] Now our micro batch size, as we call it now, is 16. So this is going to be, we still have b by t
[02:37:32.160 --> 02:37:37.200] in the seats that go into the transformer and do forward backward. But we're not going to do an
[02:37:37.200 --> 02:37:41.840] update, right? We're going to do many forward backwards. We're going to, and those gradients
[02:37:41.840 --> 02:37:46.160] are all going to plus equals on the parameter gradients, they're all going to add up. So we're
[02:37:46.160 --> 02:37:51.280] going to do forward backward grad Akum steps number of times. And then we're going to do a single
[02:37:51.280 --> 02:37:57.280] update once all that is accumulated. So in particular, our micro batch size is just now
[02:37:57.280 --> 02:38:01.840] controlling how many tokens, how many rows we're processing in a single go of a forward backward.
[02:38:01.840 --> 02:38:09.280] So here, we are doing 16 times one 24, we're doing 16, 384
[02:38:09.280 --> 02:38:17.040] tokens per forward backward. And we are supposed to be doing two to the 19,
[02:38:17.600 --> 02:38:24.000] oops, what am I doing? Two to the 19 in total. So the grad Akum will be 32.
[02:38:24.000 --> 02:38:32.560] So therefore, grad Akum here will work out to 32. And we have to do 32 forward backward,
[02:38:32.560 --> 02:38:38.640] and then a single update. Now we see that we have about 100 milliseconds for a single forward
[02:38:38.640 --> 02:38:45.520] backward. So doing 32 of them will be, we'll make every step roughly three seconds, just napkin math.
[02:38:47.120 --> 02:38:51.360] So that's grad Akum steps, but now we actually like to implement that. So we're going to swing
[02:38:51.360 --> 02:38:59.280] over to our training loop, because now this part here, and this part here, the forward and the
[02:38:59.280 --> 02:39:06.000] backward, we have to now repeat this 32 times before we do everything else that follows. So let's
[02:39:06.000 --> 02:39:10.800] see how we can implement that. So let's come over here. And actually, we do have to load a new batch
[02:39:10.800 --> 02:39:15.600] every single time. So let me move that over here. And now this is where we have the inner loop. So
[02:39:15.600 --> 02:39:25.040] for micro step in range, grad Akum steps, we do this. And remember that loss of backward always
[02:39:25.040 --> 02:39:29.520] deposits gradients. So we're doing inside loss of backward, there's always a plus equals on the
[02:39:29.520 --> 02:39:34.720] gradients. So in every single loss of backward gradients will add up on the gradient testers.
[02:39:34.720 --> 02:39:43.120] So we lost a backward, and then we get all the gradients over there. And then we normalize
[02:39:43.120 --> 02:39:48.800] and everything else should just follow. So we're very close. But actually, there's like
[02:39:48.800 --> 02:39:55.360] subtle and deep issue here. And this is actually incorrect. So I invite you to think about why
[02:39:55.360 --> 02:40:01.440] this is not yet sufficient. And let me fix it then. Okay, so I brought back the Jupyter Notebook. So
[02:40:01.440 --> 02:40:07.120] we can think about this carefully in a simple toy setting and see what's happening. So let's
[02:40:07.120 --> 02:40:12.320] create a very simple neural nut that takes a 16 vector of 16 numbers and returns a single number.
[02:40:13.600 --> 02:40:21.520] And then here I'm creating some random examples x and some targets y. And then we are using the
[02:40:21.520 --> 02:40:28.640] mean squared loss here to calculate the loss. So basically what this is is four individual
[02:40:28.640 --> 02:40:34.080] examples. And we're just doing simple regression with the mean squared loss over those four examples.
[02:40:34.080 --> 02:40:38.800] Now when we calculate the loss and we lost that backward and look at the gradient,
[02:40:39.360 --> 02:40:45.600] this is the gradient that we achieve. Now the loss objective here, notice that in MSC loss,
[02:40:45.600 --> 02:40:52.800] the default for the loss function is reduction is mean. So we're calculating the average mean loss,
[02:40:52.800 --> 02:41:01.440] the mean loss here over the four examples. So this is the exact loss objective. And this is the
[02:41:01.440 --> 02:41:06.400] average, the one over four, because there are four independent examples here. And then we have the
[02:41:07.040 --> 02:41:12.000] four examples and their mean squared error, the squared error, and then this makes it the mean
[02:41:12.000 --> 02:41:18.240] squared error. So therefore, we are, we calculate the squared error, and then we normalize it to
[02:41:18.240 --> 02:41:22.880] make it the mean over the examples. And there's four examples here. So now when we come to the
[02:41:22.880 --> 02:41:31.200] gradient accumulation version of it, this, this here is the gradient accumulation version of it,
[02:41:31.200 --> 02:41:36.320] where we have grad acronym steps of four, and I reset the gradient, we've got a school steps of
[02:41:36.320 --> 02:41:40.880] four. And now I'm evaluating all the examples individually instead and calling the loss that
[02:41:40.880 --> 02:41:45.440] backward on them many times. And then we're looking at the gradient that we achieve from that. So
[02:41:45.440 --> 02:41:51.280] basically now we forward our function, calculate the exact same loss, do a backward, and we do that
[02:41:51.280 --> 02:41:56.560] four times. And when we look at the gradient, you'll notice that the gradients don't match.
[02:41:56.560 --> 02:42:04.400] So here we did a single batch of four. And here we did four gradient accumulation steps
[02:42:04.400 --> 02:42:09.760] of batch size one. And the gradients are not the same. And basically the reason that they're not
[02:42:09.760 --> 02:42:15.600] the same is exactly because this mean squared error gets lost. This one quarter in this loss
[02:42:15.600 --> 02:42:22.240] gets lost, because what happens here is the loss objective for every one of the loops is just a
[02:42:22.240 --> 02:42:27.840] mean squared error, which in this case, because there's a long single example is just this term
[02:42:27.840 --> 02:42:33.840] here. So that was the loss in the zero iteration, the same in the first third and so on. And then
[02:42:33.840 --> 02:42:39.840] when you do the last backward, we're accumulating gradients. And what happens is that accumulation
[02:42:39.840 --> 02:42:48.080] the gradient is basically equivalent to doing a sum in the loss. So our loss actually here is
[02:42:48.080 --> 02:42:55.760] this without the factor of one quarter outside of it. So we're missing the normalizer. And therefore
[02:42:55.760 --> 02:43:00.240] our gradients are off. And so the way to fix this or one of them is basically we can actually come
[02:43:00.240 --> 02:43:07.680] here and we can say loss equals loss divide four. And what happens now is that we're introducing
[02:43:07.680 --> 02:43:12.560] where we're scaling our loss, we're introducing them one quarter in front of all of these places.
[02:43:12.560 --> 02:43:20.000] So all the individual losses are now scaled by one quarter. And then when we backward,
[02:43:20.000 --> 02:43:25.520] all of these accumulate with a sum. But now there's a one quarter inside every one of these
[02:43:25.520 --> 02:43:33.760] components. And now our losses will be equivalent. So when I run this, you see that the gradients
[02:43:33.760 --> 02:43:39.360] are now identical. So long story short, with this simple example, when you step through it,
[02:43:39.360 --> 02:43:44.560] you can see that basically the reason that this is not correct is because in the same way as
[02:43:44.560 --> 02:43:54.560] here in the MSC loss, the loss that we're calculating here in the model is using a reduction of mean
[02:43:54.560 --> 02:44:01.760] as well. So where's the loss after that cross entropy? And by default, the reduction here in
[02:44:01.760 --> 02:44:07.520] cross entropy is also, I don't know why they don't show it, but it's the mean, the mean loss at all
[02:44:07.520 --> 02:44:15.280] the B by T elements, right? So there's a reduction by mean in there. And if we're just doing this
[02:44:15.280 --> 02:44:20.160] gradient accumulation here, we're missing that. And so the way to fix this is to simply compensate
[02:44:20.160 --> 02:44:24.320] for the number of gradient accumulation steps, and we can in the same way divide this loss.
[02:44:24.320 --> 02:44:30.960] So in particular here, the number of steps that we're doing is plus equals loss divided
[02:44:30.960 --> 02:44:36.960] gradient accumulation steps. So even a co-pilot, sorry, gets the modification. But in the same way
[02:44:36.960 --> 02:44:42.240] exactly, we are scaling down the loss so that when we do lost a backward, which basically corresponds
[02:44:42.240 --> 02:44:49.200] to a sum in the objective, we are summing up the already normalized loss. And therefore,
[02:44:49.200 --> 02:44:54.800] when we sum up the losses divided by radical steps, we are recovering the additional normalizer.
[02:44:54.800 --> 02:45:02.320] And so now these two will be now this will be equivalent to the original sort of optimization,
[02:45:02.320 --> 02:45:06.560] because the gradient will come out the same. Okay, so I had to do a few more touch ups,
[02:45:06.560 --> 02:45:10.800] and I launched launched the optimization here. So in particular, one thing we want to do,
[02:45:10.800 --> 02:45:16.080] because we want to print things nicely, is well, first of all, we need to create like an accumulator
[02:45:16.080 --> 02:45:19.920] over the loss, we can't just print the loss, because we'd be printing only the final loss
[02:45:19.920 --> 02:45:24.800] at the final micro step. So instead, we have a loss of coom, which I initialized at zero,
[02:45:24.800 --> 02:45:32.560] and then I accumulate a the loss into it. And I'm using detached so that I'm detaching the tensor
[02:45:32.560 --> 02:45:38.560] from the graph, and I'm just trying to keep track of the values. So I'm making these leaf nodes
[02:45:38.560 --> 02:45:43.840] when I add them. So that's loss of coom, and then we're printing that here instead of loss.
[02:45:43.840 --> 02:45:49.120] And then in addition to that I have to account for the gradicum steps inside the tokens processed,
[02:45:49.120 --> 02:45:53.920] because now the tokens processed per step is b times t times gradient accumulation.
[02:45:53.920 --> 02:46:00.800] So once we're short, here we have the optimization, it looks reasonable, right, we're starting at a
[02:46:00.800 --> 02:46:07.280] good spot, we calculated the gradicum steps to be 32. And we're getting about three seconds here,
[02:46:07.280 --> 02:46:17.280] right? And so this looks pretty good. Now if you'd like to verify that your optimization
[02:46:17.280 --> 02:46:21.280] and the implementation here is correct, and you're working on a side, well, now because we have the
[02:46:21.280 --> 02:46:26.640] total back size and the gradient accumulation steps, our setting of b is purely a performance
[02:46:26.640 --> 02:46:32.080] optimization kind of setting. So if you have a big GPU, you can actually increase this to 32,
[02:46:32.080 --> 02:46:36.320] and you'll probably go a bit faster. If you have a very small GPU, you can try eight or four,
[02:46:36.880 --> 02:46:41.200] but in any case, you should be getting the exact same optimization and the same answers up to
[02:46:41.200 --> 02:46:47.360] what a floating point error, because the gradient accumulation kicks in and can handle everything
[02:46:47.360 --> 02:46:54.080] serially as necessary. So that's it for gradient accumulation, I think. Okay, so now is the time
[02:46:54.080 --> 02:46:59.120] to bring out the heavy weapons. You've noticed that so far, we've only been using a single GPU
[02:46:59.120 --> 02:47:04.960] for training, but actually I am paying for a GPUs here. And so we should be putting all of them to
[02:47:04.960 --> 02:47:11.760] work. And in particular, they're all going to collaborate and optimize over tokens at the same
[02:47:11.760 --> 02:47:18.160] time and communicate so that they're all kind of collaborating on the optimization. For this,
[02:47:18.160 --> 02:47:22.160] we are going to be using the distributed data parallel from PyTorch. There's also a legacy
[02:47:22.160 --> 02:47:27.840] data parallel, which I recommend you not use. And that's kind of like legacy. Distribute data
[02:47:27.840 --> 02:47:34.240] parallel works in a very simple way. We have eight GPUs. So we're going to launch eight processes.
[02:47:34.240 --> 02:47:40.320] And each process is going to be assigned a GPU. And for each process, the training loop and
[02:47:40.320 --> 02:47:44.640] everything we've worked on so far is going to look pretty much the same. Each GPU, as far as
[02:47:44.640 --> 02:47:49.680] it's concerned, is just working on exactly what we've built so far. But now secretly, there's eight
[02:47:49.680 --> 02:47:55.120] of them, and they're all going to be processing slightly different parts of the data. And we're
[02:47:55.120 --> 02:48:00.080] going to add one more part where once they all calculate their gradients, there's one more part
[02:48:00.080 --> 02:48:06.400] where we do a average of those gradients. And so that's how they're going to be collaborating
[02:48:06.400 --> 02:48:12.720] on the computational workload here. So to use all eight of them, we're not going to be launching
[02:48:12.720 --> 02:48:19.440] our script anymore with just PyTorch train GPT2.py. We're going to be running it with a special
[02:48:19.440 --> 02:48:25.760] command called torch run in PyTorch. We'll see them in a bit. And torch run, when it runs our
[02:48:25.760 --> 02:48:32.560] Python script, will actually make sure to run eight of them in parallel. And it creates these
[02:48:32.560 --> 02:48:39.040] environmental variables where each of these processes can look up which basically which one
[02:48:39.040 --> 02:48:45.840] of the processes it is. So for example, torch run will set rank local rank in world size environmental
[02:48:45.840 --> 02:48:52.000] variables. And so this is a bad way to detect whether DDP is running. So if we're using torch
[02:48:52.000 --> 02:48:58.800] run, if DDP is running, then we have to make sure that good is available because I don't know that
[02:48:58.800 --> 02:49:06.960] you can run this on CPU anymore, or that that makes sense to do. This is some setup code here.
[02:49:06.960 --> 02:49:12.080] The important part is that there's a world size, which for us will be eight. That's the total number
[02:49:12.080 --> 02:49:19.680] of processes running. There's a rank which is each process will basically run the exact same code
[02:49:19.680 --> 02:49:25.040] at the exact same time roughly. But all the process, the only difference between these processes
[02:49:25.040 --> 02:49:31.920] is that they all have a different DDP rank. So the GPU zero will have DDP rank of zero,
[02:49:31.920 --> 02:49:37.200] GPU one will have a rank of one, et cetera. So otherwise they're all running the exact
[02:49:37.200 --> 02:49:42.000] same script. It's just that DDP rank will be a slightly different integer. And that is the way
[02:49:42.000 --> 02:49:47.120] for us to coordinate that they don't, for example, run on the same data. We want them to run on
[02:49:47.120 --> 02:49:53.600] different parts of the data and so on. Now local rank is something that is only used in a multi-node
[02:49:53.600 --> 02:50:00.480] setting. We only have a single node with a GPU. And so local rank is the rank of the GPU on a single
[02:50:00.480 --> 02:50:06.960] node. So from zero to seven as an example. But for us, we're mostly going to be running on a single
[02:50:06.960 --> 02:50:12.720] box. So the things we care about are rank and world size. This is eight. And this will be whatever
[02:50:12.720 --> 02:50:18.400] it is depending on the GPU that that this particular instantiation of the script runs on.
[02:50:18.400 --> 02:50:27.520] Now here, we make sure that according to the local rank, we are setting the device to be
[02:50:27.520 --> 02:50:33.360] could a column and column indicates which GPU to use if there are more than one GPUs.
[02:50:33.360 --> 02:50:40.400] So depending on the local rank of this process, it's going to use just the appropriate GPU. So
[02:50:40.400 --> 02:50:45.520] there's no collisions on which GPUs being used by which process. And finally, there's a Boolean
[02:50:45.520 --> 02:50:51.440] variable that I like to create, which is the DDP rank equals equals zero. So the master process
[02:50:51.440 --> 02:50:56.320] is arbitrarily process number zero. And it does a lot of the printing, logging, checkpointing,
[02:50:56.320 --> 02:51:00.800] et cetera. And the other processes are thought of mostly as a compute processes that are assisting.
[02:51:00.800 --> 02:51:05.680] And so master process zero will have some additional work to do. All the other processes
[02:51:05.680 --> 02:51:10.720] will will most just be doing forward backwards. And if we're not using DDP and none of these
[02:51:10.720 --> 02:51:15.440] variables are set, we revert back to single GPU training. So that means that we only have rank
[02:51:15.440 --> 02:51:22.480] zero. The world size is just one. And we are the master process. And we try to auto detect the
[02:51:22.480 --> 02:51:28.400] device. And this is world as normal. So so far, all we've done is we've initialized DDP.
[02:51:28.400 --> 02:51:32.960] And in the case where we're running with torch run, which we'll see in a bit,
[02:51:32.960 --> 02:51:37.520] there's going to be eight copies running in parallel. Each one of them will have a different
[02:51:37.520 --> 02:51:42.880] rank. And now we have to make sure that everything happens correctly afterwards.
[02:51:42.880 --> 02:51:48.160] So the tricky thing with running multiple processes is you always have to imagine that
[02:51:48.160 --> 02:51:53.360] there's going to be eight processes running in parallel. So as you read the code now,
[02:51:53.360 --> 02:51:58.800] you have to imagine there's eight, you know, eight Python interpreters running down these lines of
[02:51:58.800 --> 02:52:03.760] code. And the only difference between them is that they have a different DDP rank. So they all
[02:52:03.760 --> 02:52:09.360] come here, they all pick the exact same seed. They all make all of these calculations completely
[02:52:09.360 --> 02:52:14.560] unaware of the other copies running, roughly speaking, right? So they all make the exact same
[02:52:14.560 --> 02:52:19.760] calculations. And now we have to adjust these calculations to take into account that there's
[02:52:19.760 --> 02:52:26.400] actually like a certain world size and certain ranks. So in particular, these micro batches and
[02:52:26.400 --> 02:52:32.800] sequence links, these are all just per GPU, right? So now there's going to be num processes of them
[02:52:32.800 --> 02:52:38.560] running in parallel. So we have to adjust this, right? Because the gradicum steps now is going to
[02:52:38.560 --> 02:52:49.520] be total batch size divided B times T times DDP role size. Because each process will do B times T,
[02:52:49.520 --> 02:52:54.640] and there's this many of them. And so in addition to that, we want to make sure
[02:52:54.640 --> 02:53:01.040] that this fits nicely into total batch size, which for us, it will be because 16 times 1.24 times
[02:53:01.760 --> 02:53:09.760] eight APUs is 131. Okay. And so five, two, four, two, eight, eight. This means that our
[02:53:09.760 --> 02:53:16.800] gradicum will be four with the current settings, right? So there's going to be 16 times 124 process
[02:53:16.800 --> 02:53:23.120] in each GPU. And then there's a GPU. So we're going to be doing 131,000 tokens in a single forward
[02:53:23.120 --> 02:53:30.480] backward on the HPUs. So we're going to make sure that this fits nicely so that we can derive
[02:53:30.480 --> 02:53:36.640] a nice gradient accumulation of steps. And yeah, let's just adjust the comments here,
[02:53:36.640 --> 02:53:46.400] times DDP role size. Okay. So each GPU calculates this. Now this is where we started to run into
[02:53:46.400 --> 02:53:52.480] issues, right? So we are each process is going to come by a print, and they're all going to print.
[02:53:52.480 --> 02:53:57.680] So we're going to have eight copies of these prints. So one way to deal with this is exactly
[02:53:57.680 --> 02:54:03.920] this master process variable that we have. So if master process, then guard this. And that's
[02:54:03.920 --> 02:54:07.920] just so that we just print this a single time, because otherwise all the processes would have
[02:54:07.920 --> 02:54:13.840] computed the exact same variables, and there's no need to print this eight times. Before getting
[02:54:13.840 --> 02:54:19.280] into the data loader, we're going to have to refactor it, obviously. Maybe at this point is,
[02:54:19.280 --> 02:54:25.360] we should do some prints, and just take it out for a spin and exit at this point. So import sys.
[02:54:27.360 --> 02:54:44.480] And sys.exit in print, I am GPU DDP rank. I am GPU DDP rank and print by.
[02:54:44.480 --> 02:54:51.920] So now let's try to run this and just see how this works. So let's take it for a spin,
[02:54:51.920 --> 02:54:56.720] just so we see what it looks like. So normally we used to launch Python train GPT to that pile like
[02:54:56.720 --> 02:55:02.400] this. Now we're going to run the torch run. And this is what it looks like. So torch run standalone
[02:55:02.400 --> 02:55:07.760] number of processes, for example, is eight for us, because we have a GPUs, and then train GPT to
[02:55:07.760 --> 02:55:13.600] that pile. So this is what the command would look like. And torch run, again, we'll run eight of these.
[02:55:13.600 --> 02:55:20.720] So let's just see what happens. So first, it gets a little busy. So there's a lot going on here. So
[02:55:20.720 --> 02:55:26.480] first of all, there's some warnings from distributed. And I don't actually know that these mean anything.
[02:55:26.480 --> 02:55:30.400] I think this is just like, the code is setting up and the processes are coming online. And we're
[02:55:30.400 --> 02:55:36.400] seeing some preliminary failure to collect while the processes come up. I'm not 100% sure about that.
[02:55:36.400 --> 02:55:44.560] But we start to then get into actual prints. So all the processes went down. And then the first
[02:55:44.560 --> 02:55:52.000] print actually comes from process five, just by chance. And then it printed. So process five
[02:55:52.000 --> 02:55:59.920] basically got here first, it said I'm process on GPU five by. And then this de sprints come from
[02:55:59.920 --> 02:56:05.760] the master process. So process five just finished first for whatever reason, it just depends on how
[02:56:05.760 --> 02:56:11.600] the operating system scheduled the processes to run. Then GPU zero ended, then GPU three and two.
[02:56:11.600 --> 02:56:19.840] And then probably process five or something like that has exited. And DDP really doesn't like that
[02:56:19.840 --> 02:56:28.240] because we didn't properly dispose of the multi GPUs setting. And so process group has not been
[02:56:28.240 --> 02:56:33.840] destroyed before we destruct. So it really doesn't like that. And in an actual application, we would
[02:56:33.840 --> 02:56:40.400] want to call destroy process group, so that we clean up DDP properly. And so it doesn't like that
[02:56:40.400 --> 02:56:46.240] too much. And then the rest of the GPUs finish. And that's it. So basically, we can't guarantee
[02:56:46.240 --> 02:56:50.240] when these processes are running, it's totally arbitrary, but they are running in parallel.
[02:56:50.240 --> 02:56:58.800] We don't want that to be printing. And next up, let's erase this. Next up, we want to make
[02:56:58.800 --> 02:57:03.520] sure that when we create data, a little light, we need to now make it aware of this multi process
[02:57:03.520 --> 02:57:10.160] setting, because we don't want all the processes to be loading the exact same data. We want every
[02:57:10.160 --> 02:57:14.720] process to get its own chunk of data so that they're all working on different parts of the data set,
[02:57:14.720 --> 02:57:21.280] of course. So let's adjust that. So one particularly simple and an naive way to do this, is we have
[02:57:21.280 --> 02:57:27.840] to make sure that we pass in the rank and the size to the data loader. And then we come up here,
[02:57:27.840 --> 02:57:33.360] we see that we now take rank and processes and we save them. Now the current position will not be
[02:57:33.360 --> 02:57:40.480] zero, because what we want is we want to stride out all the processes. So one way to do this is
[02:57:40.480 --> 02:57:45.760] we basically take cell type B times cell type T, and then multiply it by the process rank.
[02:57:45.760 --> 02:57:52.880] So process rank zero will start at zero, but process rank one now starts at B times T.
[02:57:52.880 --> 02:57:58.880] Process rank two is starts at two times B times T, etc. So that is the initialization.
[02:57:58.880 --> 02:58:05.840] Now we still, they still do this identically, but now when we advance, we don't advance by B
[02:58:05.840 --> 02:58:12.400] times T, we advance by B times T times number of processes, right? So basically,
[02:58:12.400 --> 02:58:20.000] the total number of tokens that we're consuming is B times T times number processes, and they all
[02:58:20.000 --> 02:58:27.680] go off to a different rank. And the position has to advance by the entire chunk. And then
[02:58:27.680 --> 02:58:34.560] here at B times T times cell type number of processes, plus one would be to exceed number of tokens,
[02:58:34.560 --> 02:58:38.640] then we're going to loop. And when we loop, we want to, of course, loop in the exact same way.
[02:58:38.640 --> 02:58:46.480] So we sort of like reset back. So this is the simplest change that I can find for kind of a
[02:58:46.480 --> 02:58:52.160] very simple distributed data load of light. And you can notice that if process rank is zero,
[02:58:52.160 --> 02:58:56.640] and that process he says one, then the whole thing will be identical to what we had before.
[02:58:56.640 --> 02:59:00.800] But now we can have actually multiple processes running, and this should work fine.
[02:59:03.360 --> 02:59:07.920] So that's the data load. Okay, so next up, once they've all initialized the data load,
[02:59:07.920 --> 02:59:15.440] they come here and they all create a GPT model. So we create eight GPT models on eight processes.
[02:59:15.440 --> 02:59:19.440] But because the seeds are fixed here, they all create the same identical model.
[02:59:19.440 --> 02:59:25.280] They all move it to the device of their rank, and they all compile the model. And because the
[02:59:25.280 --> 02:59:30.080] models are identical, there are eight identical compilations happening in parallel, but that's okay.
[02:59:31.040 --> 02:59:35.440] Now, none of this changes because that is on a per step basis. And we're currently working
[02:59:35.440 --> 02:59:41.280] kind of within step because we need to just all the changes we're making are kind of like a
[02:59:41.280 --> 02:59:46.720] within step changes. Now, the important thing here is when we construct the model,
[02:59:46.720 --> 02:59:51.120] we actually have a bit of work to do here, get logits is deprecated. So create model.
[02:59:51.120 --> 02:59:58.160] We need to actually wrap the model into the distributed data parallel container.
[02:59:59.360 --> 03:00:06.080] So this is how we wrap the model into the DDP container. And these are the docs for DDP. And
[03:00:06.080 --> 03:00:10.240] they're quite extensive. And there's a lot of caveats and a lot of things to be careful with
[03:00:10.240 --> 03:00:15.920] because everything complexifies times 10 when multiple processes are involved. But roughly
[03:00:15.920 --> 03:00:20.320] speaking, this device ID, I believe has to be passed in. Now, unfortunately, the docs for what
[03:00:20.320 --> 03:00:27.360] device IDs is is extremely unclear. So when you actually like come here, this comment for what
[03:00:27.360 --> 03:00:34.880] device IDs is is roughly nonsensical. But I'm pretty sure it's supposed to be the DDP local rank.
[03:00:34.880 --> 03:00:41.840] So not the DDP rank, the local rank. So this is what you pass in here. This wraps the model.
[03:00:41.840 --> 03:00:46.720] And in particular, what DDP does for you is in a forward pass, it actually behaves identically.
[03:00:46.720 --> 03:00:51.920] So my understanding of it is nothing should be changed in the forward pass. But in the backward
[03:00:51.920 --> 03:00:58.400] pass, as you are doing the backward pass, in the simplest setting, once the backward pass is over
[03:00:58.400 --> 03:01:05.600] on each independent GPU, each independent GPU has the gradient for all the primers. And what DDP
[03:01:05.600 --> 03:01:11.040] does for you is once the backward pass is over, it will call what's called all reduce. And it
[03:01:11.040 --> 03:01:18.960] basically does an average across all the ranks of their gradients. And then it will deposit that
[03:01:18.960 --> 03:01:24.000] average on every single rank. So every single single rank will end up with the average
[03:01:24.000 --> 03:01:29.120] on it. And so basically, that's the communication. It just synchronizes and averages the gradients.
[03:01:29.120 --> 03:01:36.000] And that's what DDP offers you. Now DDP actually is a little bit more involved in that because
[03:01:36.000 --> 03:01:40.880] as you are doing the backward pass through the layers in the transformer, it actually can dispatch
[03:01:40.880 --> 03:01:46.880] communications for the gradient while the backward pass is still happening. So there's overlap of the
[03:01:46.880 --> 03:01:51.360] communication of the gradients and the synchronization of them and the backward pass.
[03:01:51.360 --> 03:01:57.840] And this is just more efficient and to do it that way. So that's what DDP does for you.
[03:01:57.840 --> 03:02:04.880] Forward is unchanged and backward is mostly unchanged and we're tacking on this average as we'll see
[03:02:04.880 --> 03:02:11.840] in a bit. Okay, so now let's go to the optimization. Nothing here changes. Let's go to the optimization
[03:02:11.840 --> 03:02:17.760] here, the inner loop and think through the synchronization of these gradients in the DDP. So basically,
[03:02:17.760 --> 03:02:22.640] by default, what happens as I mentioned is when you do lost a backward here, it will do the backward
[03:02:22.640 --> 03:02:28.800] pass and then it will synchronize the gradients. The problem here is because of the gradient
[03:02:28.800 --> 03:02:35.360] accumulation steps loop here, we don't actually want to do the synchronization after every single
[03:02:35.360 --> 03:02:40.320] lost a backward, because we are just depositing gradients. And we're doing that serially. And we
[03:02:40.320 --> 03:02:44.560] just want them adding up. And we don't want to synchronize every single time that would be extremely
[03:02:44.560 --> 03:02:49.920] wasteful. So basically, we want to add them up. And then on the very last, it's only on the very
[03:02:49.920 --> 03:02:55.680] last step, when micro step, when micro step becomes greater from steps minus one, only at that last
[03:02:55.680 --> 03:03:03.920] step that we want to actually do the all reduce to average of gradients. So to do that, we come here
[03:03:03.920 --> 03:03:11.040] and the official sanctioned way, by the way, is to do this no sync context manager. So PyTorch says,
[03:03:11.040 --> 03:03:16.720] this is a context manager to disable gradient synchronization across DDP processes. So within
[03:03:16.720 --> 03:03:22.800] this context, gradients will be accumulated. And basically, when you do no sync, there will be no
[03:03:22.800 --> 03:03:28.880] communication. So they are telling us to do with DDP, no sync, do the gradient accumulation,
[03:03:28.880 --> 03:03:34.320] accumulate grads. And then they are asking us to do DDP again with another input and that backward.
[03:03:34.320 --> 03:03:39.600] And I just really don't love this. I just really don't like it. The fact that you have to copy
[03:03:39.600 --> 03:03:44.000] paste your code here and use a context manager. And it's just super ugly. So when I went to this
[03:03:44.000 --> 03:03:51.360] source code here, you can see that when you enter, you simply toggle this variable, this
[03:03:51.360 --> 03:03:58.560] required backward grad sync. And this is being toggled around and changed. And this is the
[03:03:58.560 --> 03:04:05.200] variable that basically, if you step through it, is being toggled to determine if the gradient is
[03:04:05.200 --> 03:04:10.720] going to be synchronized. So I actually just kind of like to use that directly. So instead,
[03:04:10.720 --> 03:04:18.000] what I like to do is the following right here, before the loss back backward, if we are using DDP,
[03:04:18.880 --> 03:04:25.760] then then basically, we only want to synchronize. We only want this variable to be true.
[03:04:25.760 --> 03:04:33.200] When it is the final iteration in all the other iterations inside the micro steps, we want to be
[03:04:33.200 --> 03:04:39.280] false. So I just toggle it like this. So require backward grad sync should only turn on when the
[03:04:39.280 --> 03:04:46.480] micro step is the last step. And so I'm toggling this variable directly. And I hope that that
[03:04:46.480 --> 03:04:51.040] impacts lost up backward. And this is a naughty thing to do because, you know, they could probably
[03:04:51.040 --> 03:04:56.400] change the DDP and this variable will go away. But for now, I believe this this works. And it
[03:04:56.400 --> 03:05:01.200] allows me to avoid the use of context managers and code duplication. I'm just toggling the variable
[03:05:01.200 --> 03:05:04.800] and then lost up backward will not synchronize most of the steps. And it will synchronize the
[03:05:04.800 --> 03:05:15.120] very last step. And so once this is over, and we come out, every single rank will suddenly magically
[03:05:15.120 --> 03:05:22.320] have the average of all the gradients that are stored on all the ranks. So now we have to think
[03:05:22.320 --> 03:05:29.600] through whether that is what we want. And also, if this suffices and whether how it works with the
[03:05:29.600 --> 03:05:34.480] loss and what is loss acume, so let's think through them through them. And the problem I'm
[03:05:34.480 --> 03:05:40.160] getting at is that we've averaged the gradients, which is great, but the loss acume has not been
[03:05:40.160 --> 03:05:45.840] impacted yet. And the, and this is outside of the DDP container. So that is not being averaged.
[03:05:45.840 --> 03:05:51.200] And so here, when we are printing loss acume, well, presumably, we're only going to be printing on
[03:05:51.200 --> 03:05:56.080] the master process, rank zero. And it's just going to be printing the losses that it saw on its
[03:05:56.080 --> 03:06:01.680] process. But instead, we want it to print the loss over all the processes and the average of
[03:06:01.680 --> 03:06:06.880] that loss, because we did average of gradients. So we want the average of loss as well. So simply
[03:06:06.880 --> 03:06:14.560] here after this, this is the code that I've used in the past. And instead of loss, we want loss
[03:06:14.560 --> 03:06:25.280] acume. So if DDP, again, then this is a PyTorch distributed, I imported, what do I imported?
[03:06:25.280 --> 03:06:34.000] Oh, gosh. So this file is starting to get out of control, huh? So if, so import torch
[03:06:34.000 --> 03:06:42.400] distributed as this. So this dot all reduce, and we're doing the average on loss acume. And so this
[03:06:42.400 --> 03:06:47.840] loss acume tensor exists on all the ranks, when we call all reduce of average, it creates the
[03:06:47.840 --> 03:06:53.440] average of those numbers, and it deposits that average on all the ranks. So all the ranks after
[03:06:53.440 --> 03:07:01.200] this call will now contain loss acume averaged up. And so when we print here on the master process,
[03:07:01.200 --> 03:07:05.920] the loss acume is identical in all the other ranks as well. So here, if master process,
[03:07:05.920 --> 03:07:12.640] oops, we want to print like this. Okay, and finally, we have to be careful because we're not
[03:07:12.640 --> 03:07:20.160] processing even more tokens. So times DDP world size, that's number of tokens that we've processed
[03:07:20.160 --> 03:07:29.440] up above. And everything else should be fine. The only other thing to be careful with is as I
[03:07:29.440 --> 03:07:34.000] mentioned, you want to destroy the process group so that we are nice to nickel, and it's not going
[03:07:34.000 --> 03:07:42.480] to to to DDP. And it's not going to complain to us when we exit here. So that should be it.
[03:07:42.480 --> 03:07:46.720] Let's try to take it for a spin. Okay, so I launched the script, and it should be printing
[03:07:46.720 --> 03:07:51.440] here imminently. We're now training with eight GPUs at the same time. So the gradient accumulation
[03:07:51.440 --> 03:07:59.840] steps is not 32 is now divide eight. And it's just four. So otherwise, this is what the optimization
[03:07:59.840 --> 03:08:08.000] looks like. And while we're going really fast, so we're processing 1.5 million tokens per second
[03:08:08.000 --> 03:08:12.880] now. So these are some serious numbers. And the tiny Shakespeare data set is so tiny that we're
[03:08:12.880 --> 03:08:17.760] just doing like so many epochs over it, most likely. But this is roughly what it looks like.
[03:08:19.600 --> 03:08:24.320] One thing that I had to fix, by the way, is that this was a model that configure optimizers,
[03:08:24.320 --> 03:08:29.360] which now doesn't work because model now is a DDP model. So instead, this has to become raw model
[03:08:29.360 --> 03:08:36.080] that configure optimizers, where raw model is something I create here. So right after I wrapped
[03:08:36.080 --> 03:08:43.280] a model into DDP, I have to create the raw model, which in the case of DDP is a model that module
[03:08:43.280 --> 03:08:49.040] is where it stores the raw and a module of GPT two as we have it, which contains the
[03:08:49.040 --> 03:08:53.600] configure optimizers function that we want to call. So that's one thing that I had to fix.
[03:08:53.600 --> 03:08:58.480] Otherwise, this seems to run. Now, one thing you'll notice is that when you actually compare
[03:08:58.480 --> 03:09:04.160] this run and the numbers in it to the just running a single GPU, you'll notice that this is single
[03:09:04.160 --> 03:09:11.600] GPU run with 32 gradacum. The numbers won't exactly match up. And it's kind of a boring reason for
[03:09:11.600 --> 03:09:16.560] why that happens. The reason for that is that in a data lawyer, we're basically just iterating
[03:09:16.560 --> 03:09:21.200] through batches in a slightly different way, because now we're looking for an entire page of data.
[03:09:21.200 --> 03:09:28.320] And if that page for all the GPUs, if that chunk exceeds the number of tokens, we just loop. And so
[03:09:28.320 --> 03:09:35.200] actually the single GPU and the GPU process will end up resetting in a slightly different manner,
[03:09:35.200 --> 03:09:38.640] as our batches are slightly different. And so we get slightly different numbers.
[03:09:39.200 --> 03:09:43.920] But one way to convince yourself that this is okay, is just make the total batch size much
[03:09:43.920 --> 03:09:52.560] smaller and the B and a T. And then so I think I used four times one 24 times eight. So I used
[03:09:52.560 --> 03:09:59.040] 32 768 as a total pack size. And then so I made sure that the single GPU will do eight
[03:09:59.040 --> 03:10:04.240] grade even simulation steps, and then the multi GPU. And then you're reducing the boundary effects
[03:10:04.240 --> 03:10:08.960] of the data loader. And you'll see that the numbers match up. So one story short, we're now
[03:10:08.960 --> 03:10:13.840] going really, really fast. The optimization is mostly consistent with GPT two and three hybrid
[03:10:13.840 --> 03:10:20.640] parameters. And we have outgrown our tiny Shakespeare file. And we want to upgrade it. So let's move
[03:10:20.640 --> 03:10:25.040] to next to that next. So let's now take a look at what data sets were used by GPT two and GPT three.
[03:10:25.040 --> 03:10:31.600] So GPT two used this web text data set that was never released. There's an attempt at
[03:10:31.600 --> 03:10:36.160] reproducing it called open web text. So basically, roughly speaking, what they say here in the paper
[03:10:36.160 --> 03:10:42.880] is that is create all outbound links from Reddit. And then with at least three karma. And that
[03:10:42.880 --> 03:10:46.880] was kind of like their starting point and they collected all the web pages and all the text in
[03:10:46.880 --> 03:10:52.320] them. And so this was 45 million links. And this ended up being 40 gigabytes of text. So
[03:10:52.320 --> 03:10:58.640] so that's roughly what GPT two says about its data set. So it's basically outbound links from
[03:10:58.640 --> 03:11:03.920] Reddit. Now when we go over to GPT three, there's a training data set section. And that's where they
[03:11:03.920 --> 03:11:11.360] start to talk about common crawl, which is a lot more used. Actually, I think you can GPT two
[03:11:11.360 --> 03:11:16.640] talked about common crawl. But basically, it's not a very high quality data set all by itself,
[03:11:16.640 --> 03:11:21.120] because it's extremely noisy. This is a completely random subset of the internet. And it's much
[03:11:21.120 --> 03:11:25.520] worse than you think. So people go into great lengths to filter common crawl, because there's
[03:11:25.520 --> 03:11:30.480] good stuff in it. But most of it is just like ad spam and random tables and numbers and stock
[03:11:30.480 --> 03:11:39.760] takers. And it's just total mess. So that's why people like to train on these data mixtures
[03:11:39.760 --> 03:11:45.760] that they curate and are careful with. So a large chunk of these data mixtures typically will be
[03:11:45.760 --> 03:11:50.880] common crawl, like for example, 50% of the tokens will be common crawl. But then here in GPT three,
[03:11:50.880 --> 03:11:55.280] they're also using web text to from before. So that's Reddit outbound. But they're also adding,
[03:11:55.280 --> 03:11:59.840] for example, books. And they're anything Wikipedia. There's many other things you can decide to add.
[03:12:00.560 --> 03:12:05.280] Now, this data set for GPT three was also never released. So today, some of the datasets that
[03:12:05.280 --> 03:12:09.840] I'm familiar with that are quite good and would be representative of something along these lines.
[03:12:09.840 --> 03:12:15.760] Our number one, the red pajama data set, or more specifically, for example, the slim pajama subset
[03:12:15.760 --> 03:12:21.280] of the red pajama data set, which is a cleaned and de-duplicated version of it. And just to give
[03:12:21.280 --> 03:12:27.600] you a sense, again, it's a bunch of common crawl, C four, which is also, as far as I know, more common
[03:12:27.600 --> 03:12:33.600] crawl, but processed differently. And then we have GitHub books, archive, Wikipedia, stack is change.
[03:12:33.600 --> 03:12:37.920] These are the kinds of datasets that would go into these data mixtures. Now, specifically the one
[03:12:37.920 --> 03:12:44.080] that I like that came out recently is called fine web data set. So this is an attempt to basically
[03:12:44.080 --> 03:12:50.320] collect really high quality common crawl data and filter it, in this case, to 15 trillion tokens.
[03:12:50.320 --> 03:12:56.000] And then in addition to that, more recently, hugging face released this fine web EDU subset,
[03:12:56.000 --> 03:13:02.160] which is 1.3 trillion of educational and 5.4 trillion of high educational content.
[03:13:02.160 --> 03:13:08.000] So basically, they're trying to filter common crawl to very high quality educational subsets.
[03:13:08.000 --> 03:13:15.040] And this is the one that we will use. There's a long web page here on fine web, and they go into a
[03:13:15.040 --> 03:13:19.120] ton of detail about how they process the data, which is really fascinating reading, by the way.
[03:13:19.120 --> 03:13:22.560] And I would definitely recommend if you're interested into data mixtures and so on,
[03:13:22.560 --> 03:13:27.600] and how data gets processed at these scales, I'll look at this page. And more specifically,
[03:13:27.600 --> 03:13:33.200] we'll be working with the fine web EDU, I think. And it's basically educational content from the
[03:13:33.200 --> 03:13:41.920] internet. They show that training on educational content in their metrics works really, really well.
[03:13:41.920 --> 03:13:49.200] And we're going to use this sample 10 billion tokens subsample of it, because we're not going
[03:13:49.200 --> 03:13:55.120] to be training on trillions of tokens. We're just going to train on 10 billion sample off the fine
[03:13:55.120 --> 03:14:00.240] web EDU, because empirically, in my previous few experiments, this actually suffices to really
[03:14:00.240 --> 03:14:05.520] get close to GPT two performance. And it's simple enough to work with. And so let's work with the
[03:14:05.520 --> 03:14:12.960] sample 10 BT. So our goal will be to download it, process it, and make sure that our data loader
[03:14:12.960 --> 03:14:19.840] can work with it. So let's get to that. Okay, so I introduced another file here that will
[03:14:19.840 --> 03:14:26.320] basically download fine web EDU from hugging face datasets. It will pre process and pre tokenize
[03:14:26.320 --> 03:14:36.480] all of the data, and it will save data shards to a folder on a local disk. And so while this is running,
[03:14:36.480 --> 03:14:41.840] just wanted to briefly mention that you can kind of look through the data set viewer here,
[03:14:41.840 --> 03:14:45.280] just to get a sense of what's in here. And it's kind of interesting. I mean, it's a,
[03:14:45.280 --> 03:14:50.160] it basically looks like it's working fairly well. Like it's talking about nuclear energy in France,
[03:14:50.160 --> 03:14:58.320] it's talking about Mexican America, some Mac, Pi J's, et cetera. So actually, it seems like
[03:14:58.320 --> 03:15:03.680] their filters are working pretty well. The filters here, by the way, were applied automatically using
[03:15:04.560 --> 03:15:11.200] llama 370b, I believe. And so basically, LLMs are judging which content is educational, and that
[03:15:11.200 --> 03:15:15.760] ends up making it through the filter. So that's pretty cool. Now in terms of the script itself,
[03:15:15.760 --> 03:15:20.800] I'm not going to go through the full script, because it's not as interesting and not as LLM
[03:15:20.800 --> 03:15:25.040] centric. But when you run this, basically, number one, we're going to load the data set,
[03:15:25.040 --> 03:15:30.800] which this is all hugging face code, running this, you're going to need to pip install datasets.
[03:15:33.360 --> 03:15:39.360] So it's downloading the dataset, then it is tokenizing all of the documents inside this dataset.
[03:15:39.360 --> 03:15:44.960] Now, when we tokenize the documents, you'll notice that to tokenize a single document,
[03:15:44.960 --> 03:15:52.080] we first start the tokens with the end of text token. And this is a special token in the GPT-2
[03:15:52.080 --> 03:15:59.040] tokenizer, as you know. So 50,256 is the ID of the end of text. And this is what begins a document,
[03:15:59.040 --> 03:16:03.440] even though it's called end of text. But this is the first token that begins a document.
[03:16:03.440 --> 03:16:09.680] Then we extend with all of the tokens of that document, then we create a NumPy array out of that.
[03:16:09.680 --> 03:16:18.320] We make sure that all the tokens are between, oh, okay, let me debug this. Okay, so apologies for
[03:16:18.320 --> 03:16:22.880] that. It just had to do with me using a float division in Python, it must be integer division,
[03:16:22.880 --> 03:16:29.360] so that this is an int and everything is nice. Okay, but basically, the tokenization here is
[03:16:29.360 --> 03:16:36.320] relatively straightforward returns tokens in mp.un16. We're using un.16 to save a little bit of space,
[03:16:36.320 --> 03:16:43.440] because two to the 16 minus one is 65,000. So the GPT-2 max token ID is well below that.
[03:16:43.440 --> 03:16:47.920] And then here, there's a bunch of multi processing code, and it's honestly not that exciting,
[03:16:47.920 --> 03:16:52.640] so I'm not going to step through it. But we're loading the dataset, we're tokenizing it,
[03:16:52.640 --> 03:16:59.520] and we're saving everything to shards. And the shards are NumPy files. So just storing a NumPy
[03:16:59.520 --> 03:17:08.480] array, which is very, very similar to Torx tensors. And the first shard, 000, is a validation shard,
[03:17:08.480 --> 03:17:14.320] and all the other shards are training shards. And as I mentioned, they all have 100 million tokens
[03:17:14.320 --> 03:17:21.680] in them exactly. And that just makes it easier to work with us to shard the files,
[03:17:21.680 --> 03:17:25.840] because if we just have a single massive file, sometimes it can be hard to work with on the disk.
[03:17:25.840 --> 03:17:32.480] And so sharding it is just kind of a messier from that perspective. And yeah, so we'll just let
[03:17:32.480 --> 03:17:38.960] this run. This will be probably 30ish minutes or so, and then we're going to come back to actually
[03:17:38.960 --> 03:17:43.120] train on this data. And we're going to be actually doing some legit pre training in this case. This
[03:17:43.120 --> 03:17:49.360] is a good data set. We're doing lots of tokens per second. We have HEPUs, the code is ready.
[03:17:49.360 --> 03:17:53.440] And so we're actually going to be doing a serious training run. So let's get back in a minute.
[03:17:53.440 --> 03:18:01.040] Okay, so we're back. So if we LSE do find web, we see that there's now 100 shards in it.
[03:18:01.040 --> 03:18:07.600] And that makes sense because each shard is 100 million tokens. So 100 shards of that is 10 billion
[03:18:07.600 --> 03:18:12.960] tokens in total. Now swinging over to the main file, I made some adjustments to our data loader
[03:18:12.960 --> 03:18:19.120] again. And that's because we're not running with Shakespeare anymore. We want to use the find web
[03:18:19.120 --> 03:18:23.680] shards. And so you'll see some code here that additionally basically can load these shards.
[03:18:23.680 --> 03:18:31.600] We load the UN16 NumPy file. We convert it to a torch.long tensor, which is what a lot of the
[03:18:31.600 --> 03:18:38.400] layers up top expect by default. And then here we're just enumerating all the shards. I also added
[03:18:38.400 --> 03:18:44.000] a split to data loader light. So we can load the split train, but also the split valve, the zero
[03:18:44.000 --> 03:18:50.400] split. And then we can load the shards. And then here we also have not just a current position now,
[03:18:50.400 --> 03:18:56.720] but also the current shard. So we have a position inside a shard. And then when we run out of tokens
[03:18:56.720 --> 03:19:02.880] in a single shard, we first advance the shard and loop if we need to. And then we get the tokens
[03:19:02.880 --> 03:19:09.920] and readjust the position. So this data loader will now iterate all the shards as well. So I
[03:19:09.920 --> 03:19:15.200] changed that. And then the other thing that I did while the data is processing is our train
[03:19:15.200 --> 03:19:24.000] loader now has split train, of course. And down here I set up some numbers. So we are doing two to
[03:19:24.000 --> 03:19:34.320] the 19 tokens per per step. And we want to do roughly 10 billion tokens,
[03:19:34.320 --> 03:19:39.680] because that's how many unique tokens we have. So if we did 10 billion tokens, then divide that
[03:19:39.680 --> 03:19:46.000] by two to the 19, we see that this is 19,073 steps. So that's where that's from. And then
[03:19:46.000 --> 03:19:50.880] the GPT three paper says that they warm up the learning rate over 375 million tokens.
[03:19:51.760 --> 03:20:01.760] So I came here and 375 e6 tokens divide two to the 19 is 715 steps. So that's why warm up steps
[03:20:01.760 --> 03:20:07.680] is set to 715. So this will exactly match the warm up schedule that GPT three used.
[03:20:07.680 --> 03:20:13.520] And I think 715 by the way is very mild. And this could be made significantly more aggressive,
[03:20:13.520 --> 03:20:19.280] probably even like 100 is good enough. But it's okay, let's leave it for now so that we have the
[03:20:19.280 --> 03:20:26.560] exact hyper parameters of GPT three. So I fix that. And then that's pretty much it. We can,
[03:20:26.560 --> 03:20:34.720] we can run. So we have our script here. And we can launch. And actually, sorry, let me do one more thing.
[03:20:34.720 --> 03:20:47.920] Excuse me. For my GPU, I can actually fit more back size. And I believe I can fit 64 on my GPU
[03:20:48.800 --> 03:20:51.120] as a micro batch size. So let me try that.
[03:20:51.120 --> 03:21:01.680] I could be misremembering. But that means 64 times 124 per GPU. And then we have a GPU. So that
[03:21:01.680 --> 03:21:07.680] means we would not even be doing gradient accumulation in this fit. Because this just multiplies out to
[03:21:07.680 --> 03:21:14.800] the full total back size. So no gradient accumulation. And that would run pretty quickly if that fits.
[03:21:19.680 --> 03:21:31.840] Let's go. Let's go. I mean, if this works, then this is basically a serious pre-training run.
[03:21:31.840 --> 03:21:36.480] We're not logging. We're not evaluating the validation split. We're not running any
[03:21:36.480 --> 03:21:42.800] valuations yet. So it's not we haven't crossed our T's and dotted our eyes. But if we let this run
[03:21:42.800 --> 03:21:48.000] for a while, we're going to actually get a pretty good model. And the model that might even be
[03:21:48.000 --> 03:21:55.360] on par with or better than GPT two on 24. Okay. So it looks like everything is growing great.
[03:21:55.360 --> 03:21:58.160] We're processing 1.5 million tokens per second.
[03:21:58.160 --> 03:22:08.720] Everything here looks good. We're doing 330 milliseconds per iteration. And we have to do a total of
[03:22:10.320 --> 03:22:20.560] where are we printing that? 1973. So 1903 times 0.33 is this many seconds, this many minutes.
[03:22:20.560 --> 03:22:30.560] So this will run for 1.7 hours. So one and a half hour run like this. And we don't even have to use
[03:22:30.560 --> 03:22:34.720] gradient accumulation, which is nice. And you might not have that luxury in your GPU. In that case,
[03:22:34.720 --> 03:22:38.800] just start decreasing the batch size until things fit. But keep it to nice numbers.
[03:22:41.040 --> 03:22:44.640] So that's pretty exciting. We're currently warming up the learning rate. So you see that
[03:22:44.640 --> 03:22:49.360] it's still very low, one in negative four. So this will ramp up over the next few steps all the way
[03:22:49.360 --> 03:22:58.080] to 16 negative four here. Very cool. So now what I'd like to do is let's cross the T's and dot
[03:22:58.080 --> 03:23:03.360] our eyes. Let's evaluate on the validation split. And let's try to figure out how we can run emails,
[03:23:03.360 --> 03:23:09.040] how we can do logging, how we can visualize our losses, and all the good stuff. So let's get to
[03:23:09.040 --> 03:23:13.600] that before we actually do the run. Okay, so I've adjusted the code so that we're evaluating on
[03:23:13.600 --> 03:23:18.400] the validation split. So creating the val loader, just by passing in split equals val,
[03:23:18.400 --> 03:23:22.000] that will basically create a data loader just for the validation shard.
[03:23:22.000 --> 03:23:28.560] The other thing I did is in the data loader, I introduced a new function reset, which is called
[03:23:28.560 --> 03:23:34.240] at init. And it basically resets the data loader. And that is very useful because when we come to
[03:23:34.240 --> 03:23:40.720] the main training loop now, so this is the code I've added. And basically every 100 iteration,
[03:23:40.720 --> 03:23:46.480] including the zero iteration, we put the model into evaluation mode, we reset the val loader,
[03:23:46.480 --> 03:23:55.360] and then no gradients involved. We're going to basically accumulate the gradients over say 20
[03:23:55.360 --> 03:24:02.720] steps, and then average it all up and print out the validation loss. And so that basically
[03:24:02.720 --> 03:24:07.840] is the exact same logic as the training roughly, but there's no loss that backward. It's only
[03:24:07.840 --> 03:24:12.160] inference. We're just measuring the loss. We're adding it up. Everything else otherwise applies
[03:24:12.160 --> 03:24:17.760] and is exactly as we've seen it before. And so this will print the validation loss every 100th
[03:24:17.760 --> 03:24:24.160] iteration, including the very first iteration. So that's nice. That will tell us some amount,
[03:24:24.160 --> 03:24:29.760] some a little bit about how much we're overfitting. That said, like we have roughly infinity data.
[03:24:29.760 --> 03:24:34.240] So we're mostly expecting our train and vowel loss to be about the same. But the other reason
[03:24:34.240 --> 03:24:39.280] I'm kind of interested in this is because we're can take the GPT two 124M as opening I released it,
[03:24:39.280 --> 03:24:43.680] we can initialize from it, and we can basically see what kind of loss it achieves on the validation
[03:24:43.680 --> 03:24:48.560] loss as well. And that gives us kind of an indication as to how much that model would
[03:24:48.560 --> 03:24:54.560] generalize to 124M. But it's not an, it's sorry to find what edu validations, but that said,
[03:24:54.560 --> 03:24:58.160] it's not a super fair comparison to GPT two because it was trained on a very different data
[03:24:58.160 --> 03:25:02.720] distribution. But it's still kind of like an interesting data point. And in any case, you would
[03:25:02.720 --> 03:25:08.160] always want to have a validation split in a training run like this, so that you can make sure that
[03:25:08.160 --> 03:25:14.880] you are not overfitting. And this is especially a concern if we were to make more epochs in our
[03:25:14.880 --> 03:25:20.240] training data. So for example, right now we're just doing a single epoch. But if we get to a point
[03:25:20.240 --> 03:25:25.360] for everyone on training box or something like that, we would be really careful with maybe we
[03:25:25.360 --> 03:25:30.640] are memorizing that data too much. If we have a big enough model, and our validation split would
[03:25:30.640 --> 03:25:34.880] be one way to tell whether that is happening. Okay, and in addition to that, if you remember,
[03:25:34.880 --> 03:25:39.440] at the bottom of our script, we had all of this orphaned code for sampling from way back when.
[03:25:39.440 --> 03:25:45.680] So I deleted that code and I moved it up to here. So once in a while, we simply evaluate
[03:25:45.680 --> 03:25:53.520] validation. Once in a while, we sample, we generate samples. And then we do that only every 100
[03:25:53.520 --> 03:25:57.840] steps, and we train on every single step. So that's how I have a structure right now. And
[03:25:57.840 --> 03:26:01.920] I've been running this for 1000 iterations. So here are some samples on iteration 1000.
[03:26:01.920 --> 03:26:08.320] Hello, I'm a language model, and I'm not able to get more creative.
[03:26:08.320 --> 03:26:13.920] I'm a language model and languages file you're learning about here is, or is the beginning of
[03:26:13.920 --> 03:26:21.360] a computer. Okay, so this is all like pretty, there's still a garble, but we're only at
[03:26:21.360 --> 03:26:26.800] iteration 1000. And we've only just barely reached the maximum learning rate. So this is still
[03:26:26.800 --> 03:26:37.760] a learning. We're about to get some more samples coming up in 100. Okay, this is, you know, the
[03:26:37.760 --> 03:26:45.280] model is still a young baby. Okay, so basically all of this sampling code that I've put here,
[03:26:45.280 --> 03:26:49.360] everything should be familiar with to you and came from before. The only thing that I did is I
[03:26:49.360 --> 03:26:55.040] created a generator object in PyTorch, so that I have a direct control over the sampling
[03:26:55.040 --> 03:26:59.920] of the random numbers, because I don't want to impact the RNG state of the random number
[03:26:59.920 --> 03:27:04.400] generator that is the global one used for training. I want this to be completely outside of the
[03:27:04.400 --> 03:27:11.040] training loop. And so I'm using a special sampling RNG. And then I make sure to seed it
[03:27:11.040 --> 03:27:16.240] that every single rank has a different seed. And then I pass in here, where we sort of
[03:27:16.240 --> 03:27:20.880] consumer in the numbers in multinomial, where the sampling happens, I make sure to pass in the
[03:27:20.880 --> 03:27:27.040] generator object there. Otherwise, this is identical. Now the other thing is, you'll notice that we're
[03:27:27.040 --> 03:27:32.240] running a bit slower. That's because I actually had to disable torch.compile to get this to sample.
[03:27:32.240 --> 03:27:37.040] And so we're running a bit slower. So for some reason it works with no torch compile,
[03:27:37.040 --> 03:27:41.520] but when I torch compile my model, I get a really scary error from PyTorch and I have no idea how
[03:27:41.520 --> 03:27:46.080] to resolve it right now. So probably by the time you see this code released or something like that,
[03:27:46.080 --> 03:27:51.120] maybe it's fixed. But for now, I'm just going to do pin false. And I'm going to bring back
[03:27:51.120 --> 03:27:57.760] torch compile. And you're not going to get samples. And I think I'll fix this later. By the way,
[03:27:57.760 --> 03:28:03.440] I will be releasing all this code. And actually, I've been very careful about making Git commits
[03:28:03.440 --> 03:28:08.160] every time we add something. And so I'm going to release the entire repo that starts completely
[03:28:08.160 --> 03:28:13.440] from scratch all the way to now and after this as well. And so everything should be
[03:28:13.440 --> 03:28:19.200] exactly documented in the Git commit history. And so I think that will be nice. So hopefully,
[03:28:19.200 --> 03:28:23.360] by the time you go to GitHub, this is removed and it's working. And I will have fixed the buck.
[03:28:23.360 --> 03:28:28.960] Okay, so I have the optimization running here. And it's stepping and we're on step 6000 or so,
[03:28:28.960 --> 03:28:33.200] so we're about 30% through training. Now, while this is training, I would like to introduce
[03:28:33.200 --> 03:28:39.120] one evaluation that we're going to use to supplement the validation set. And that is the heliswag eval.
[03:28:39.120 --> 03:28:46.160] So heliswag comes from this paper back in 2019. So it's a five year old eval now. And the way
[03:28:46.160 --> 03:28:51.520] heliswag works is there is basically a sentence completion data set. So it's a multiple choice.
[03:28:51.520 --> 03:28:57.680] For every one of these questions, we have basically a shared context like a woman is outside with
[03:28:57.680 --> 03:29:04.000] a bucket and a dog. The dog is running around trying to avoid bath, she, A, raises the bucket
[03:29:04.000 --> 03:29:10.800] off with soap and blow dry the dog's head, B uses a hose to keep it from getting soapy, C
[03:29:10.800 --> 03:29:17.600] gets the dog wet and it runs away again, or D gets into a bathtub with the dog. And so basically,
[03:29:17.600 --> 03:29:23.280] the idea is that these multiple choice are constructed so that one of them is a natural
[03:29:23.280 --> 03:29:32.880] continuation of the sentence and the others are not. And the others might not make sense like
[03:29:32.880 --> 03:29:37.120] uses the hose to keep it from getting soapy, that makes no sense. And so what happens is that
[03:29:37.120 --> 03:29:43.040] models that are not trained very well are not able to tell these apart, but models that have a lot
[03:29:43.040 --> 03:29:49.440] of world knowledge and can tell which and can tell a lot about the world will be able to
[03:29:49.440 --> 03:29:55.680] create these completions. And these sentences are sourced from activity net and from Wookieow.
[03:29:55.680 --> 03:30:04.560] And at the bottom of the paper, there's kind of like a cool chart of the kinds of domains
[03:30:04.560 --> 03:30:09.840] in Wookieow. So there's a lot of sentences from computers and electronics and homes and garden.
[03:30:09.840 --> 03:30:14.320] And it has kind of a broad coverage of the kinds of things you need to know about the world in order
[03:30:14.320 --> 03:30:21.040] to find the most likely completion and the identity of that completion.
[03:30:21.040 --> 03:30:25.920] One more thing that's kind of interesting about Halaswag is the way it was constructed
[03:30:25.920 --> 03:30:35.600] is that the incorrect options are deliberately adversarially sourced. So they're not just random
[03:30:35.600 --> 03:30:40.400] sentences. They're actually sentences generated by language models. And they're generated in such
[03:30:40.400 --> 03:30:45.360] a way that language models basically find them difficult, but humans find them easy. And so
[03:30:45.360 --> 03:30:50.400] they mentioned that humans have a 95% accuracy on this set. But at the time, the state of their
[03:30:50.400 --> 03:30:57.440] language models had only 48%. And so at the time, this was a good benchmark. Now, you can read the
[03:30:57.440 --> 03:31:02.960] details of this paper to learn more. The thing to point out though is that this is five years ago.
[03:31:02.960 --> 03:31:10.960] And since then what happened to Halaswag is that it's been totally just solved. And so now the
[03:31:10.960 --> 03:31:17.120] language models here are 96%. So basically, the last 4% is probably errors in the dataset,
[03:31:17.120 --> 03:31:21.440] or the questions are really, really hard. And so basically, this dataset is kind of crushed with
[03:31:21.440 --> 03:31:25.040] respect to language models. But back then, the best of the language model was only at about 50%.
[03:31:25.040 --> 03:31:32.000] But this is how far things got. But still, the reason people like Halaswag,
[03:31:32.560 --> 03:31:39.920] and it's not used by the way in GPT-2, but in GPT-3, there is Halaswag Evale. And lots of people use
[03:31:39.920 --> 03:31:48.560] Halaswag. And so with GPT-3, we have results here that are cited. So we know what percent
[03:31:48.560 --> 03:31:55.360] accuracy is GPT-3 attains at all these different model checkpoints for Halaswag Evale. And the
[03:31:55.360 --> 03:32:00.960] reason people like it is because Halaswag is a smooth eval. And it is an eval that offers,
[03:32:00.960 --> 03:32:07.360] quote, unquote, early signal. So early signal means that even small language models are going to
[03:32:07.360 --> 03:32:11.680] start at the random chance of 25%. But they're going to slowly improve. And you're going to see
[03:32:11.680 --> 03:32:18.800] 25, 26, 27, et cetera. And you can see slow improvement, even when the models are very small,
[03:32:18.800 --> 03:32:26.960] and it's very early. So it's smooth. It has early signal. And it's been around for a long time. So
[03:32:26.960 --> 03:32:32.720] that's why people will kind of like this eval. Now, the way that we're going to evaluate this
[03:32:32.720 --> 03:32:39.440] is as follows. As I mentioned, we have a shared context. And this is kind of like a multiple
[03:32:39.440 --> 03:32:44.320] choice task. But instead of giving the model a multiple choice question and asking it for A,
[03:32:44.320 --> 03:32:50.320] B, C, or D, we can't do that because these models, when they are so small, as we are seeing here,
[03:32:50.320 --> 03:32:54.720] the models can't actually do multiple choice. They don't understand the concept of associating
[03:32:54.720 --> 03:32:59.760] a label to one of the options of multiple choice. They don't understand that. So we have to give
[03:32:59.760 --> 03:33:06.240] it to them in native form. And the native form is a token completion. So here we do we construct a
[03:33:06.240 --> 03:33:14.240] batch of four rows and T tokens, whatever that T happens to be. Then the shared context, that is
[03:33:14.240 --> 03:33:20.080] basically the context for the four choices, the tokens of that are shared across all of the rows.
[03:33:20.640 --> 03:33:25.280] And then we have the four options. So we kind of like lay them out. And then only one of the
[03:33:25.280 --> 03:33:31.360] options is correct. In this case, label three, option three. And so this is the correct option,
[03:33:31.360 --> 03:33:37.200] and option one, two for our incorrect. Now, these options might be of different lengths.
[03:33:37.200 --> 03:33:42.400] So what we do is we sort of like take the longest length, and that's the size of the batch B by T.
[03:33:42.400 --> 03:33:47.840] And then some of these here are going to be padded dimensions. So they're going to be
[03:33:47.840 --> 03:33:54.000] unused. And so we need the tokens. We need the correct label. And we need a mask
[03:33:54.000 --> 03:33:59.920] that tells us which tokens are active. And the mask is then zero for these padded areas.
[03:33:59.920 --> 03:34:06.160] So that's how we construct these batches. And then in order to get the language model to predict
[03:34:06.160 --> 03:34:10.880] A, B, C, or D, the way this works is basically we're just going to look at the tokens,
[03:34:10.880 --> 03:34:18.000] their probabilities. And we're going to pick the option that gets the lowest or the highest
[03:34:18.000 --> 03:34:25.920] average probability for the token. So for the tokens, because that is the most likely completion
[03:34:25.920 --> 03:34:31.440] according to the language model. So we're just going to look at the probabilities here,
[03:34:31.440 --> 03:34:37.280] and average them up across the options, and pick the one with the highest probability,
[03:34:37.280 --> 03:34:45.200] roughly speaking. So this is how we're going to do hella swag. And this is, I believe, also how
[03:34:45.200 --> 03:34:53.200] GPT 3 did it. This is how GPT 3 did it, as far as I know. But you should note that some of the
[03:34:53.200 --> 03:34:58.000] other emails, or you might see hella swag, may not do it this way. They may do it in a multiple
[03:34:58.000 --> 03:35:03.760] choice format where you sort of give the context a single time, and then the four completions.
[03:35:03.760 --> 03:35:08.560] And so the model is able to see all the four options before it picks the best possible option.
[03:35:08.560 --> 03:35:12.960] And that's actually an easier task for a model because you get to see the other options when
[03:35:12.960 --> 03:35:18.560] you're picking your choice. But unfortunately, models that are sized can't do that. Only models
[03:35:18.560 --> 03:35:23.600] at a bigger size are able to do that. And so our models are actually slightly handicapped in this
[03:35:23.600 --> 03:35:28.960] way that they are not going to see the other options. They're only going to see one option at a time,
[03:35:28.960 --> 03:35:33.200] and they just have to assign probabilities. And the correct option has to win out in this metric.
[03:35:33.920 --> 03:35:38.400] All right, so let's now implement this very briefly and incorporate it into our script.
[03:35:38.400 --> 03:35:43.280] Okay, so what I've done here is I've introduced a new file called hella swag.py,
[03:35:43.280 --> 03:35:47.040] and you can take a look into it. And I'm not going to step through all of it because
[03:35:47.040 --> 03:35:52.960] this is not exactly like deep code, deep code. It's kind of like a little bit tedious, honestly,
[03:35:52.960 --> 03:35:57.520] because what's happening is I'm downloading hella swag from GitHub, and I'm rendering all of
[03:35:57.520 --> 03:36:02.640] its examples. And there are a total of 10,000 examples. I am rendering them into this format.
[03:36:04.160 --> 03:36:10.880] And so here at the end of this render example function, you can see that I'm returning the tokens,
[03:36:10.880 --> 03:36:21.360] the tokens of this four by T array of tokens, the mask, which tells us which parts are the options,
[03:36:21.360 --> 03:36:27.120] and everything else is zero, and the label, that is the correct label. And so that allows us to
[03:36:27.120 --> 03:36:31.840] then iterate the examples and render them. And I have an evaluate function here, which can load
[03:36:31.840 --> 03:36:40.560] a GPT two from our new face. And it runs the eval here. And basically just calculates,
[03:36:40.560 --> 03:36:47.040] just as I described, it predicts the option that has the lowest or the highest probability. And
[03:36:47.040 --> 03:36:51.840] the way to do that actually is we can basically evaluate the cross entropy loss. So we're basically
[03:36:51.840 --> 03:36:57.120] valuing the loss of predicting the next token in the sequence. And then we're looking at the row
[03:36:57.120 --> 03:37:05.040] that has the lowest average loss. And that's the option that we pick as the prediction.
[03:37:05.040 --> 03:37:09.440] And then we do some stats and prints and stuff like that. So that is a way to evaluate the
[03:37:09.440 --> 03:37:15.520] loss. Now, if you go up here, I'm showing that for GPT two, one 24m, if you run this script,
[03:37:15.520 --> 03:37:22.720] you're going to see that hella swag gets 29.55%. So that's the performance we get here. Now,
[03:37:22.720 --> 03:37:28.560] remember that random chances 25%. So we haven't gone too far. And GPT two XL,
[03:37:28.560 --> 03:37:35.840] which is the biggest V GPT two, gets all the way up to 49% roughly. So these are pretty low values
[03:37:35.840 --> 03:37:40.720] considering that today's state of the art is more like 95%. So these are definitely older models by
[03:37:40.720 --> 03:37:45.760] now. And then there's one more thing called a Luther harness, which is a very common piece of
[03:37:45.760 --> 03:37:50.240] infrastructure for running emails for language models. And they get slightly different numbers.
[03:37:50.240 --> 03:37:55.280] And I'm not 100% sure what the discrepancy is for these. It could be that they actually
[03:37:55.280 --> 03:37:59.840] do the multiple choice instead of just the completions. And then that could be the
[03:37:59.840 --> 03:38:06.080] discrepancy. But I'm not 100% sure about that. I'd have to take a look. But for now, our script
[03:38:06.080 --> 03:38:12.320] reports 29.55. And so that is the number that we'd like to beat if we're training a GP two on 24m
[03:38:12.320 --> 03:38:21.120] from scratch in ourselves. So now I'm going to go into actually incorporating this eval
[03:38:21.120 --> 03:38:28.880] into our main training script. And and basically because we want to evaluate it in a periodic manner,
[03:38:28.880 --> 03:38:34.880] so that we can track hella swag and how it evolves over time. And see when and if we cross
[03:38:35.600 --> 03:38:43.040] this 29.55 sort of region. So let's now walk through some of the changes to train GPT to that
[03:38:43.040 --> 03:38:49.280] pipe. The first thing I did here is actually made use compile optional kind of and I disabled it
[03:38:49.280 --> 03:38:55.120] by default. And the problem with that is the problem with compile is that unfortunately it
[03:38:55.120 --> 03:38:59.520] does make our code faster, but it actually breaks the evaluation code and the sampling code. It
[03:38:59.520 --> 03:39:03.520] gives me a very gnarly message and I don't know why. So hopefully by the time you get to the
[03:39:04.160 --> 03:39:08.400] code base, when I put it up on GitHub, we're gonna fix that by then. But for now I'm running
[03:39:08.400 --> 03:39:13.360] without torch compile, which is why you see this be a bit slower. So we're running without torch
[03:39:13.360 --> 03:39:20.960] compile. I also created a log directory log where we can place our log.txt, which will record the
[03:39:20.960 --> 03:39:26.080] train loss, validation loss, and the hella swag accuracies. So a very simple text file and we're
[03:39:26.080 --> 03:39:31.520] going to open for writing so that it sort of starts empty. And then we're going to append to it.
[03:39:33.440 --> 03:39:39.520] I created a simple variable that helps tell us when we have a last step. And then basically
[03:39:39.520 --> 03:39:46.080] periodically inside this loop, every 250th iteration or at the last step, we're going to evaluate
[03:39:46.080 --> 03:39:54.560] the validation loss. And then every 2250th iteration, we are going to evaluate hella swag,
[03:39:54.560 --> 03:40:00.320] but only if we are not using compile because compile breaks it. So I'm going to come back to
[03:40:00.320 --> 03:40:06.080] this code for evaluating hella swag in a second. And then every 250th iteration as well,
[03:40:06.080 --> 03:40:09.920] we're also going to sample from the model. And so you should recognize this as our
[03:40:09.920 --> 03:40:14.160] ancient code from way back when we started the video. And we're just sampling from the model.
[03:40:14.160 --> 03:40:22.800] And then finally here, these are, if we're not, after we validate sample and evaluate hella swag,
[03:40:22.800 --> 03:40:28.560] we actually do a training step here. And so this is one step of training, and you should be pretty
[03:40:28.560 --> 03:40:33.440] familiar with all of what this does. And at the end here, once we get our training loss,
[03:40:33.440 --> 03:40:38.320] we write it to the file. So the only thing that changed that I really added is this entire section
[03:40:38.320 --> 03:40:43.680] for hella swag eval. And the way this works is I'm trying to get all the GPUs to collaborate on
[03:40:43.680 --> 03:40:50.480] the hella swag. And so we're iterating on the examples. And then each process only picks the
[03:40:50.480 --> 03:40:56.320] examples that assigned to it. So we sort of take I and mod it by the world size, and we have to
[03:40:56.320 --> 03:41:02.480] make it equal to rank, otherwise we continue. And then we render an example, put it on a GPU,
[03:41:02.480 --> 03:41:08.000] we get the logits, then I created helper function that helps us basically predict the option with
[03:41:08.000 --> 03:41:13.840] the lowest loss. So this comes here, the prediction. And then if it's correct, we sort of keep count.
[03:41:13.840 --> 03:41:19.040] And then if multiple processes were collaborating on all this, then we need to synchronize their
[03:41:19.040 --> 03:41:24.960] stats. And so the way one way to do that is to package up our statistics here into tensors,
[03:41:25.600 --> 03:41:33.120] which we can then call this dot already is on, and some. And then here we sort of unwrap them
[03:41:33.120 --> 03:41:38.480] from tensors so that we just have ins. And then here the master process will print and log the
[03:41:38.480 --> 03:41:47.120] hella swag accuracy. So that's kind of the that's kind of it. And that's what I'm running right here.
[03:41:47.120 --> 03:41:54.080] So you see this optimization here. And we just had a generation. And this is step 10,000 out of
[03:41:54.080 --> 03:41:59.680] about 20,000, right? So we are halfway done. And these are kinds of samples that we are getting
[03:41:59.680 --> 03:42:05.360] at this stage. So let's take a look. Hello, I'm a language model. So I'd like to use it to generate
[03:42:05.360 --> 03:42:09.920] some kinds of output. Hello, I'm a language model, and I'm a developer for a lot of companies.
[03:42:11.040 --> 03:42:21.520] A long language model. Let's see if I can find any fun one.
[03:42:21.520 --> 03:42:32.800] I don't know, you can go through this yourself, but certainly the predictions are getting less and
[03:42:32.800 --> 03:42:38.960] less random. It seems like the model is a little bit more self-aware in using language that is a bit
[03:42:38.960 --> 03:42:46.240] more specific to it being a language model. Hello, I'm a language model. And like how the language
[03:42:46.240 --> 03:42:51.440] is used to communicate, I'm a language model and are going to be speaking English and German.
[03:42:51.440 --> 03:42:57.760] So let's just wait until this optimization finishes. And we'll see what kind of samples we get.
[03:42:57.760 --> 03:43:03.680] And we're also going to look at the train, the vowel, and the hella swag accuracy and see how
[03:43:03.680 --> 03:43:10.560] we're doing with respect to GPT2. Okay, good morning. So focusing for a moment on the Jupyar
[03:43:10.560 --> 03:43:15.520] notebook here on the right, I created a new cell that basically allows us to visualize the train,
[03:43:15.520 --> 03:43:23.040] vowel, and the hella score. And you can step through this. It basically parses the log file that we
[03:43:23.040 --> 03:43:29.520] are writing. And a lot of this is just like boring matplotlib code. But basically, this is what our
[03:43:29.520 --> 03:43:40.160] optimization looks like. So we ran for 19,073 steps, which is roughly 10 billion tokens,
[03:43:40.160 --> 03:43:45.520] which is whoops, oh my gosh, which is one epoch of the sample 10B of firewall video.
[03:43:45.520 --> 03:43:52.000] On the left, we have the loss. And in the blue, we have the training loss. In orange, we have the
[03:43:52.000 --> 03:43:59.360] validation loss. And red as a horizontal line, we have the opening of GPT2, 124M model checkpoint,
[03:43:59.360 --> 03:44:06.640] when it's just evaluated on the validation set of this fine web media. So you can see that we are
[03:44:06.640 --> 03:44:12.560] surpassing this orange is below the red. So we're surpassing the validation set of this data set.
[03:44:12.560 --> 03:44:16.640] And like I mentioned, the data distribution is very different from what GPT2 trained on.
[03:44:16.640 --> 03:44:22.720] So this is not exactly fair comparison, but it's a good cross check to look at.
[03:44:22.720 --> 03:44:29.040] Now we would ideally like something that is withheld and comparable and somewhat standard.
[03:44:29.040 --> 03:44:35.680] And so for us, that is hella swag. And so on here, we see the hella swag progress we made from 25%
[03:44:35.680 --> 03:44:44.560] all the way here. In red, we see the opening of GPT2, 124M model in red. So it achieves this
[03:44:44.560 --> 03:44:51.920] hella swag here. And the GPT3 model 124M, which was trained on 300 billion tokens,
[03:44:51.920 --> 03:45:00.000] achieves green. So that's over here. So you see that we basically surpassed the GPT2 124M model
[03:45:00.000 --> 03:45:07.520] right here, which is really nice. Now, interestingly, we were able to do so with
[03:45:07.520 --> 03:45:12.160] only training on 10 billion tokens, while GPT2 was trained on 100 billion tokens.
[03:45:12.960 --> 03:45:17.760] So for some reason, we were able to get away with significantly fewer tokens for training.
[03:45:17.760 --> 03:45:22.960] There are many possibilities to us to why we could match or surpass this accuracy
[03:45:22.960 --> 03:45:30.960] with only 10 billion training. So number one, it could be that opening of GPT2 was trained on
[03:45:30.960 --> 03:45:37.520] a much wider data distribution. So in particular, fine web EDU is all English. It's not multilingual.
[03:45:38.080 --> 03:45:44.240] And there's not that much math and code. And so math and code and multilingual could have been
[03:45:44.240 --> 03:45:52.000] stealing capacity from the original GPT2 model. And basically, that could be partially the reason
[03:45:52.000 --> 03:45:57.600] why this is not working out. There's many other reasons. So for example, the hella swag eval is
[03:45:57.600 --> 03:46:03.600] fairly old, maybe five years or so. It is possible that aspects of hella swag in some way or even
[03:46:03.600 --> 03:46:10.080] identically have made it into the training set of fine web. We don't know for sure, but if that
[03:46:10.080 --> 03:46:13.360] was the case, then we are basically looking at the training curve instead of the validation curve.
[03:46:13.360 --> 03:46:18.880] So long story short, this is not a perfect eval and there's some caveats here. But at least we
[03:46:18.880 --> 03:46:25.920] have some confidence that we're not doing something completely wrong. And it's probably the case
[03:46:25.920 --> 03:46:30.240] that when people try to create these datasets, they try to make sure that test sets that are very
[03:46:30.240 --> 03:46:35.520] common are not part of the training set. For example, when hugging face created the fine web EDU,
[03:46:35.520 --> 03:46:40.240] they use hella swag as an eval. So I would hope that they make sure that they deduplicate and
[03:46:40.240 --> 03:46:45.760] that there's no hella swag in the training set. But we can't be sure. The other thing I wanted
[03:46:45.760 --> 03:46:50.960] to address briefly is, look at this lusker. This looks really, this looks really wrong here.
[03:46:50.960 --> 03:46:56.720] I don't actually know 100% what this is. And I suspect it's because the 10 billion sample of fine
[03:46:56.720 --> 03:47:04.960] web EDU was not properly shuffled. And there's some issue here with the data that I don't fully
[03:47:04.960 --> 03:47:10.560] understand yet. And there's some weird periodicity to it. And because we are in a very lazy way sort
[03:47:10.560 --> 03:47:15.120] of serializing all the tokens and just iterating on them from scratch without doing any permutations
[03:47:15.120 --> 03:47:21.360] or any random sampling ourselves, I think we're inheriting some of the ordering that they have
[03:47:21.360 --> 03:47:27.840] in a dataset. So this is not ideal. But hopefully by the time you get to this repo,
[03:47:27.840 --> 03:47:34.400] some of these things, by the way, will hopefully be fixed. And I will release this build nano GPT
[03:47:34.400 --> 03:47:39.520] repo. And right now it looks a little ugly and preliminary. So hopefully by the time you get
[03:47:39.520 --> 03:47:45.360] here, it's nicer. But down here, I'm going to show Erada. And I'm going to talk about some of the
[03:47:45.360 --> 03:47:51.520] things that happened after the video. And I expect that we will have fixed the small issue. But for
[03:47:51.520 --> 03:47:58.320] now, basically, this shows that our training is not completely wrong. And it shows that we're able
[03:47:58.320 --> 03:48:05.120] to surpass the accuracy with only 10x the token budget. And possibly it could be also that the
[03:48:05.120 --> 03:48:11.680] data set might have improved. So the original GPT two data set was webtext. It's possible that
[03:48:11.680 --> 03:48:16.720] not a lot of care and attention wanted to the data set. This was very early in LMS. Whereas
[03:48:16.720 --> 03:48:22.720] now there's a lot more scrutiny on good practices around deduplication, filtering, quality filtering,
[03:48:22.720 --> 03:48:27.120] and so on. And it's possible that the data set we're training on is just a higher quality per token.
[03:48:27.120 --> 03:48:32.080] And that could be giving us a boost as well. So a number of caveats to think about. But for now,
[03:48:32.080 --> 03:48:38.080] we're pretty happy with this. And yeah, now the next thing I was interested in is, as you see,
[03:48:38.080 --> 03:48:42.800] it's a morning now. So there was an overnight. And I wanted to basically see how far I could
[03:48:42.800 --> 03:48:48.720] push the result. So to do an overnight run, I basically did instead of one epoch, which took
[03:48:48.720 --> 03:48:53.680] roughly two hours. I just did a times four, so that that would take eight hours while I was sleeping.
[03:48:53.680 --> 03:48:59.600] And so we did four epochs or roughly 40 billion tokens of training. And I was trying to see how
[03:48:59.600 --> 03:49:05.600] far we could get. And so this was the only change in I rerun the script. And when I point and read
[03:49:05.600 --> 03:49:13.520] the log file at the 40 B, this is what the curve looked like. Okay, so to narrate this, number one,
[03:49:13.520 --> 03:49:18.400] we are seeing this issue here with the periodicity through the different epochs and something really
[03:49:18.400 --> 03:49:25.280] weird with the fine web edu data set. And that is to be determined. But otherwise, we are seeing
[03:49:25.280 --> 03:49:32.320] that the helus wag actually went up by a lot. And we almost, we almost made it to the GPT three
[03:49:32.320 --> 03:49:39.120] 124M accuracy up here, but not quite. So it's too bad that I didn't sleep slightly longer.
[03:49:39.120 --> 03:49:47.520] And I think if this was a five epoch run, we may have gotten here. Now, one thing to point out is
[03:49:47.520 --> 03:49:52.880] that if you're doing multi epoch runs, we're not actually being very careful in our data loader.
[03:49:52.880 --> 03:50:00.960] And we're not. This data loader goes through the data in exactly the same format and exactly the
[03:50:00.960 --> 03:50:05.280] same order. And this is kind of suboptimal. And you would want to look into extensions where you
[03:50:05.280 --> 03:50:11.360] actually permute the data randomly. You permute the documents around in every single shard on every
[03:50:11.360 --> 03:50:18.080] single new epoch, and potentially even permute the shards. And that would go a long way into
[03:50:18.080 --> 03:50:22.560] decreasing the pre allicity. And it's also better for the optimization, so that you're not seeing
[03:50:22.560 --> 03:50:28.000] things in the identical format. And you're introducing some of the some of the randomness in how the
[03:50:28.000 --> 03:50:32.640] documents follow each other. Because you have to remember that in every single row, these documents
[03:50:32.640 --> 03:50:36.480] follow each other. And then there's the end of text token, and then the next document. So the
[03:50:36.480 --> 03:50:41.920] documents are currently glued together in the exact same identical manner. But we actually want
[03:50:41.920 --> 03:50:46.640] to break break up the documents and shuffle them around, because the order of the documents shouldn't
[03:50:46.640 --> 03:50:51.680] matter. And they shouldn't basically want to break up that dependence, because it says kind of a
[03:50:51.680 --> 03:50:56.960] spurious correlation. And so our data loader is not currently doing that. And that's one improvement
[03:50:56.960 --> 03:51:03.120] you could think of making. The other thing to point out is we're almost matching GPT-3 accuracy
[03:51:03.120 --> 03:51:09.040] with only 40 billion tokens, GPT-3 trained on 300 billion tokens. So again, we're seeing about a
[03:51:09.040 --> 03:51:16.400] 10x improvement here with respect to learning efficiency. The other thing I wanted to, and
[03:51:16.400 --> 03:51:19.680] I don't actually know exactly what to attribute this to, other than some of the things that I
[03:51:19.680 --> 03:51:24.800] already mentioned previously for the previous one. The other thing I wanted to briefly mention is
[03:51:26.160 --> 03:51:31.360] the max LR here. I saw some people already play with this a little bit in a previous related
[03:51:31.360 --> 03:51:36.880] repository. And it turns out that you can actually almost like 3x this. So it's possible that the
[03:51:36.880 --> 03:51:41.200] maximum learning rate can be a lot higher. And for some reason, the GPT-3 hyperparameters that we
[03:51:41.200 --> 03:51:45.520] are inheriting are actually extremely conservative. And you can actually get away with higher learning
[03:51:45.520 --> 03:51:52.000] rate and it would train faster. So a lot of these hyperparameters are quite tunable and feel free
[03:51:52.000 --> 03:51:59.200] to play with them. And they're probably not set precisely correctly. And it's possible that you
[03:51:59.200 --> 03:52:05.200] can get away with doing this, basically. And if you wanted to exactly be faithful to GPT-3,
[03:52:05.200 --> 03:52:11.920] you would also want to make the following difference. You'd want to come here. And the sequence length
[03:52:11.920 --> 03:52:19.600] of GPT-3 is 2x. It's 2048 instead of 1024. So you would come here, changes to 2048 for T.
[03:52:19.600 --> 03:52:25.440] And then if you want the exact same number of tokens, half a million per iteration or per step,
[03:52:25.440 --> 03:52:29.440] you want to then decrease this to 232. So they still multiply to half a million.
[03:52:29.440 --> 03:52:36.400] So that would give your model sequence length equal to that GPT-3. And in that case, basically,
[03:52:36.400 --> 03:52:44.480] the models would be roughly identical as far as I'm aware. Because again, GPT-2 and GPT-3 are very,
[03:52:44.480 --> 03:52:49.120] very similar models. Now, we can also look at some of the samples here from the model
[03:52:49.120 --> 03:52:56.000] that was trained overnight. So this is the optimization. And you see that here we stepped
[03:52:56.000 --> 03:53:04.480] all the way to 76,000, 290, also or so. And the hella-spy we achieved was 33.24.
[03:53:04.480 --> 03:53:10.320] And these are some of the samples from the model. And you can see that if you read through this
[03:53:10.320 --> 03:53:16.560] and pause the video briefly, you can see that there are a lot more coherent. And they're actually
[03:53:16.560 --> 03:53:20.000] addressing the fact that it's a language model almost. So
[03:53:20.000 --> 03:53:25.840] hello, I'm a language model, and I try to be as accurate as possible.
[03:53:25.840 --> 03:53:30.400] I'm a language model, not a programming language.
[03:53:30.400 --> 03:53:35.280] I know how to communicate. I use Python.
[03:53:35.280 --> 03:53:43.040] I don't know. If you pause this and look at it and then compare it to the model that was
[03:53:43.040 --> 03:53:46.720] only trained for 10 billion, you will see that these are a lot more coherent.
[03:53:46.720 --> 03:53:51.040] And you can play with this yourself. One more thing I added to the code, by the way,
[03:53:51.040 --> 03:53:56.240] is this chunk of code here. So basically, right after we evaluate the validation loss,
[03:53:56.240 --> 03:54:01.520] if we are the master process, in addition to logging the validation loss, every 5,000 steps
[03:54:01.520 --> 03:54:06.080] we're also going to save the checkpoint, which is really just the state dictionary of the model.
[03:54:06.080 --> 03:54:10.720] And so checkpointing is nice just because you can save the model and later you can
[03:54:11.360 --> 03:54:16.880] use it in some way. If you wanted to resume the optimization, then in addition to saving the model,
[03:54:16.880 --> 03:54:22.240] we have to also save the optimizer state dict, because remember that the optimizer has a few
[03:54:22.240 --> 03:54:29.200] additional buffers because of Adam. So it's got the M and B, and you need to also resume the
[03:54:29.200 --> 03:54:34.080] optimizer properly. You have to be careful with the RNG seeds, random number generators, and so on.
[03:54:34.080 --> 03:54:38.960] So if you wanted to exactly be able to resume optimization, you have to think through the state
[03:54:38.960 --> 03:54:43.440] of the training process. But if you just want to save the model, this is how you would do it.
[03:54:43.440 --> 03:54:48.640] And one nice reason why you might want to do this is because you may want to evaluate the model
[03:54:48.640 --> 03:54:54.640] up more carefully. So here we are only kind of like winging the hella swag eval, but you may want
[03:54:54.640 --> 03:55:02.960] to use something nicer. Like, for example, the Luther evaluation hardness, evaluation hardness,
[03:55:04.000 --> 03:55:11.680] hardness. So this is a way to also evaluate language models. And so it's possible that
[03:55:11.680 --> 03:55:18.880] you may want to use basically different infrastructure to more thoroughly evaluate the models on different
[03:55:18.880 --> 03:55:25.840] evaluations and compare it to the opening RGP2 model on many other tasks. Like, for example,
[03:55:25.840 --> 03:55:30.480] that involve math code or different languages and so on. So this is a nice functionality to have as
[03:55:30.480 --> 03:55:36.960] possible. And then the other thing I wanted to mention is that everything we've built here,
[03:55:36.960 --> 03:55:43.840] this is only the pre-training step. So the GPT here is a, it dreams documents, it just predicts
[03:55:43.840 --> 03:55:49.520] the next token. You can't talk to it like you can talk to chat GPT chat GPT. If you wanted to
[03:55:49.520 --> 03:55:54.400] talk to the model, we have to fine tune it into the chat format. And that's not actually like
[03:55:54.400 --> 03:55:58.960] that complicated. If you're looking at supervised fine tuning or SFT, really what that means is
[03:55:58.960 --> 03:56:02.960] we're just swapping out a dataset into a dataset that is a lot more conversational,
[03:56:02.960 --> 03:56:07.200] and there's a user-assistant user-assistant kind of structure. And we just fine tune on it.
[03:56:07.200 --> 03:56:13.440] And then we basically fill in the user tokens and we sample the assistant tokens. It's not a lot
[03:56:13.440 --> 03:56:19.120] more deeper than that. But basically we swap out the dataset and continue training. But for now,
[03:56:19.120 --> 03:56:23.520] we're going to stop at pre-training. One more thing that I wanted to briefly show you is that,
[03:56:23.520 --> 03:56:28.800] of course, what we've built up today was building towards nano GPT, which is this repository from
[03:56:28.800 --> 03:56:34.320] earlier. But also there's actually another nano GPT implementation, and it's hiding in a
[03:56:34.320 --> 03:56:42.240] more recent project that I've been working on, called LLM.C. And LLM.C is a pure C CUDA implementation
[03:56:42.240 --> 03:56:49.040] of GPT-2 or GPT-3 training. And it just directly uses CUDA and is written as C CUDA.
[03:56:49.600 --> 03:56:54.880] Now the nano GPT here acts as reference code in PyTorch to the C implementation. So we're
[03:56:54.880 --> 03:56:59.920] trying to exactly match up the two, but we're hoping that the C CUDA is faster. And of course,
[03:56:59.920 --> 03:57:05.200] currently that seems to be the case, because it is a direct optimized implementation. So
[03:57:05.200 --> 03:57:12.000] traingpt2.py in LLM.C is basically the nano GPT. And when you scroll through this file,
[03:57:12.000 --> 03:57:19.360] you'll find a lot of things that very much look like things that we've built up in this lecture.
[03:57:19.920 --> 03:57:25.360] And then when you look at traingpt2.cu, this is the C CUDA implementation.
[03:57:25.360 --> 03:57:33.280] So there's a lot of MPI and they're called GPU CUDA CC++. And you have to be familiar with that. But
[03:57:33.280 --> 03:57:40.240] when this is built up, we can actually run the two side by side. And they're going to produce
[03:57:40.240 --> 03:57:46.480] the exact same results, but LLM.C actually runs faster. So let's see that. So on the left, I have
[03:57:46.480 --> 03:57:52.960] PyTorch, nano GPT looking thing. On the right, I have the LLM.C call. And here I'm going to
[03:57:52.960 --> 03:57:58.000] launch the two. Both of these are going to be running on a single GPU. And here I'm putting
[03:57:58.000 --> 03:58:04.000] the LM.C on GPU one. And this one will grab GPU zero by default. And then
[03:58:04.000 --> 03:58:12.240] we can see here that LLM.C compiled and then allocate space and it's stepping.
[03:58:13.520 --> 03:58:21.200] So basically, meanwhile, PyTorch is still compiling because Torque compile is a bit slower here
[03:58:21.200 --> 03:58:28.400] than the LLM.C NVCC C CUDA compile. And so this program has already started running. And we're
[03:58:28.400 --> 03:58:33.440] still waiting here for Torque compile. Now, of course, this is a very specific implementation
[03:58:33.440 --> 03:58:38.320] to GPT two and three. PyTorch is a very general neural network framework. So they're not exactly
[03:58:38.320 --> 03:58:43.280] comparable. But if you're only interested in training GPT two and three, LLM.C is very fast.
[03:58:43.280 --> 03:58:51.920] It takes less space. It's faster to start and it's faster per step. And so PyTorch started
[03:58:51.920 --> 03:58:58.080] stepping here. And as you can see, we're running at about 223,000 tokens per second here and about
[03:58:58.080 --> 03:59:06.080] 185,000 tokens per second here. So quite a bit slower. But I don't have full confidence that I
[03:59:06.640 --> 03:59:11.440] exactly squeezed out all the juice from the PyTorch implementation. But the important thing
[03:59:11.440 --> 03:59:17.200] here is notice that if I line up the steps, you will see that the losses and the norms that are
[03:59:17.200 --> 03:59:23.120] printed between these two are identical. So on the left, we have the PyTorch and on the right,
[03:59:23.120 --> 03:59:28.880] this secret implementation. And they're the same except this one runs faster. So that's kind of,
[03:59:28.880 --> 03:59:34.480] I wanted to show you also briefly, LLM.C. And this is a parallel implementation. And it's also
[03:59:34.480 --> 03:59:39.440] something that you may want to play with or look at. And it's kind of interesting.
[03:59:39.440 --> 03:59:43.360] Okay. So at this point, I should probably start wrapping up the video because I think it's getting
[03:59:43.360 --> 03:59:48.720] way longer than anticipated. But we did cover a lot of ground and we built everything from scratch.
[03:59:48.720 --> 03:59:54.560] So as a brief summary, we were looking at the GPT two and GPT three papers.
[03:59:54.560 --> 04:00:00.560] We were looking at how you set up these training runs and all the considerations evolved.
[04:00:00.560 --> 04:00:05.120] We wrote everything from scratch. And then we saw that over the duration of either a two-hour
[04:00:05.120 --> 04:00:10.960] training run or an overnight run, we can actually match the 124 million parameter checkpoints of
[04:00:10.960 --> 04:00:16.880] GPT two and GPT three to a very large extent. In principle, the code that we wrote would be
[04:00:16.880 --> 04:00:20.400] able to train even bigger models if you have the patients or the computing resources.
[04:00:20.400 --> 04:00:24.480] And so you could potentially think about training some of the bigger checkpoints as well.
[04:00:25.920 --> 04:00:30.960] There are a few remaining issues to address. What's happening with the loss here, which I suspect
[04:00:30.960 --> 04:00:36.880] has to do with the fine web EDU data sampling. Why can't we turn on torch compile? It currently
[04:00:36.880 --> 04:00:41.760] breaks generation and hella swag. What's up with that? In the data loader, we should probably be
[04:00:41.760 --> 04:00:46.640] permuting our data when we reach epoch boundaries. So there's a few more issues like that. And I
[04:00:46.640 --> 04:00:51.840] expect to be documenting some of those over time in the build managed GPT repository here,
[04:00:52.960 --> 04:00:57.920] which I'm going to be releasing with this video. If you have any questions or would like to talk
[04:00:57.920 --> 04:01:04.160] about anything that we covered, please go to discussion tab so we can talk here. Or please go
[04:01:04.160 --> 04:01:08.720] to issues or pull requests, pull requests, depending on what you'd like to contribute.
[04:01:08.720 --> 04:01:15.040] Or also have a look at the zero to hero discord. And I'm going to be hanging out here on that
[04:01:15.040 --> 04:01:24.160] on GPT. Otherwise, for now, I'm pretty happy about where we got. And I hope you enjoyed the video.
[04:01:24.160 --> 04:01:25.680] And I will see you later.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment