Created
February 20, 2025 03:36
-
-
Save arthurcolle/033db4b2cc961f7d718a4989db0cb626 to your computer and use it in GitHub Desktop.
let's build gpt2
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[00:00:00.000 --> 00:00:04.320] Hi everyone. So today we are going to be continuing our Zero2Hero series | |
[00:00:04.320 --> 00:00:10.640] and in particular today we are going to reproduce the GPT2 model, the 124 million version of it. | |
[00:00:10.640 --> 00:00:17.440] So when OpenAI released GPT2, this was 2019 and they released it with this blog post. | |
[00:00:17.440 --> 00:00:23.040] On top of that they released this paper and on top of that they released this code on GitHub, | |
[00:00:23.040 --> 00:00:29.600] so OpenAI/GPT2. Now when we talk about reproducing GPT2, we have to be careful because in particular | |
[00:00:29.600 --> 00:00:34.880] in this video we're going to be reproducing the 124 million parameter model. So the thing to | |
[00:00:34.880 --> 00:00:41.040] realize is that there's always a miniseries when these releases are made, so there are the GPT2 | |
[00:00:41.040 --> 00:00:46.800] miniseries made up of models at different sizes and usually the biggest model is called the GPT2. | |
[00:00:46.800 --> 00:00:52.160] But basically the reason we do that is because you can put the model sizes on the x-axis of | |
[00:00:52.160 --> 00:00:56.720] plots like this and on the y-axis you put a lot of downstream metrics that you're interested in, | |
[00:00:56.720 --> 00:01:02.000] like translation, summarization, question answering and so on, and you can chart out the scaling loss. | |
[00:01:02.000 --> 00:01:06.800] So basically as the model size increases you're getting better and better at downstream metrics. | |
[00:01:06.800 --> 00:01:14.480] And so in particular for GPT2, if we scroll down in the paper, there are four models in the GPT2 | |
[00:01:14.480 --> 00:01:22.720] miniseries starting at 124 million all the way up to 1558 million. Now the reason my numbers, | |
[00:01:22.720 --> 00:01:27.520] the way I say them, disagree with this table is that this table is wrong. If you actually go to the | |
[00:01:27.520 --> 00:01:34.400] GPT2 github repo they sort of say that there was an error in how they added up the parameters, | |
[00:01:34.400 --> 00:01:40.320] but basically this is the 124 million parameter model, etc. So the 124 million parameter had 12 | |
[00:01:40.320 --> 00:01:47.840] layers in the transformer and it had 768 channels in the transformer, 768 dimensions. And I'm going | |
[00:01:47.840 --> 00:01:51.680] to be assuming some familiarity with what these terms mean because I covered all of this in my | |
[00:01:51.680 --> 00:01:57.200] previous video, let's build GPT2, let's build GPT from scratch. So I covered that in a previous | |
[00:01:57.200 --> 00:02:02.640] video in this playlist. Now if we do everything correctly and everything works out well, by the | |
[00:02:02.640 --> 00:02:07.360] end of this video we're going to see something like this, where we're looking at the validation loss, | |
[00:02:07.360 --> 00:02:13.760] which basically measures how good we are at predicting the next token in a sequence on some | |
[00:02:13.760 --> 00:02:18.880] validation data that the model has not seen during training. And we see that we go from doing that | |
[00:02:18.880 --> 00:02:23.680] task not very well, because we're initializing from scratch, all the way to doing that task quite well | |
[00:02:23.680 --> 00:02:29.920] by the end of the training. And hopefully we're going to beat the GPT2 124M model. | |
[00:02:29.920 --> 00:02:35.680] Now previously when they were working on this, this is already five years ago. So this was probably | |
[00:02:35.680 --> 00:02:40.560] a fairly complicated optimization at the time and the GPUs and the compute was a lot smaller. | |
[00:02:40.560 --> 00:02:45.920] Today you can reproduce this model in roughly an hour, or probably less even, and it will cost | |
[00:02:45.920 --> 00:02:50.640] you about 10 bucks if you want to do this on the cloud, cloud compute, a sort of computer that | |
[00:02:50.640 --> 00:02:56.640] you can all rent. And if you pay $10 for that computer, you wait about an hour or less, you can | |
[00:02:56.640 --> 00:03:02.720] actually achieve a model that is as good as this model that OpenAI released. And one more thing | |
[00:03:02.720 --> 00:03:08.800] to mention is unlike many other models, OpenAI did release the weights for GPT2. So those weights | |
[00:03:08.800 --> 00:03:14.640] are all available in this repository. But the GPT2 paper is not always as good with all of the | |
[00:03:14.640 --> 00:03:19.600] details of the training. So in addition to the GPT2 paper, we're going to be referencing the GPT3 | |
[00:03:19.600 --> 00:03:25.200] paper, which is a lot more concrete in a lot of the hyper parameters and optimization settings and | |
[00:03:25.200 --> 00:03:31.680] so on. And it's not a huge departure in the architecture from the GPT2 version of the model. | |
[00:03:31.680 --> 00:03:37.040] So we're going to be referencing both GPT2 and GPT3 as we try to reproduce GPT2 124M. | |
[00:03:37.040 --> 00:03:42.960] So let's go. So the first thing I would like to do is actually start at the end or at the target. | |
[00:03:42.960 --> 00:03:48.480] So in other words, let's load the GPT2 124M model as it was released by OpenAI | |
[00:03:48.480 --> 00:03:52.800] and maybe take it for a spin. Let's sample some tokens from it. Now the issue with that is when | |
[00:03:52.800 --> 00:03:58.160] you go to the codebase of GPT2 and you go into the source and you click in on the model at Pi, | |
[00:03:58.160 --> 00:04:03.680] you'll realize that actually this is using TensorFlow. So the original GPT2 code here was written in | |
[00:04:03.680 --> 00:04:12.000] TensorFlow, which is, you know, not, let's just say not used as much anymore. So we'd like to use | |
[00:04:12.000 --> 00:04:16.480] PyTorch because it's a lot friendlier, easier, and I just personally like it a lot more. | |
[00:04:16.480 --> 00:04:20.240] The problem with that is the initial code is in TensorFlow. We'd like to use PyTorch. | |
[00:04:20.240 --> 00:04:24.000] So instead, to get the target, we're going to use the hugging face transformers | |
[00:04:24.000 --> 00:04:30.480] code, which I like a lot more. So when you go into the transformers, source transformers, models, | |
[00:04:30.480 --> 00:04:35.840] GPT2, modeling GPT2.py, you will see that they have the GPT2 implementation of that | |
[00:04:35.840 --> 00:04:43.840] transformer here in this file. And it's like medium readable, but not fully readable. | |
[00:04:43.840 --> 00:04:49.440] But what it does is it did all the work of converting all those weights | |
[00:04:49.440 --> 00:04:53.680] from TensorFlow to PyTorch friendly. And so it's much easier to load and work with. | |
[00:04:53.680 --> 00:05:01.280] So in particular, we can look at the GPT2 model here and we can load it using hugging face transformers. | |
[00:05:01.280 --> 00:05:07.600] So swinging over, this is what that looks like. From transformers import the GPT2 | |
[00:05:07.600 --> 00:05:15.440] element head model and then from pre-trained GPT2. Now, one awkward thing about this is that | |
[00:05:15.440 --> 00:05:20.160] when you do GPT2 as the model that we're loading, this actually is the 124 million | |
[00:05:20.160 --> 00:05:26.560] parameter model. If you want the actual GPT2, the 1.5 billion, then you actually want to do | |
[00:05:26.560 --> 00:05:33.440] dash Excel. So this is the 124M, our target. Now what we're doing is when we actually get this, | |
[00:05:33.440 --> 00:05:38.160] we're initializing the PyTorch and then module as defined here in this class. | |
[00:05:38.160 --> 00:05:44.800] From it, I want to get just the state dict, which is just the raw tensors. So we just have | |
[00:05:44.800 --> 00:05:51.040] the tensors of that file. And by the way, here, this is a Jupyter notebook. But this is Jupyter | |
[00:05:51.040 --> 00:05:55.920] notebook running inside VS code. So I like to work with it all in a single | |
[00:05:55.920 --> 00:06:02.480] sort of interface. So I like to use VS code. So this is the Jupyter notebook extension inside VS code. | |
[00:06:02.480 --> 00:06:10.320] So we'll get the state dict. This is just a dict. So we can print the key and the value, | |
[00:06:10.320 --> 00:06:14.160] which is the tensor. And let's just look at the shapes. So these are sort of the | |
[00:06:14.160 --> 00:06:23.760] different parameters inside the GPT2 model and their shape. So the W wait for token embedding. | |
[00:06:25.840 --> 00:06:35.040] Is of size 50,257 by 768. Where this is coming from is that we have 50,257 tokens in the GPT2 | |
[00:06:35.040 --> 00:06:40.960] vocabulary. And the tokens, by the way, these are exactly the tokens that we spoken about in | |
[00:06:40.960 --> 00:06:46.560] the previous video on my tokenization series. So the previous videos, just before this, I go into | |
[00:06:46.560 --> 00:06:52.480] a ton of detail on tokenization. GPT2 tokenizer happens to have this many tokens. For each token, | |
[00:06:54.000 --> 00:07:00.800] we have a 768 dimensional embedding that is the distributed representation that stands in | |
[00:07:00.800 --> 00:07:07.280] for that token. So each token is a little string piece. And then these 768 numbers are the vector | |
[00:07:07.280 --> 00:07:12.960] that represents that token. And so this just our lookup table for tokens. And then here, | |
[00:07:12.960 --> 00:07:19.760] we have the lookup table for the positions. So because GPT2 has a maximum sequence of 1024, | |
[00:07:20.400 --> 00:07:27.120] we have up to 1024 positions that each token can be attending to in the past. And every one of | |
[00:07:27.120 --> 00:07:36.160] those positions in GPT2 has a fixed vector of 768 that is learned by optimization. And so this is | |
[00:07:36.160 --> 00:07:41.840] the position embedding and the token embedding. And then everything here is just the other | |
[00:07:41.840 --> 00:07:47.520] weights and biases and everything else of this transformer. So when you just take, for example, | |
[00:07:47.520 --> 00:07:52.560] the positional embeddings and flatten it out and take just a 20 elements, you can see that these | |
[00:07:52.560 --> 00:07:59.120] are just parameters. These are weights, floats, just we can take and we can plot them. So these | |
[00:07:59.120 --> 00:08:03.920] are the position embeddings. And we get something like this, and you can see that this has structure. | |
[00:08:03.920 --> 00:08:11.040] And it has structure because what we have here really is every row in this visualization | |
[00:08:11.040 --> 00:08:17.680] is a different position, a fixed absolute position in the range from zero to 1024. | |
[00:08:17.680 --> 00:08:24.160] And each row here is the representation of that position. And so it has structure because | |
[00:08:24.160 --> 00:08:30.000] these positional embeddings end up learning these sinusoids and cosines that sort of like | |
[00:08:30.000 --> 00:08:36.560] represent each of these positions. And each row here stands in for that position and is processed | |
[00:08:36.560 --> 00:08:43.280] by the transformer to recover all the relative positions and sort of realize which token is where | |
[00:08:43.280 --> 00:08:46.720] and attend to them depending on their position, not just their content. | |
[00:08:46.720 --> 00:08:54.240] So when we actually just look into an individual column inside these, and I just grabbed three | |
[00:08:54.240 --> 00:09:01.840] random columns, you'll see that, for example, here we are focusing on every single channel. | |
[00:09:02.480 --> 00:09:13.280] And we're looking at what that channel is doing as a function of position from zero to 1023 really. | |
[00:09:13.280 --> 00:09:18.880] And we can see that some of these channels basically like respond more or less to different | |
[00:09:18.880 --> 00:09:24.800] parts of the position spectrum. So this green channel really likes to fire for everything after | |
[00:09:25.360 --> 00:09:32.720] 200 up to 800, but not less, but a lot less and has a sharp drop off here near zero. | |
[00:09:32.720 --> 00:09:36.160] So who knows what these embeddings are doing and why they are the way they are. | |
[00:09:36.160 --> 00:09:39.680] You can tell, for example, that because they're a bit more jagged and they're kind of noisy, | |
[00:09:39.680 --> 00:09:44.720] you can tell that this model was not fully trained. And the more trained this model was, | |
[00:09:44.720 --> 00:09:48.480] the more you would expect to smooth this out. And so this is telling you that this is a little | |
[00:09:48.480 --> 00:09:54.560] bit of an under trained model. But in principle, actually, these curves don't even have to be smooth. | |
[00:09:54.560 --> 00:09:58.720] This should just be totally random noise. And in fact, in the beginning of the optimization, | |
[00:09:58.720 --> 00:10:04.400] it is complete random noise because this position embedding table is initialized completely at random. | |
[00:10:04.400 --> 00:10:08.960] So in the beginning, you have jaggedness. And the fact that you end up with something smooth is | |
[00:10:08.960 --> 00:10:14.000] already kind of impressive that that just falls out of the optimization. Because in principle, | |
[00:10:14.000 --> 00:10:18.240] you shouldn't even be able to get any single graph out of this that makes sense. But we actually | |
[00:10:18.240 --> 00:10:22.480] get something that looks a little bit noisy. But for the most part, it looks sinusoidal like. | |
[00:10:24.000 --> 00:10:30.240] In the original transformer paper, the attention is all you need paper. The positional embeddings | |
[00:10:30.240 --> 00:10:36.160] are actually initialized and fixed, flammable correctly, to sinusoids and cosines of different | |
[00:10:36.160 --> 00:10:41.200] frequencies. And that's the positional encoding and it's fixed. But in GPT2, these are just parameters | |
[00:10:41.200 --> 00:10:45.680] and they're trained from scratch, just like any other parameter. And that seems to work | |
[00:10:45.680 --> 00:10:50.400] about as well. And so what they do is they kind of like recover these sinusoidal like features | |
[00:10:50.400 --> 00:10:55.920] during the optimization. We can also look at any of the other matrices here. So | |
[00:10:55.920 --> 00:11:04.320] here I took the first layer of the transformer and looking at like one of its weights and just | |
[00:11:04.320 --> 00:11:10.560] the first block of 300 by 300. And you see some structure, but like, again, like who knows what | |
[00:11:10.560 --> 00:11:15.040] any of this is, if you're into mechanistic interoperability, you might get a real kick out | |
[00:11:15.040 --> 00:11:19.440] of trying to figure out like what is going on, what is this structure, and what does this all mean. | |
[00:11:19.440 --> 00:11:22.880] But we're not going to be doing that in this video. But we definitely see that there's some | |
[00:11:22.880 --> 00:11:27.360] interesting structure and that's kind of cool. What we're most interested in is we've loaded | |
[00:11:27.360 --> 00:11:32.800] the weights of this model that was released by OpenAI. And now using the Hagen phase transformers, | |
[00:11:32.800 --> 00:11:39.520] we can not just get all the raw weights, but we can also get the what they call pipeline | |
[00:11:39.520 --> 00:11:46.000] and sample from it. So this is the prefix. Hello, I'm a language model comma. And then we're sampling | |
[00:11:46.960 --> 00:11:52.960] 30 tokens. And we're getting five sequences. And I ran this, and this is what it produced. | |
[00:11:52.960 --> 00:11:59.440] Hello, I'm a language model. But what I'm really doing is making a human-readable document. | |
[00:11:59.440 --> 00:12:03.520] There are other languages, but those are thought that thought. So you can read through these if you | |
[00:12:03.520 --> 00:12:08.800] like. But basically, these are five different completions of the same prefix from this duty | |
[00:12:08.800 --> 00:12:17.280] to 124M. Now, if I go here, I took this example from here. And sadly, even though we are fixing | |
[00:12:17.280 --> 00:12:23.920] the seed, we are getting different generations from the snippet than what they got. So presumably, | |
[00:12:23.920 --> 00:12:30.320] the code changed. But what we see though at this stage that's important is that we are getting | |
[00:12:30.320 --> 00:12:36.400] coherent text. So we've loaded the model successfully. We can look at all its parameters. And the keys | |
[00:12:36.400 --> 00:12:42.720] tell us where in the model these come from. And we want to actually write our own GPT2 class so | |
[00:12:42.720 --> 00:12:46.400] that we have full understanding of what's happening there. We don't want to be working with something | |
[00:12:46.400 --> 00:12:51.520] like the modeling GPT2.py, because it's just too complicated. We want to write this from scratch | |
[00:12:51.520 --> 00:12:57.120] ourselves. So we're going to be implementing the GPT model here in parallel. And as our first task, | |
[00:12:57.120 --> 00:13:04.480] let's load the GPT2 on 24M into the class that we're going to develop here from scratch. That's | |
[00:13:04.480 --> 00:13:10.080] going to give us confidence that we can load the OpenAI model. And therefore, there's a setting of | |
[00:13:10.080 --> 00:13:14.560] weights that exactly is the 124 model. But then of course, what we're going to do is we're going | |
[00:13:14.560 --> 00:13:20.880] to initialize the model from scratch instead, and try to train it ourselves on a bunch of documents | |
[00:13:20.880 --> 00:13:24.880] that we're going to get. And we're going to try to surpass that model. So we're going to get | |
[00:13:24.880 --> 00:13:28.000] different weights and everything's going to look different, hopefully better even. | |
[00:13:28.000 --> 00:13:33.360] But we're going to have a lot of confidence that because we can load the OpenAI model, | |
[00:13:33.360 --> 00:13:37.680] we are in the same model family and model class, and we just have to rediscover a good setting | |
[00:13:37.680 --> 00:13:44.400] of the weights, but from scratch. So let's now write the GPT2 model, and let's load the weights, | |
[00:13:44.400 --> 00:13:48.960] and make sure that we can also generate text that looks coherent. Okay, so let's now swing | |
[00:13:48.960 --> 00:13:53.280] over to the attention is on any paper that started everything. And let's scroll over to the model | |
[00:13:53.280 --> 00:13:59.200] architecture, the original transformer. Now remember that GPT2 is slightly modified from the original | |
[00:13:59.200 --> 00:14:05.680] transformer. In particular, we do not have the encoder. GPT2 is a decoder only transformer, | |
[00:14:05.680 --> 00:14:10.640] as we call it. So this entire encoder here is missing. And in addition to that, this cross | |
[00:14:10.640 --> 00:14:16.480] attention here that was using that encoder is also missing. So we delete this entire part. | |
[00:14:16.480 --> 00:14:21.680] Everything else stays almost the same, but there are some differences that we're going to | |
[00:14:21.680 --> 00:14:29.360] sort of look at here. So there are two main differences. When we go to the GPT2 paper, | |
[00:14:29.360 --> 00:14:35.680] under 2.3 that model, we notice that first, there's a reshuffling of the layer norms. So they change | |
[00:14:35.680 --> 00:14:43.680] place. And second, an additional layer normalization was added here to the final self-intentioned block. | |
[00:14:43.680 --> 00:14:48.960] So basically, all the layer norms here, instead of being after the MLP or after the attention, | |
[00:14:48.960 --> 00:14:54.000] they swing before it. And an additional layer norm gets added here right before the final classifier. | |
[00:14:54.000 --> 00:15:01.440] So now let's implement some of the first sort of skeleton and end modules here in our GPT and | |
[00:15:01.440 --> 00:15:06.720] end module. And in particular, we're going to try to match up to this schema here that is used by | |
[00:15:06.720 --> 00:15:11.120] hugging face transformers, because that will make it much easier to load these weights from this | |
[00:15:11.120 --> 00:15:17.440] state dict. So we want something that reflects this schema here. So here's what I came up with. | |
[00:15:19.760 --> 00:15:25.920] Basically, we see that the main container here that has all the modules is called transformer. | |
[00:15:25.920 --> 00:15:30.000] So I'm reflecting that within an end module dict. And this is basically a module that | |
[00:15:30.000 --> 00:15:36.080] allows you to index into the sub-modules using keys, just like a dictionary, strings. | |
[00:15:36.080 --> 00:15:41.360] Within it, we have the weights of the token embeddings, W-T-E, and that's an | |
[00:15:41.360 --> 00:15:46.560] an embedding. And the weights of the position embeddings, which is also just an an embedding. | |
[00:15:46.560 --> 00:15:51.280] And if you remember, an embedding is really just a fancy little wrapper module around just a single | |
[00:15:51.280 --> 00:15:59.360] array of numbers, a single block of numbers, just like this. It's a single tensor. | |
[00:15:59.360 --> 00:16:05.600] And embedding is a glorified wrapper around a tensor that allows you to | |
[00:16:05.600 --> 00:16:11.760] access its elements by indexing into the rows. Now, in addition to that, we see here that we have | |
[00:16:11.760 --> 00:16:17.840] a dot H, and then there's a, this is indexed using numbers instead of indexed using strings. | |
[00:16:17.840 --> 00:16:23.600] So there's a dot H dot zero, one, two, et cetera, all the way up till dot H dot 11. | |
[00:16:23.600 --> 00:16:28.560] And that's because there are 12 layers here in this transformer. So to reflect that, | |
[00:16:28.560 --> 00:16:34.160] I'm creating also an H, I think that probably stands for hidden. And instead of a module dict, | |
[00:16:34.160 --> 00:16:39.040] this is a model list. So we can index it using integers exactly as we see here, dot zero, | |
[00:16:39.040 --> 00:16:46.640] dot one, two, et cetera. And the module list has n layer blocks, and the blocks are yet to be | |
[00:16:46.640 --> 00:16:52.720] defined and then module in a bit. In addition to that, following the GPT two paper, we need an | |
[00:16:52.720 --> 00:16:58.720] additional final layer norm that we're going to put in there. And then we have the final classifier, | |
[00:16:58.720 --> 00:17:07.680] the language model head, which projects from 768, the number of embedding dimensions in this GPT, | |
[00:17:07.680 --> 00:17:13.760] all the way to the vocab size, which is 550,257. And GPT two uses no bias for this final | |
[00:17:13.760 --> 00:17:22.080] sort of projection. So this is the skeleton. And you can see that it reflects this. So the WTE | |
[00:17:22.080 --> 00:17:26.720] has the token embeddings. Here it's called output embedding, but it's really the token embeddings. | |
[00:17:26.720 --> 00:17:32.320] The PE is the positional encoyings. Those two pieces of information as we saw previously are | |
[00:17:32.320 --> 00:17:38.640] going to add, and then going to the transformer. The dot H is the older blocks in gray. And the | |
[00:17:38.640 --> 00:17:45.840] LNF is this new layer that gets added here by the GPT two model. And lmhead is this linear part | |
[00:17:45.840 --> 00:17:53.040] here. So that's the skeleton of the GPT two. We now have to implement the block. Okay, so let's | |
[00:17:53.040 --> 00:17:59.760] now recurse to the block itself. So we want to define the block. So I'll start putting them here. | |
[00:18:00.320 --> 00:18:06.960] So the block, I like to write it out like this. These are some of the initializations. And then | |
[00:18:06.960 --> 00:18:12.560] this is the actual forward pass of what this block computes. And notice here that there's a change | |
[00:18:12.560 --> 00:18:19.200] from the transformer again that is mentioned in the GPT two paper. So here, the layer normalizations | |
[00:18:19.200 --> 00:18:23.360] are after the application of attention or feed forward. In addition to that note, | |
[00:18:23.360 --> 00:18:28.480] that the normalizations are inside the residual stream. You see how feed forward is applied, | |
[00:18:28.480 --> 00:18:33.920] and then this arrow goes through and through the normalization. So that means that your residual | |
[00:18:33.920 --> 00:18:40.080] pathway has normalizations inside them. And this is not very good or desirable. You actually | |
[00:18:40.080 --> 00:18:45.600] prefer to have a single clean residual stream all the way from supervision all the way down to | |
[00:18:45.600 --> 00:18:52.480] the inputs, the tokens. And this is very desirable and nice because the gradients that flow from the | |
[00:18:52.480 --> 00:18:58.480] top, if you remember from your micro grad, addition just distributes gradients during the backward | |
[00:18:58.480 --> 00:19:05.600] stage to both of its branches equally. So addition is a branch in the gradients. And so that means | |
[00:19:05.600 --> 00:19:11.840] that the gradients from the top flow straight to the inputs, the tokens, through the residual pathways | |
[00:19:11.840 --> 00:19:16.000] unchanged. But then in addition to that, the gradient also flows through the blocks. And the | |
[00:19:16.000 --> 00:19:20.560] blocks, you know, contribute their own contribution over time and kick in and change the optimization | |
[00:19:20.560 --> 00:19:27.280] over time. But basically, clean residual pathway is desirable from an optimization perspective. And | |
[00:19:27.280 --> 00:19:33.040] then this is the pre normalization version, where you see that our X first goes through the layer | |
[00:19:33.040 --> 00:19:39.120] normalization, and then the attention, and then goes back out to go to the layer normalization | |
[00:19:39.120 --> 00:19:44.960] number two, and the multi-layer perceptron, sometimes also referred to as a feed forward network or an | |
[00:19:44.960 --> 00:19:50.800] F of M. And then that goes into the residual stream again. And the one more thing that is kind of | |
[00:19:50.800 --> 00:19:56.000] interesting to note is that recall that attention is a communication operation. It is where all the | |
[00:19:56.000 --> 00:20:01.760] tokens, and there's 1024 tokens lined up in a sequence. And this is where the tokens communicate, | |
[00:20:01.760 --> 00:20:09.280] this is where they exchange information. So attention is a aggregation function. It's a pooling function. | |
[00:20:09.280 --> 00:20:18.000] It's a weighted sum function. It is a reduce operation. Whereas MLP, this MLP here happens | |
[00:20:18.000 --> 00:20:22.160] at every single token individually. There's no information being collected or exchanged between | |
[00:20:22.160 --> 00:20:28.240] the tokens. So the attention is the reduce, and the MLP is the map. And what you end up with is | |
[00:20:28.240 --> 00:20:33.520] that the transformers just ends up just being a repeated application of map reduce. If you want | |
[00:20:33.520 --> 00:20:38.480] to think about it that way. So this is where they communicate, and this is where they think | |
[00:20:38.480 --> 00:20:42.640] individually about the information that they gathered, and every one of these blocks, | |
[00:20:42.640 --> 00:20:49.520] iteratively defines the representation inside the residual stream. So this is our block. | |
[00:20:49.520 --> 00:20:55.200] It's likely modified from this picture. Okay, so let's now move on to the MLP. | |
[00:20:55.200 --> 00:21:01.280] So the ML block I implemented as follows. It is relatively straightforward. We basically have | |
[00:21:01.280 --> 00:21:08.400] two linear projections here that are sandwiched in between the GALU nonlinearity. So end-end up | |
[00:21:08.400 --> 00:21:15.040] GALU, approximate is 10H. Now when we swing on, swing over to the part of documentation, | |
[00:21:15.040 --> 00:21:20.240] this is end up GALU, and it has this format, and it has two versions, the original version of GALU, | |
[00:21:20.240 --> 00:21:25.600] which we'll step into in a bit, and the approximate version of GALU, which we can request using 10H. | |
[00:21:26.640 --> 00:21:33.360] So as you can see, just as a preview here, GALU is basically like a relu, except there's no flat, | |
[00:21:33.360 --> 00:21:39.520] exactly flat tail here at exactly zero. But otherwise, it looks very much like a slightly | |
[00:21:39.520 --> 00:21:46.320] smoother relu. It comes from this paper here, Gaussian error linear units, and you can step | |
[00:21:46.320 --> 00:21:50.960] through this paper, and there's some mathematical kind of reasoning that leads to interpretation, | |
[00:21:50.960 --> 00:21:56.560] at least to the specific formulation. It has to do with stochastic radio risers, and the expectation | |
[00:21:56.560 --> 00:22:00.560] of a modification to it have to drop out, so you can read through all of that if you'd like here. | |
[00:22:00.560 --> 00:22:05.920] And there's a little bit of the history as to why there's an approximate version of GALU, | |
[00:22:05.920 --> 00:22:09.040] and that comes from this issue here, as far as I can tell. | |
[00:22:09.040 --> 00:22:16.880] And in this issue, Daniel Hendrix mentions that at the time when they developed this nonlinearity, | |
[00:22:16.880 --> 00:22:22.000] the IRF function, which you need to evaluate the exact GALU, was very slow in tensor flows, | |
[00:22:22.000 --> 00:22:26.960] so they ended up basically developing this approximation. And this approximation then ended up being | |
[00:22:26.960 --> 00:22:31.680] picked up by BERT and by GPT-2, et cetera. But today, there's no real good reason to use the | |
[00:22:31.680 --> 00:22:37.200] approximate version. You'd prefer to just use the exact version, because in my expectations | |
[00:22:37.200 --> 00:22:42.640] that there's no big difference anymore, and this is kind of like a historical kind of quirk. | |
[00:22:42.640 --> 00:22:49.760] But we are trying to reproduce GPT-2 exactly, and GPT-2 used the 10H approximate version, | |
[00:22:49.760 --> 00:22:56.720] so we prefer to stick with that. Now, one other reason to actually just intuitively use GALU, | |
[00:22:56.720 --> 00:23:02.000] instead of RALU, is previously in videos in the past, we've spoken about the dead RALU | |
[00:23:02.000 --> 00:23:08.640] neuron problem, where in this tale of RALU, if it's exactly flat at zero, any activations that fall | |
[00:23:08.640 --> 00:23:13.280] there will get exactly zero gradient. There's no change, there's no adaptation, there's no | |
[00:23:13.280 --> 00:23:18.560] development of the network, if any of these activations end in this flat region. But the | |
[00:23:18.560 --> 00:23:23.120] GALU always contributes a local gradient, and so there's always going to be a change, always | |
[00:23:23.120 --> 00:23:28.080] going to be an adaptation, and sort of smoothing it out ends up empirically working better in practice, | |
[00:23:28.080 --> 00:23:33.360] as demonstrated in this paper, and also as demonstrated by it being picked up by the BERT paper, GPT-2 | |
[00:23:33.360 --> 00:23:39.280] paper, and so on. So for that reason, we adopt this nonlinearity here in the 10, in the GPT-2 | |
[00:23:39.280 --> 00:23:44.960] reproduction. Now, in more modern networks, also like LAMA-3 and so on, this nonlinearity also | |
[00:23:44.960 --> 00:23:51.520] further changes to SWIGU and other variants like that, but for GPT-2, they use this approximate GALU. | |
[00:23:51.520 --> 00:23:57.440] Okay, and finally, we have the attention operation. So let me paste in my attention. | |
[00:23:57.440 --> 00:24:04.400] So I know this is a lot, so I'm going to go through this a bit quickly, a bit slowly, but | |
[00:24:04.400 --> 00:24:08.160] not too slowly, because we have covered this in the previous video, and I would just punch you there. | |
[00:24:08.160 --> 00:24:14.480] So this is the attention operation. Now, in the previous video, you will remember, this is not | |
[00:24:14.480 --> 00:24:20.560] just attention. This is multi-headed attention, right? And so in the previous video, we had this | |
[00:24:20.560 --> 00:24:26.400] multi-headed attention module, and this implementation made it obvious that these heads are not actually | |
[00:24:26.400 --> 00:24:32.720] that complicated. There's basically in parallel inside every attention block. There's multiple | |
[00:24:32.720 --> 00:24:38.480] heads, and they're all functioning in parallel, and their outputs are just being concatenated, | |
[00:24:38.480 --> 00:24:43.520] and that becomes the output of the multi-headed attention. So the heads are just kind of like | |
[00:24:43.520 --> 00:24:49.840] parallel streams, and their outputs get concatenated. And so it was very simple and made the head | |
[00:24:49.840 --> 00:24:53.440] be kind of like fairly straightforward in terms of its implementation. | |
[00:24:53.440 --> 00:24:59.440] What happens here is that instead of having two separate modules, and indeed many more modules | |
[00:24:59.440 --> 00:25:06.400] that get concatenated, all of that is just put into a single self-attention module. And instead, | |
[00:25:06.400 --> 00:25:13.520] I'm being very careful in doing a bunch of transpose split tensor gymnastics to make this very | |
[00:25:13.520 --> 00:25:17.360] efficient impact work. But fundamentally and algorithmically, nothing is different from the | |
[00:25:17.360 --> 00:25:26.240] implementation we saw before in this giver repository. So to remind you very briefly, and I don't want | |
[00:25:26.240 --> 00:25:32.880] to go into this in too much time, but we have these tokens lined up in a sequence, and there's | |
[00:25:32.880 --> 00:25:39.040] one thousand twenty of them. And then each token at this stage of the attention emits three vectors, | |
[00:25:39.040 --> 00:25:46.560] the query key and the value. And first what happens here is that the queries and the keys | |
[00:25:46.560 --> 00:25:52.960] have to multiply each other to get sort of the attention amount, like how interesting they find | |
[00:25:52.960 --> 00:25:57.440] each other. So they have to interact multiplicatively. So we're doing here as we're calculating the | |
[00:25:57.440 --> 00:26:03.440] QKV while splitting it. And then there's a bunch of gymnastics, as I mentioned here. And the way this | |
[00:26:03.440 --> 00:26:10.160] works is that we're basically making the number of heads and H into a batch dimension. And so it's | |
[00:26:10.160 --> 00:26:16.800] a batch dimension just like B, so that in these operations that follow PyTorch treats B and H as | |
[00:26:16.800 --> 00:26:22.880] batches. And it applies all the operations on all of them in parallel in both the batch and the heads. | |
[00:26:24.080 --> 00:26:28.800] And the operations that can apply are number one, the queries and the keys interact to give us | |
[00:26:28.800 --> 00:26:35.280] our attention. This is the autoregressive masks that make sure that the tokens only attend to | |
[00:26:35.280 --> 00:26:42.160] tokens before them and never to tokens in the future. The softmax here normalizes the attention, | |
[00:26:42.160 --> 00:26:48.080] so it sums to one always. And then recall from the previous video that doing the attention matrix | |
[00:26:48.080 --> 00:26:53.280] multiply with the values is basically a way to do a weighted sum of the values of the tokens that | |
[00:26:53.280 --> 00:26:58.400] we found interesting at every single token. And then the final transpose contiguous | |
[00:26:58.400 --> 00:27:03.760] and view is just reassembling all of that again. And this actually performs the concatenation | |
[00:27:03.760 --> 00:27:09.760] operation. So you can step through this slowly if you'd like, but it is equivalent mathematically | |
[00:27:09.760 --> 00:27:14.960] to our previous implementation is just more efficient in PyTorch. So that's why I chose this | |
[00:27:14.960 --> 00:27:20.240] implementation instead. Now in addition to that, I'm being careful with how I name my variables. | |
[00:27:20.240 --> 00:27:26.560] So for example, c@en is the same as c@en. And so actually our keys should basically exactly | |
[00:27:26.560 --> 00:27:30.560] follow the schema of the hugging face transformers code. And that will make it very easy for us to | |
[00:27:30.560 --> 00:27:36.800] now port over all the weights from exactly this sort of naming conventions, because all of our | |
[00:27:36.800 --> 00:27:42.800] variables are named the same thing. But at this point, we have finished the DQPT2 implementation. | |
[00:27:42.800 --> 00:27:48.000] And what that allows us to do is we don't have to basically use this file from ugly face, | |
[00:27:48.000 --> 00:27:58.560] which is fairly long. This is 2,000 lines of code. Instead, we just have less than 100 | |
[00:27:58.560 --> 00:28:03.120] lines of code. And this is the complete GPT2 implementation. So at this stage, we should just | |
[00:28:03.120 --> 00:28:08.320] be able to take over all the weights, set them, and then do generation. So let's see what that | |
[00:28:08.320 --> 00:28:12.640] looks like. Okay, so here, I've also changed the GPT config so that the numbers here, | |
[00:28:12.640 --> 00:28:18.400] dive parameters agree with the GPT2-124M model. So the maximum sequence length, which I call block | |
[00:28:18.400 --> 00:28:26.880] size here, is 124. The number of tokens is 5257, which if you watch my token as a video, know that | |
[00:28:26.880 --> 00:28:35.920] this is 50,000 merges, BP merges, 256 byte tokens, the leaves of the BPE tree, and one special end | |
[00:28:35.920 --> 00:28:41.520] of text token that delimits different documents and can start generation as well. And there are | |
[00:28:41.520 --> 00:28:46.720] 12 layers. There are 12 heads in the attention and the dimension of the transformers was 768. | |
[00:28:46.720 --> 00:28:53.920] So here's how we can now load the parameters from hugging face to our code here and initialize | |
[00:28:53.920 --> 00:29:00.240] the GPT class with those parameters. So let me just copy paste a bunch of code here. And I'm | |
[00:29:00.240 --> 00:29:07.360] not going to go through this code too slow too quickly, too slowly, because honestly, it's not | |
[00:29:07.360 --> 00:29:10.960] that interesting. It's not that exciting. We're just loading the weights. So it's kind of dry. | |
[00:29:10.960 --> 00:29:15.760] But as I mentioned, there are four models in this mini series of GPT2. This is some of the | |
[00:29:15.760 --> 00:29:21.440] Jupiter code code that we had here on the right. I'm just putting it over. These are | |
[00:29:21.440 --> 00:29:26.320] the hyper parameters of the GPT2 models. We're creating the config object and creating our own | |
[00:29:26.320 --> 00:29:31.600] model. And then what's happening here is we're creating the state dict, both for our model and | |
[00:29:31.600 --> 00:29:37.920] for the hugging face model. And then what we're doing here is we're going over the hugging face | |
[00:29:37.920 --> 00:29:45.600] model keys. And we're copying over those tensors. And in the process, we are kind of ignoring a | |
[00:29:45.600 --> 00:29:50.720] few of the buffers. They're not parameters, they're buffers. So for example, attention of bias, | |
[00:29:50.720 --> 00:29:56.000] that's just used for the R aggressive mask. And so we are ignoring some of those masks. | |
[00:29:56.000 --> 00:30:01.680] And that's it. And then one additional kind of annoyance is that this comes from the TensorFlow | |
[00:30:01.680 --> 00:30:06.400] repo. And I'm not sure how this is a little bit annoying, but some of the weights are transposed | |
[00:30:06.400 --> 00:30:11.920] from what PyTorch would want. And so manually, I hard coded the weights that should be transposed. | |
[00:30:11.920 --> 00:30:18.240] And then we transpose them if that is so. And then we return this model. So the firm | |
[00:30:18.240 --> 00:30:26.480] pre-trained is a constructor or a class method in Python that returns the GPT object. If we just | |
[00:30:26.480 --> 00:30:30.880] give it the model type, which in our case is GPT2, the smallest model that we're interested in. | |
[00:30:31.760 --> 00:30:37.200] So this is the code. And this is how you would use it. And we can pop open the terminal here | |
[00:30:37.200 --> 00:30:44.960] in VS code. And we can Python train GPT2.py and fingers crossed. | |
[00:30:44.960 --> 00:30:53.680] Okay, so we didn't crash. And so we can load the weights and the biases and everything else | |
[00:30:53.680 --> 00:30:58.080] into our and in module. But now let's also get additional confidence that this is working. | |
[00:30:58.080 --> 00:31:02.160] And let's try to actually generate from this model. Okay, now before we can actually generate | |
[00:31:02.160 --> 00:31:06.080] from this model, we have to be able to forward it. We didn't actually write that code yet. | |
[00:31:06.080 --> 00:31:12.640] So here's the forward function. So the input to the forward is going to be our indices, | |
[00:31:12.640 --> 00:31:20.160] our tokens, token indices. And they are always of shape B by T. And so we have batch dimension | |
[00:31:20.160 --> 00:31:27.120] of B. And then we have the time dimension of up to T. And the T can be more than the block size. | |
[00:31:27.120 --> 00:31:32.800] The block size is the maximum sequence length. So B by T indices arranged a sort of like a | |
[00:31:32.800 --> 00:31:39.520] two dimensional layout. And remember that basically every single row of this is of size up to block | |
[00:31:39.520 --> 00:31:46.320] size. And this is T tokens that are in a sequence. And then we have B independent sequences stacked | |
[00:31:46.320 --> 00:31:51.760] up in a batch so that this is efficient. Now here we are forwarding the position embeddings | |
[00:31:51.760 --> 00:31:55.760] and the token embeddings. And this code should be very recognizable from the previous lecture. | |
[00:31:56.320 --> 00:32:02.080] So we basically use a range, which is kind of like a version of range, but for PyTorch. | |
[00:32:02.080 --> 00:32:08.880] And we're iterating from zero to T and creating this positions sort of indices. | |
[00:32:08.880 --> 00:32:15.120] And then we are making sure that they're unseen devices IDX because we're not going to be training | |
[00:32:15.120 --> 00:32:19.040] on only CPU. That's going to be too inefficient. We want to be training on GPU. And that's going | |
[00:32:19.040 --> 00:32:24.960] to come in a bit. Then we have the position embeddings and the token embeddings and the addition | |
[00:32:24.960 --> 00:32:29.360] operation of those two. Now notice that position embeddings are going to be identical for every | |
[00:32:29.360 --> 00:32:36.880] single row of of input. And so there's broadcasting hidden inside this plus where we have to create | |
[00:32:36.880 --> 00:32:41.200] an additional dimension here. And then these two add up because the same position embeddings | |
[00:32:41.200 --> 00:32:46.240] apply to every single row of our examples stacked up in a batch. Then we forward the | |
[00:32:46.240 --> 00:32:52.320] transformer blocks and finally the last layer norm and the element. So what comes out after | |
[00:32:52.320 --> 00:32:59.200] forward is the logits. And if the input was B by T indices, then at every single B by T, | |
[00:32:59.200 --> 00:33:06.720] we will calculate the logits for what token comes next in a sequence. So what is the token | |
[00:33:06.720 --> 00:33:14.080] B, T plus one, the one on the right of this token. And vocab size here is the number of | |
[00:33:14.080 --> 00:33:19.120] possible tokens. And so therefore this is the tensor that we're going to obtain. And these | |
[00:33:19.120 --> 00:33:25.680] logits are just a softmax away from coming probabilities. So this is the forward pass | |
[00:33:25.680 --> 00:33:30.000] of the network. And now we can get logits. And so we're going to be able to generate from the model | |
[00:33:30.000 --> 00:33:35.440] imminently. Okay, so now we're going to try to set up the identical thing on the left here | |
[00:33:35.440 --> 00:33:41.120] that matches hugging face on the right. So here we sampled from the pipeline and we sampled | |
[00:33:41.120 --> 00:33:46.560] five times up to 30 tokens with a prefix of hello on the language model. And these are the | |
[00:33:46.560 --> 00:33:51.040] completions that we achieved. So we're going to try to replicate that a lot here. So number | |
[00:33:51.040 --> 00:33:54.960] term sequences five max length is 30. So the first thing we do, of course, is we initialize | |
[00:33:54.960 --> 00:34:00.480] our model, then we put it into evaluation mode. Now this is a good practice to put the model into | |
[00:34:00.480 --> 00:34:05.680] eval when you're not going to be training it, you're just going to be using it. And I don't | |
[00:34:05.680 --> 00:34:09.760] actually know if this is doing anything right now for the following reason. Our model up above | |
[00:34:09.760 --> 00:34:16.080] here contains no modules or layers that actually have a different behavior at training or evaluation | |
[00:34:16.080 --> 00:34:20.960] time. So for example, dropout, bashroom, and a bunch of other layers have this kind of behavior. | |
[00:34:20.960 --> 00:34:24.960] But all of these layers that we've used here should be identical in both training and evaluation | |
[00:34:24.960 --> 00:34:32.080] time. So potentially model that eval is nothing, but then I'm not actually sure if this is the case | |
[00:34:32.080 --> 00:34:37.200] and maybe pytorch internals do some clever things depending on the evaluation mode inside here. | |
[00:34:37.200 --> 00:34:43.680] The next thing we're doing here is we are moving the entire model to CUDA. So we're moving this | |
[00:34:43.680 --> 00:34:49.840] all of the tensors to GPU. So I'm SSH'd here to a cloud box and I have a bunch of GPUs on this box. | |
[00:34:49.840 --> 00:34:55.600] And here I'm moving the entire model and all of its members and all of its tensors and everything | |
[00:34:55.600 --> 00:35:01.600] like that. Everything gets shipped off to basically a whole separate computer that is sitting on the | |
[00:35:01.600 --> 00:35:06.880] GPU. And the GPU is connected to the CPU and they can communicate, but it's basically a whole separate | |
[00:35:06.880 --> 00:35:11.840] computer with its own computer architecture. And it's really located to parallel processing tasks | |
[00:35:11.840 --> 00:35:17.360] like those of running neural networks. So I'm doing this so that the model lives on the GPU, | |
[00:35:17.360 --> 00:35:21.840] a whole separate computer, and it's just going to make our code a lot more efficient, | |
[00:35:21.840 --> 00:35:28.320] because all of this stuff runs a lot more efficiently than the GPUs. So that's the model itself. | |
[00:35:28.320 --> 00:35:35.280] Now, the next thing we want to do is we want to start with this as the prefix when we do the | |
[00:35:35.280 --> 00:35:42.160] generation. So let's actually create those prefix tokens. So here's the code that I've written. | |
[00:35:42.160 --> 00:35:47.520] We're going to import the tech token library from OpenAI and we're going to get the GPT-2 encoding. | |
[00:35:47.520 --> 00:35:54.800] So that's the tokenizer for GPT-2. And then we're going to encode this string and get a list of | |
[00:35:54.800 --> 00:36:00.400] integers which are the tokens. Now these integers here should actually be fairly straightforward | |
[00:36:00.400 --> 00:36:05.360] because we can just copy paste this string. And we can sort of inspect what it is in | |
[00:36:05.360 --> 00:36:09.680] the tech tokenizer. So just pasting that in, these are the tokens that are going to come out. | |
[00:36:09.680 --> 00:36:16.720] So this list of integers is what we expect tokens to become. And as you recall, if you saw my video, | |
[00:36:16.720 --> 00:36:21.600] of course, all the tokens, they're just little string chunks, right? So these are, this is the | |
[00:36:21.600 --> 00:36:28.800] chunkation of this string into GPT-2 tokens. So once we have those tokens, it's a list of integers, | |
[00:36:28.800 --> 00:36:33.760] we can create a torch tensor out of it. In this case, it's eight tokens. And then we're going to | |
[00:36:33.760 --> 00:36:40.800] replicate these eight tokens for five times to get five rows of eight tokens. And that is our | |
[00:36:40.800 --> 00:36:50.160] initial input X, as I call it here. And it lives on the GPU and more. So X now is this IDX that we | |
[00:36:50.160 --> 00:36:57.360] can put into forward to get our logits so that we know what comes as the sixth token, | |
[00:36:57.360 --> 00:37:04.720] sorry, as the ninth token in every one of these five rows. Okay, and we are now ready to generate. | |
[00:37:04.720 --> 00:37:10.400] So let me paste in one more code block here. So what's happening here in this code block is | |
[00:37:10.400 --> 00:37:17.600] we have these X, which is of size B by T, right? So batch by time. And we're going to be in every | |
[00:37:17.600 --> 00:37:23.200] iteration of this loop, we're going to be adding a column of new indices into each one of these rows, | |
[00:37:23.200 --> 00:37:28.240] right? And so these are the new indices, and we're appending them to the sequence as we're | |
[00:37:28.240 --> 00:37:34.640] sampling. So with each loop iteration, we get one more column into X. And all of the operations | |
[00:37:34.640 --> 00:37:38.320] happening in the context manager of torch.no grad, this is just telling PyTorch that we're | |
[00:37:38.320 --> 00:37:42.560] not going to be calling that backward on any of this. So it doesn't have to cache all the | |
[00:37:42.560 --> 00:37:46.720] intermediate tensors, it's not going to have to prepare in any way for a potential backward | |
[00:37:46.720 --> 00:37:53.520] later. And this saves a lot of space and also possibly some time. So we get our logits, | |
[00:37:53.520 --> 00:37:58.960] we get the logits at only the last location, we throw away all the other logits, we don't need | |
[00:37:58.960 --> 00:38:05.920] them, we only care about the last column's logits. So this is being wasteful. But this is just kind | |
[00:38:05.920 --> 00:38:12.640] of like an inefficient implementation of sampling. So it's correct by inefficient. So we get the last | |
[00:38:12.640 --> 00:38:16.960] column of logits, pass it through a soft mass to get our probabilities. Then here, I'm doing | |
[00:38:16.960 --> 00:38:21.520] top case sampling of 50. And I'm doing that because this is the hugging face default. So just | |
[00:38:21.520 --> 00:38:30.000] looking at the hugging face dogs here of a pipeline, there's a bunch of quarks that go into hugging | |
[00:38:30.000 --> 00:38:36.720] face. And I mean, that's kind of a lot honestly. But I guess the important one that I noticed is | |
[00:38:36.720 --> 00:38:43.440] that they're using top K by default, which is 50. And what that does is that, so that's being used | |
[00:38:43.440 --> 00:38:47.760] here as well. And what that does is basically we want to take our probabilities, and we only want | |
[00:38:47.760 --> 00:38:54.080] to keep the top 50 probabilities. And anything that is lower than the 50th probability, we just | |
[00:38:54.080 --> 00:38:59.680] clamped to zero and renormalize. And so that way we are never sampling very rare tokens. | |
[00:38:59.680 --> 00:39:04.640] The tokens we're going to be sampling are always in the top 50 of most likely tokens. | |
[00:39:04.640 --> 00:39:09.040] And this helps keep the model kind of on track. And it doesn't blabber on and it doesn't get lost, | |
[00:39:09.040 --> 00:39:14.640] and doesn't go off the rails as easily. And it kind of like sticks in the vicinity of likely | |
[00:39:14.640 --> 00:39:19.280] tokens a lot better. So this is the way to do it in PyTorch. And you can step through it if you like, | |
[00:39:19.280 --> 00:39:23.600] I don't think it's super insightful. So I'll speed through it. But roughly speaking, we get this new | |
[00:39:23.600 --> 00:39:31.440] column of tokens. We append them on X, and basically the columns of X grow until this while loop gets | |
[00:39:31.440 --> 00:39:40.800] tripped up. And then finally, we have an entire X of size. Five by 30 in this case, in this example. | |
[00:39:40.800 --> 00:39:46.480] And we can just basically print all those individual rows. So I'm getting all the rows. | |
[00:39:46.480 --> 00:39:51.040] I'm getting all the tokens that are sampled. And I'm using the decode function from tick | |
[00:39:51.040 --> 00:39:57.120] tokenizer to get back the string, which we can print. And so terminal, new terminal. | |
[00:39:59.280 --> 00:40:01.600] And let me Python train GPT to. | |
[00:40:01.600 --> 00:40:14.640] Okay. So these are the generations that we're getting. Hello, I'm a language model, not a program. | |
[00:40:14.640 --> 00:40:21.600] New line, new line, etc. Hello, I'm a language model. And one of the main things that bothers | |
[00:40:21.600 --> 00:40:26.000] me when they create languages is how easy it becomes to create something that I mean, so this | |
[00:40:26.000 --> 00:40:30.160] will just like blabber on right in all these cases. Now one thing you will notice is that these | |
[00:40:30.160 --> 00:40:36.960] generations are not the generations of fucking face here. And I can't find the discrepancy, to be | |
[00:40:36.960 --> 00:40:41.120] honest, and I didn't fully go through all these options, but probably there's something else hiding | |
[00:40:41.120 --> 00:40:45.680] in on addition to the top P. So I'm not able to match it up, but just for correctness. | |
[00:40:45.680 --> 00:40:51.920] Down here below in the Drupal notebook and using the hugging face model. So this is the hugging | |
[00:40:51.920 --> 00:41:00.560] face model here. I was, I replicated the code. And if I do this and I run that, then I am getting the | |
[00:41:00.560 --> 00:41:07.280] same results. So basically, the model internals are not wrong. It's just I'm not 100% sure what | |
[00:41:07.280 --> 00:41:12.000] the pipeline does in hugging face. And that's why we're not able to match them up. But otherwise, | |
[00:41:12.000 --> 00:41:17.600] the code is correct. And we've loaded all the tensors correctly. So we're initializing the model | |
[00:41:17.600 --> 00:41:22.560] correctly and everything here works. So long story short, we've ported all the weights, we | |
[00:41:22.560 --> 00:41:28.080] initialize the GPT to this is the exact opening on GPT to, and it can generate sequences and they | |
[00:41:28.080 --> 00:41:34.400] look sensible. And now here, of course, we're initializing with GPT to model weights. But now | |
[00:41:34.400 --> 00:41:38.880] we want to initialize from scratch from random numbers. And we want to actually train the model | |
[00:41:38.880 --> 00:41:46.080] that will give us sequences as good as or better than these ones in quality. And so that's what we | |
[00:41:46.080 --> 00:41:50.400] turn to next. So it turns out that using the random model is actually fairly straightforward, | |
[00:41:50.400 --> 00:41:57.520] because PyTorch already initializes our model randomly and by default. So when we create the | |
[00:41:57.520 --> 00:42:03.520] GPT model and the constructor, this is all, all of these layers and modules have random | |
[00:42:03.520 --> 00:42:08.800] initializers that are there by default. So when these linear layers get created and so on, | |
[00:42:08.800 --> 00:42:12.880] there's default constructors, for example, using the Javier initialization that we saw in the past | |
[00:42:13.520 --> 00:42:19.440] to construct the weights of these players. And so creating a random model instead of a GPT | |
[00:42:19.440 --> 00:42:24.480] to model is actually fairly straightforward. And we would just come here. And instead we would | |
[00:42:24.480 --> 00:42:31.360] create model equals GPT. And then we want to use the default config, GPT config. And the default | |
[00:42:31.360 --> 00:42:37.760] config uses the 124m parameters. So this is the random model initialization. And we can run it. | |
[00:42:43.120 --> 00:42:49.120] And we should be able to get results. Now the results here, of course, are total garbage | |
[00:42:49.120 --> 00:42:53.440] carble. And that's because it's a random model. And so we're just getting all these random token | |
[00:42:53.440 --> 00:42:59.200] strength pieces chunked up totally a random. So that's what we have right now. Now, one more | |
[00:42:59.200 --> 00:43:03.120] thing I wanted to point out, by the way, is in case you do not have CUDA available, because you | |
[00:43:03.120 --> 00:43:08.400] don't have a GPU, you can still follow along with what we're doing here, to some extent. | |
[00:43:09.440 --> 00:43:13.360] And probably not to the very end, because by the end, we're going to be using multiple GPUs and | |
[00:43:13.360 --> 00:43:17.600] actually doing a serious training run. But for now, you can actually follow along decently. Okay. | |
[00:43:17.600 --> 00:43:22.960] So one thing that I like to do in PyTorch is I like to auto detect the device that is available | |
[00:43:22.960 --> 00:43:30.480] to you. So in particular, you could do that like this. So here we are trying to detect the device to | |
[00:43:30.480 --> 00:43:35.360] run on that has the highest compute capability. You can think about it that way. So by default, | |
[00:43:35.360 --> 00:43:39.200] we start with CPU, which of course is available everywhere, because every single computer will | |
[00:43:39.200 --> 00:43:45.440] have a CPU. But then we can try to detect the heavy GPU, you saw use a CUDA. And then if you don't | |
[00:43:45.440 --> 00:43:52.080] have a CUDA, do you at least have MPS? MPS is the backend for Apple Silicon. So if you have a MacBook | |
[00:43:52.080 --> 00:43:56.720] that is fairly new, you probably have Apple Silicon on the inside. And then that has a GPU that is | |
[00:43:56.720 --> 00:44:01.680] actually fairly capable, depending on which MacBook you have. And so you can use MPS, which will be | |
[00:44:01.680 --> 00:44:06.720] potentially faster than CPU. And so we can print the device here. Now, once we have the device, | |
[00:44:06.720 --> 00:44:14.720] we can actually use it in place of CUDA. So we just swap it in. And notice that here, when we call | |
[00:44:14.720 --> 00:44:22.720] model on X, if this X here is on CPU, instead of GPU, then it will work fine, because here in the | |
[00:44:22.720 --> 00:44:28.880] forward, which is where PyTorch will come, when we create a pose, we are careful to use the device | |
[00:44:28.880 --> 00:44:34.480] of IDX to create this tensor as well. And so there won't be any mismatch where one tensor is on | |
[00:44:34.480 --> 00:44:40.880] CPU, one is on GPU, and that you can't combine those. But here we are carefully initializing on | |
[00:44:40.880 --> 00:44:47.680] the correct device, as indicated by the input to this model. So this will auto detect device. | |
[00:44:47.680 --> 00:44:53.040] For me, this will be, of course, GPU. So using device CUDA. | |
[00:44:55.600 --> 00:45:00.480] But you can also run with, as I mentioned, another device. And it's not going to be too | |
[00:45:00.480 --> 00:45:08.160] much slower. So if I override device here, if I override device equals CPU, then | |
[00:45:08.160 --> 00:45:15.120] we'll swap print CUDA, of course. But now we're actually using CPU, one, two, three, | |
[00:45:15.120 --> 00:45:22.640] four, five, six, okay, about six seconds. And actually, we're not using Torch compile and | |
[00:45:22.640 --> 00:45:27.040] stuff like that, which will speed up everything a lot faster as well. But you can't follow along, | |
[00:45:27.040 --> 00:45:33.120] even on a CPU, I think, to a decent extent. So that's a note on that. Okay, so I do want to loop | |
[00:45:33.120 --> 00:45:38.160] around eventually into what it means to have different devices in PyTorch. And what it is exactly | |
[00:45:38.160 --> 00:45:43.840] that PyTorch does in the background for you, when you do something like module dot two device, | |
[00:45:43.840 --> 00:45:48.640] or where you take a torch tensor and do a dot two device. And what exactly happens and how that | |
[00:45:48.640 --> 00:45:53.600] works? But for now, I'd like to get to training. And I'd like to start training the model. And for | |
[00:45:53.600 --> 00:45:59.040] now, let's just say the device makes code go fast. And let's go into how we can actually train the | |
[00:45:59.040 --> 00:46:04.000] model. So to train the model, we're going to need some data set. And for me, the best debugging | |
[00:46:04.000 --> 00:46:09.600] simplest data set that I like to use is the tiny Shakespeare data set. And it's available at this | |
[00:46:09.600 --> 00:46:16.000] URL. So you can W get it, or you can just search tiny Shakespeare data set. And so I have in my | |
[00:46:16.000 --> 00:46:23.040] file system as just Ls input dot txt. So I already downloaded it. And here I'm reading the data set | |
[00:46:23.040 --> 00:46:28.880] getting the first 1000 characters and printing the first 100. Now remember that GPT two has | |
[00:46:28.880 --> 00:46:34.480] roughly a compression ratio, the tokenizer has a compression ratio of roughly three to one. | |
[00:46:34.480 --> 00:46:39.840] So 1000 characters is roughly 300 tokens here that will come out of this in the slides that | |
[00:46:39.840 --> 00:46:47.120] we're currently getting. So this is the first few characters. And if you want to get a few more | |
[00:46:47.120 --> 00:46:54.480] statistics on this, we can do word count on input dot txt. So we can see that this is 40,000 lines, | |
[00:46:54.480 --> 00:47:00.960] about 200,000 words in this data set, and about 1 million bytes in this file. And knowing that | |
[00:47:00.960 --> 00:47:05.600] this file is only ASCII characters, there's no crazy Unicode here as far as I know. And so | |
[00:47:05.600 --> 00:47:10.960] every ASCII character is encoded with one byte. And so this is the same number, roughly a million | |
[00:47:10.960 --> 00:47:17.040] characters inside this data set. So that's the data set size by default, very small and minimal | |
[00:47:17.040 --> 00:47:22.320] data set for debugging to get us off the ground. In order to tokenize this data set, we're going to | |
[00:47:22.320 --> 00:47:30.800] get tick token encoding for GPT two, encode the data, the first 1000 characters, and then not | |
[00:47:30.800 --> 00:47:36.000] only going to print the first 24 tokens. So these are the tokens as a list of integers. | |
[00:47:36.000 --> 00:47:41.680] And if you can read GPT two tokens, you will see that 198 here, you'll recognize that as the | |
[00:47:41.680 --> 00:47:46.320] slashing character. So that is a new line. And then here, for example, we have two new lines. So | |
[00:47:46.320 --> 00:47:53.120] that's 198 twice here. So this is just the tokenization of the first 24 tokens. So what we want to do | |
[00:47:53.120 --> 00:47:59.120] now is we want to actually process these token sequences and feed them into a transformer. And | |
[00:47:59.120 --> 00:48:05.680] in particular, we want them, we want to rearrange these tokens into this ID X variable that we're | |
[00:48:05.680 --> 00:48:09.440] going to be feeding into the transformer. So we don't want a single very long one dimensional | |
[00:48:09.440 --> 00:48:16.880] sequence. We want an entire batch, where each sequence is up to, it's basically T tokens, | |
[00:48:16.880 --> 00:48:21.840] and T cannot be larger than the maximum sequence length. And then we have these T, | |
[00:48:22.640 --> 00:48:29.600] T long sequence of tokens. And we have B independent examples of sequences. So how can we create a | |
[00:48:29.600 --> 00:48:35.360] B by T tensor that we can feed into the forward out of these one dimensional sequences. So here's | |
[00:48:35.360 --> 00:48:41.440] my favorite way to to achieve this. So if we take torch, and then we create a tensor object out of | |
[00:48:41.440 --> 00:48:46.720] this list of integers and just the first 24 tokens, my favorite way to do this is basically you do a | |
[00:48:46.720 --> 00:48:54.800] dot view of, for example, four by six, which multiply to 24. And so it's just a two dimensional | |
[00:48:54.800 --> 00:48:58.560] rearrangement of these tokens. And you'll notice that when you view this one dimensional sequence | |
[00:48:58.560 --> 00:49:07.280] as two dimensional four by six here, the first six tokens up to here end up being the first row, | |
[00:49:07.280 --> 00:49:13.200] the next six tokens here end up being the second row, and so on. And so basically, it's just going | |
[00:49:13.200 --> 00:49:21.440] to stack up the every six tokens in this case, as independent rows, and it creates a batch of | |
[00:49:21.440 --> 00:49:27.840] tokens in this case. And so for example, if we are token 25, in the transformer, when we feed this | |
[00:49:27.840 --> 00:49:33.440] in, and this becomes the IDX, this token is going to see these three tokens, and it's going to try | |
[00:49:33.440 --> 00:49:40.640] to predict that 198 comes next. So in this way, we are able to create this two dimensional batch, | |
[00:49:40.640 --> 00:49:46.480] that's quite nice. Now, in terms of the label that we're going to need for the target to calculate | |
[00:49:46.480 --> 00:49:50.880] the loss function, how do we get that? Well, we could write some code inside the forward pass, | |
[00:49:50.880 --> 00:49:55.840] because we know that the next token in a sequence, which is the label, is just to the right of us. | |
[00:49:55.840 --> 00:50:02.160] But you'll notice that actually we for this token at the very end, 13, we don't actually have the | |
[00:50:02.160 --> 00:50:08.000] next correct token because we didn't load it. So we actually didn't get enough information here. | |
[00:50:09.200 --> 00:50:14.560] So I'll show you my favorite way of basically getting these batches. And I like to personally | |
[00:50:14.560 --> 00:50:19.840] have not just the input to the transformer, which I like to call x, but I also like to create the | |
[00:50:19.840 --> 00:50:26.000] labels tensor, which is of the exact same size as x, but contains the targets at every single | |
[00:50:26.000 --> 00:50:30.880] position. And so here's the way that I like to do that. I like to make sure that I fetch plus one | |
[00:50:30.880 --> 00:50:38.800] token, because we need the ground truth for the very last token, for 13. And then when we're creating | |
[00:50:38.800 --> 00:50:44.800] the input, we take everything up to the last token, not including and view it as four by six. And when | |
[00:50:44.800 --> 00:50:51.440] we're creating targets, we do the buffer, but starting at index one, not index zero. So we're | |
[00:50:51.440 --> 00:50:55.840] skipping the first element and we view it in the exact same size. And then when I print this, | |
[00:50:55.840 --> 00:51:02.320] here's what happens, where we see that basically as an example for this token 25, | |
[00:51:02.320 --> 00:51:08.320] its target was 198. And that's now just stored at the exact same position in the target tensor, | |
[00:51:08.320 --> 00:51:16.320] which is 198. And also this last token 13 now has its label, which is 198. And that's just because | |
[00:51:16.320 --> 00:51:22.400] we loaded this plus one here. So basically, this is the way I like to do it. You take long sequences, | |
[00:51:22.400 --> 00:51:29.280] you view them in two dimensional terms, so that you get batches of time. And then we make sure to | |
[00:51:29.280 --> 00:51:36.080] load one additional token. So we basically load a buffer of tokens of b times t plus one. And then | |
[00:51:36.080 --> 00:51:40.800] we sort of offset things and view them. And then we have two tensors, one of them is the input to | |
[00:51:40.800 --> 00:51:47.440] the transformer. And the other exactly is the labels. And so let's now reorganize this code and | |
[00:51:47.440 --> 00:51:53.760] create a very simple data loader object that tries to basically load these tokens and | |
[00:51:53.760 --> 00:51:59.040] feed them to the transformer and calculate the loss. Okay, so I reshuffled the code here, | |
[00:51:59.040 --> 00:52:04.240] accordingly. So as you can see here, I'm temporarily overriding to run on CPU. | |
[00:52:05.040 --> 00:52:09.520] And importing to token and all of this should look familiar, we're loading 1000 characters. | |
[00:52:09.520 --> 00:52:13.920] I'm setting BT to just be four and 32 right now, just because we're debugging, | |
[00:52:13.920 --> 00:52:18.640] we just want to have a single batch that's very small. And all of this should now look familiar | |
[00:52:18.640 --> 00:52:23.280] and follows what we did on the right. And then here, we get the we create the model | |
[00:52:23.280 --> 00:52:30.720] and get the logits. And so here, as you see, I already ran this only runs in a few seconds. | |
[00:52:30.720 --> 00:52:38.640] But because we have a batch of four by 32, our logits are now size four by 32 by 50,000 to 57. | |
[00:52:38.640 --> 00:52:44.480] So those are the logits for what comes next at every position. And now we have the labels, | |
[00:52:44.480 --> 00:52:49.440] which are stored in Y. So now is the time to calculate the loss, and then do the backward pass, | |
[00:52:49.440 --> 00:52:55.040] and then the optimization. So let's first calculate loss. Okay, so to calculate the loss, | |
[00:52:55.040 --> 00:52:59.840] we're going to adjust the forward function of this and in module in the model. And in particular, | |
[00:52:59.840 --> 00:53:03.040] we're not just going to be returning logits, but also we're going to return the loss. | |
[00:53:03.040 --> 00:53:08.400] And we're going to not just passing the input in the seats, but also the targets in Y. | |
[00:53:08.400 --> 00:53:14.800] And now we will print not load just that shape anymore. We're actually going to bring the loss | |
[00:53:14.800 --> 00:53:19.200] function and then assist that exit of zero so that we skip some of the sampling logic. | |
[00:53:19.200 --> 00:53:26.080] So now let's swing up to the forward function, which gets called there, because now we also have | |
[00:53:26.080 --> 00:53:32.720] these optional targets. And when we get the targets, we can also calculate the loss. | |
[00:53:32.720 --> 00:53:38.720] And remember that we want to basically return a logist loss and loss by default is none. But | |
[00:53:38.720 --> 00:53:50.400] let's put this here. If targets is not done, then we want to calculate the loss. And copilot is | |
[00:53:50.400 --> 00:53:55.600] already getting excited here and calculating the what looks to be correct loss. It is using the | |
[00:53:55.600 --> 00:54:03.440] cross entropy loss as is documented here. So this is a function in PyTorch under the functional. | |
[00:54:03.440 --> 00:54:09.520] Now, what is actually happening here, because it looks a little bit scary. Basically, the F dot | |
[00:54:09.520 --> 00:54:14.560] cross entropy does not like multi dimensional inputs. It can't take a B by T by vocab size. | |
[00:54:14.560 --> 00:54:19.280] So what's happening here is that we are flattening out of this three dimensional tensor into just | |
[00:54:19.280 --> 00:54:23.600] two dimensions. The first dimension is going to be calculated automatically, and it's going to be | |
[00:54:23.600 --> 00:54:30.160] B times T. And then the last dimension is vocab size. So basically, this is flattening out this | |
[00:54:30.160 --> 00:54:36.560] three dimensional tensor of logits to just be two dimensional B times T, all individual examples | |
[00:54:36.560 --> 00:54:42.800] and vocab size on in terms of the length of each row. And then it's also flattening out the | |
[00:54:42.800 --> 00:54:47.360] targets, which are also two dimensional at this stage. But we're going to just flatten them out. | |
[00:54:47.360 --> 00:54:52.080] So they're just a single tensor of B times T. And this can then pass into cross entropy to | |
[00:54:52.080 --> 00:54:57.360] calculate a loss, which we return. So this should basically at this point run, | |
[00:54:57.360 --> 00:55:04.320] because it's not too complicated. So let's run it. And let's see if we should be printing the loss. | |
[00:55:04.320 --> 00:55:14.320] And here we see that we printed 11 roughly. And so | |
[00:55:16.960 --> 00:55:22.080] and notice that this is the tensor of a single element, which is this number 11. Now, we also | |
[00:55:22.080 --> 00:55:27.520] want to be able to calculate a reasonable kind of starting point for a random linearized network. | |
[00:55:27.520 --> 00:55:33.760] So we covered this in previous videos, but our vocabulary size is 50,257. At initialization of | |
[00:55:33.760 --> 00:55:40.960] the network, you would hope that every vocab element is getting roughly a uniform probability, | |
[00:55:41.680 --> 00:55:46.880] so that we're not favoring at initialization, any token way too much. We're not confidently | |
[00:55:46.880 --> 00:55:51.760] wrong at initialization. So we're hoping is that the probability of any arbitrary token is roughly | |
[00:55:51.760 --> 00:55:58.560] one over 50,257. And now we can sanity check the loss, because remember that the cross entropy | |
[00:55:58.560 --> 00:56:05.920] loss is just basically the negative log likelihood. So if we now take this probability, and we take | |
[00:56:05.920 --> 00:56:11.440] it through the natural logarithm, and then we do the negative, that is the loss we expect at | |
[00:56:11.440 --> 00:56:17.040] initialization, and we covered this in previous videos. So I would expect something around 10.82, | |
[00:56:17.040 --> 00:56:21.360] and we're seeing something around the level. So it's not way off. This is roughly the probability | |
[00:56:21.360 --> 00:56:26.480] expect at initialization. So that tells me that the at initialization or probability distribution | |
[00:56:26.480 --> 00:56:32.160] is roughly diffuse. It's a good starting point. And we can now perform the optimization and | |
[00:56:32.160 --> 00:56:35.680] tell the network which elements, you know, should follow correctly in what order. | |
[00:56:35.680 --> 00:56:41.200] So at this point, we can do a loss that backward, calculate the gradients and do an optimization. | |
[00:56:41.200 --> 00:56:47.200] So let's get to that. Okay, so let's do the optimization now. So here we have | |
[00:56:47.200 --> 00:56:54.320] the loss is this is how we get the loss. But now basically we want a load for loop here. So for | |
[00:56:54.320 --> 00:57:00.400] i in range, let's do 50 steps or something like that. Let's create an optimizer object in PyTorch. | |
[00:57:02.000 --> 00:57:08.640] And so here we are using the atom optimizer, which is an alternative to stochastic gradient | |
[00:57:08.640 --> 00:57:13.760] descent optimizer, SGT that we're using. So SGT is a lot simpler, atom is a bit more involved. | |
[00:57:13.760 --> 00:57:19.440] And actually specifically like the atom w variation, because in my opinion, it kind of just like fixes | |
[00:57:19.440 --> 00:57:26.080] a bug. So atom w is a bug fix of atom is what I would say. When we go to the documentation for | |
[00:57:26.080 --> 00:57:34.080] an w. Oh my gosh. We see that it takes about two parameters and it's a little bit more | |
[00:57:34.080 --> 00:57:38.720] complicated than the SGT we were looking at before, because in addition to basically | |
[00:57:38.720 --> 00:57:43.440] updating the parameters with the gradient scaled by the learning rate, it keeps these buffers | |
[00:57:43.440 --> 00:57:49.200] around and it keeps two buffers, the M and the V, which it calls the first and the second moment. | |
[00:57:49.200 --> 00:57:53.040] So something that looks a bit like momentum is something that looks a bit like RMS prop, | |
[00:57:53.040 --> 00:57:57.120] if you're familiar with it. But you don't have to be, it's just kind of like a normalization | |
[00:57:57.120 --> 00:58:01.520] that happens on each gradient element individually and speeds up the optimization, | |
[00:58:01.520 --> 00:58:05.360] especially for language models. But I'm not going to go into the detail right here. | |
[00:58:05.360 --> 00:58:11.680] We're going to treat this a bit of a black box and it just optimizes the objective faster | |
[00:58:11.680 --> 00:58:16.560] than SGT, which is what we've seen in the previous lectures. So let's use it as a black box in our | |
[00:58:16.560 --> 00:58:23.680] case. Create the optimizer object and then go through the optimization. | |
[00:58:23.680 --> 00:58:34.800] The first thing to always make sure the copilot did not forget to zero the gradients. So always | |
[00:58:34.800 --> 00:58:39.440] remember that you have to start with a zero gradient. Then when you get your loss and you do a dot | |
[00:58:39.440 --> 00:58:45.840] backward, dot backward adds to gradients. So it deposits gradients. It always does a plus equals | |
[00:58:45.840 --> 00:58:50.080] on whatever the gradients are, which is why you must set them to zero. So this accumulates | |
[00:58:50.080 --> 00:58:56.960] the gradient from this loss. And then we call the step function on the optimizer to update the | |
[00:58:56.960 --> 00:59:05.360] parameters and to decrease the loss. And then we print the step and the loss dot item is used | |
[00:59:05.360 --> 00:59:11.440] here because loss is a tensor with a single element dot item will actually convert that to a single | |
[00:59:11.440 --> 00:59:17.520] float. And this float will not will live on the CPU. So this gets to some of the internals again | |
[00:59:17.520 --> 00:59:24.000] of the devices, but loss is a is a tensor with a single element and it lives on GPU for me because | |
[00:59:24.000 --> 00:59:29.680] I'm using GPUs. When you call dot item, pytorch behind the scenes will take that one dimensional | |
[00:59:29.680 --> 00:59:36.640] tensor ship it back to the CPU memory and convert it into a float that we can just print. So this | |
[00:59:36.640 --> 00:59:46.320] is the optimization and this should probably just work. Let's see what happens. Actually, | |
[00:59:46.320 --> 00:59:51.600] sorry, let me instead of using CPU override, let me delete that. So this is a bit faster for me | |
[00:59:51.600 --> 01:00:02.960] and it runs on CUDA. Oh, expected all tensors to be on the same device but found at least two devices, | |
[01:00:02.960 --> 01:00:09.280] CUDA zero and CPU. So CUDA zero is the zero GPU because I actually have a GPUs on this box. | |
[01:00:09.280 --> 01:00:17.280] So the 0th GPU on my box and CPU. And a model we have moved to device, but when I was writing | |
[01:00:17.280 --> 01:00:22.640] this code, I actually introduced the bug because buff we never moved to device. And you have to be | |
[01:00:22.640 --> 01:00:30.400] careful because you can't just do buff dot to of device. It's not stateful. It doesn't convert it | |
[01:00:30.400 --> 01:00:37.280] to be a device. It instead returns pointer to a new memory which is on the device. So you see how | |
[01:00:37.280 --> 01:00:42.080] we can just do model that to a device that does not apply to tensors. You have to do buff equals | |
[01:00:42.080 --> 01:00:48.960] buff dot to device. And then this should work. Okay. | |
[01:00:48.960 --> 01:00:54.720] So what do we expect to see? We expect to see a reasonable loss in the beginning and then we | |
[01:00:54.720 --> 01:00:59.120] continue to optimize just a single batch. And so we want to see that we can overfit this single | |
[01:00:59.120 --> 01:01:04.240] batch. We can we can crush this little batch and we can perfectly predict the indices on just this | |
[01:01:04.240 --> 01:01:10.960] little batch. And in these that is roughly what we're seeing here. So we started off at roughly | |
[01:01:10.960 --> 01:01:16.240] 10.82 11 in this case. And then as we continue optimizing on this single batch without loading | |
[01:01:16.240 --> 01:01:20.560] new examples, we are making sure that we can overfit a single batch. And we are getting to | |
[01:01:20.560 --> 01:01:25.280] very, very low loss. So the transformer is memorizing this single individual batch. | |
[01:01:26.080 --> 01:01:30.320] And one more thing I didn't mention is the learning rate here is three negative four, | |
[01:01:30.320 --> 01:01:36.400] which is a pretty good default for most optimizations that you went around at a very early debugging | |
[01:01:36.400 --> 01:01:43.840] stage. So this is our simple inner loop. And we are overfitting a single batch. And this looks good. | |
[01:01:43.840 --> 01:01:47.760] So now what what comes next is we don't just want to overfit a single batch. We actually want | |
[01:01:47.760 --> 01:01:52.960] to do an optimization. So we actually need to iterate these xy batches and create a little data | |
[01:01:52.960 --> 01:01:57.760] loader that makes sure that we're always getting a fresh batch and that we're actually optimizing | |
[01:01:57.760 --> 01:02:02.160] a reasonable objective. So let's do that next. Okay, so this is where I came up with, and I wrote | |
[01:02:02.160 --> 01:02:08.320] a little data loader light. So what this data loader does is we're importing to token up here, | |
[01:02:08.320 --> 01:02:14.800] reading the entire text file from this single input.txt, tokenizing it, and then we're just | |
[01:02:14.800 --> 01:02:20.880] printing the number of tokens in total. And the number of batches and a single epoch of iterating | |
[01:02:20.880 --> 01:02:26.160] over this dataset. So how many unique batches do we output before we loop back around the beginning | |
[01:02:26.160 --> 01:02:32.320] of the document and start reading it again. So we start off at position zero, and then we simply | |
[01:02:32.320 --> 01:02:38.320] walk the document in batches of b times t. So we take chunks of b times t, and then always advance | |
[01:02:38.320 --> 01:02:45.280] by b times t. And it's important to note that we're always advancing our position by exactly b times | |
[01:02:45.280 --> 01:02:50.640] t. But when we're fetching the tokens, we're actually fetching from current position to b times | |
[01:02:50.640 --> 01:02:57.440] t plus one. And we need that plus one, because remember, we need the target token for the last | |
[01:02:57.440 --> 01:03:05.760] token in the current patch. And so that way we can do the x, y exactly as we did it before. And | |
[01:03:05.760 --> 01:03:12.480] if we are to run out of data, we'll just loop back around to zero. So this is one way to write | |
[01:03:12.480 --> 01:03:18.800] a very, very simple data loader. That simply just goes through the file in chunks. And this | |
[01:03:18.800 --> 01:03:24.800] could enough for us for current purposes. And we're going to complexify it later. And now we'd like | |
[01:03:24.800 --> 01:03:29.440] to come back around here, and we'd like to actually use our data loader. So the import tick token has | |
[01:03:29.440 --> 01:03:36.320] moved up. And actually all of this is now useless. So instead we just want a train loader for the | |
[01:03:36.320 --> 01:03:43.440] training data. And we want to use the same hyperparameters for four. So their size was four and time was 32. | |
[01:03:44.320 --> 01:03:50.320] And then here, we need to get the x, y for the current batch. So let's see if kopal gets it, | |
[01:03:50.320 --> 01:03:56.720] because this is simple enough. So we call the next batch. And then we make sure that we have to | |
[01:03:56.720 --> 01:04:05.040] move our tensors from CPU to the device. So here, when I converted the tokens, | |
[01:04:05.040 --> 01:04:10.800] notice that I didn't actually move these tokens to the GPU, I left them on the CPU, which is default. | |
[01:04:12.080 --> 01:04:16.640] And that's just because I'm trying not to waste too much memory on the GPU. In this case, this is a | |
[01:04:16.640 --> 01:04:22.960] tiny data set that would fit. But it's fine to just ship it to GPU right now for our purposes | |
[01:04:22.960 --> 01:04:29.120] right now. So we get the next batch, we keep the data loader simple CPU class. And then here, | |
[01:04:29.120 --> 01:04:37.040] we actually ship it to the GPU and do all the computation. And let's see if this runs. So Python | |
[01:04:37.040 --> 01:04:42.560] trained to be to the pie. And what do we expect to see before this actually happens? What we | |
[01:04:42.560 --> 01:04:47.680] expect to see is now we're actually getting the next batch. So we expect to not overfit a single | |
[01:04:47.680 --> 01:04:54.640] batch. And so I expect our loss to come down, but not too much. And that's because I still | |
[01:04:54.640 --> 01:05:01.600] expected to come down because in the 50,257 tokens, many of those tokens never occur in our data set. | |
[01:05:01.600 --> 01:05:06.240] So there's some very easy gains to be made here in the optimization by, for example, taking the | |
[01:05:06.240 --> 01:05:11.520] biases of all the logits that never occur and driving them to negative infinity. And that would | |
[01:05:11.520 --> 01:05:16.080] basically just, it's just that all of these crazy unicodes or different languages, those tokens | |
[01:05:16.080 --> 01:05:20.480] never occur. So their probability should be very low. And so the gains that we should be seeing | |
[01:05:20.480 --> 01:05:26.000] are along the lines of basically deleting the usage of tokens that never occur. That's probably | |
[01:05:26.000 --> 01:05:30.800] most of the loss gain that we're going to see at this scale right now. But we shouldn't come to | |
[01:05:31.680 --> 01:05:38.160] zero because we are only doing 58 iterations. And I don't think that's enough to do an epoch | |
[01:05:38.160 --> 01:05:46.640] right now. So let's see what we got. We, we have 338,000 tokens, which makes sense with our | |
[01:05:46.640 --> 01:05:52.720] three to one compression ratio, because there are one million characters. So one epoch with the | |
[01:05:52.720 --> 01:05:59.840] current setting of B and T will take 2,600 batches. And we're only doing 50 batches of optimization | |
[01:05:59.840 --> 01:06:06.800] in here. So we start off in a familiar territory, as expected, and then we seem to come down to | |
[01:06:06.800 --> 01:06:13.440] about 6.6. So basically, I think seem to be working okay right now with respect to our expectations. | |
[01:06:13.440 --> 01:06:19.120] So that's good. Okay, next, I want to actually fix a bug that we have in our code. It's not a | |
[01:06:19.120 --> 01:06:27.040] major bug, but it is a bug with respect to how GPT-2 training should happen. So the bug is the | |
[01:06:27.040 --> 01:06:31.120] following. We were not being careful enough when we were loading the weights from hugging face and | |
[01:06:31.120 --> 01:06:39.680] we actually missed a little detail. So if we come here, notice that the shape of these two tensors is | |
[01:06:39.680 --> 01:06:47.040] the same. So this one here is the token embedding at the bottom of the transformer. Right, so and | |
[01:06:47.040 --> 01:06:53.440] this one here is the language modeling head at the top of the transformer. And both of these are | |
[01:06:53.440 --> 01:07:00.160] basically two dimensional tensors and their shape is identical. So here, the first one is the output | |
[01:07:00.160 --> 01:07:05.040] embedding, the token embedding. And the second one is this linear layer at the very top, the classifier | |
[01:07:05.040 --> 01:07:14.640] layer. Both of them are of shape 50,000 to 57 by 768. This one here is giving us our token embeddings | |
[01:07:14.640 --> 01:07:21.040] at the bottom. And this one here is taking the 768 channels of the transformer and trying to upscale | |
[01:07:21.040 --> 01:07:28.240] that to 50,000 to 57 to get the logis for the next token. So they're both the same shape. But more | |
[01:07:28.240 --> 01:07:35.200] than that, actually, if you look at comparing their elements in PyTorch, this is an element | |
[01:07:35.200 --> 01:07:39.600] wise equality. So then we use dot all, and we see that every single element is identical. | |
[01:07:39.600 --> 01:07:46.080] And more than that, we see that if we actually look at the data pointer, this is what this is a | |
[01:07:46.080 --> 01:07:51.360] way in PyTorch to get the actual pointer to the data and the storage, we see that actually the | |
[01:07:51.360 --> 01:07:56.480] pointer is identical. So not only are these two separate tensors that happen to have the same shape | |
[01:07:56.480 --> 01:08:02.240] and elements, they're actually pointing to the identical tensor. So what's happening here is | |
[01:08:02.240 --> 01:08:10.000] that this is a common wait time scheme that actually comes from the original, from the original | |
[01:08:10.000 --> 01:08:15.040] attention is all you need paper, and actually even the reference before it. So if we come here, | |
[01:08:15.040 --> 01:08:25.920] embeddings and softmax in the attention is all you need paper, they mention that in our model, | |
[01:08:25.920 --> 01:08:30.720] we shared the same wait matrix between the two embedding layers and the pre-softmax linear | |
[01:08:30.720 --> 01:08:37.520] transformation similar to 30. So this is an awkward way to phrase that these two are shared, | |
[01:08:37.520 --> 01:08:42.000] and they're tied, and they're the same matrix. And the 30 reference is this paper. | |
[01:08:42.000 --> 01:08:49.440] So this came out in 2017. And you can read the full paper, but basically it argues for this | |
[01:08:49.440 --> 01:08:54.880] wait time scheme. And I think intuitively the idea for why you might want to do this | |
[01:08:54.880 --> 01:09:00.080] comes from this paragraph here. And basically you can observe that | |
[01:09:03.280 --> 01:09:10.080] you actually want these two matrices to behave similar in the following sets. If two tokens | |
[01:09:10.080 --> 01:09:14.560] are very similar semantically, like maybe one of them is all lowercase and the other one is | |
[01:09:14.560 --> 01:09:18.240] all uppercase, or it's the same token in the different language or something like that, | |
[01:09:18.240 --> 01:09:22.160] if you have similarity between two tokens, presumably you would expect that they are | |
[01:09:22.160 --> 01:09:27.600] nearby in the token embedding space. But in the exact same way, you'd expect that if you have | |
[01:09:27.600 --> 01:09:32.720] two tokens that are similar semantically, you'd expect them to get the same probabilities | |
[01:09:33.280 --> 01:09:40.240] at the output of a transformer, because they are semantically similar. And so both positions | |
[01:09:40.240 --> 01:09:45.120] in the transformer at the very bottom and at the top have this property that similar tokens | |
[01:09:45.120 --> 01:09:50.720] should have similar embeddings or similar weights. And so this is what motivates their | |
[01:09:50.720 --> 01:09:54.400] exploration here. And they kind of, you know, I don't want to go through the entire paper. | |
[01:09:54.400 --> 01:10:00.160] And you can go through it, but this is what they observe. They also observe that if you look at | |
[01:10:00.160 --> 01:10:06.960] the output embeddings, they also behave like word embeddings. If you just kind of try to use those | |
[01:10:06.960 --> 01:10:12.800] weights as word embeddings. So they kind of observe this similarity. They try to tie them, | |
[01:10:12.800 --> 01:10:18.160] and they observe that they can get much better performance in that way. And so this was adopted | |
[01:10:18.160 --> 01:10:23.680] and the attention is on a meat paper. And then it was used again in GPT-2 as well. So | |
[01:10:25.200 --> 01:10:30.160] I couldn't find it in the Transformers implementation. I'm not sure where they tie those embeddings, | |
[01:10:30.160 --> 01:10:37.920] but I can't find it in the original GPT-2 code introduced by OpenAI. So this is OpenAI GPT-2 | |
[01:10:37.920 --> 01:10:42.880] source model. And here where they are forwarding this model, and this is an intensive flow, but | |
[01:10:42.880 --> 01:10:50.560] that's okay, we see that they get the WTE token embeddings. And then here is the encoder of the | |
[01:10:50.560 --> 01:10:55.840] token embeddings and the position. And then here at the bottom, they use the WTE again | |
[01:10:55.840 --> 01:11:02.800] to do the logits. So when they get the logits, it's a mat model of this output from the transformer | |
[01:11:02.800 --> 01:11:10.240] and the WTE tensor is reused. And so the WTE tensor basically is used twice on the bottom of the | |
[01:11:10.240 --> 01:11:15.920] transformer and on the top of the transformer. And in the backward pass, we'll get gradients | |
[01:11:15.920 --> 01:11:22.080] contributions from both branches. And these gradients will add up on the WTE tensor. | |
[01:11:22.080 --> 01:11:27.440] So we'll get a contribution from the classifier layer. And then at the very end of the transformer, | |
[01:11:27.440 --> 01:11:33.760] we'll get a contribution at the bottom of it flowing again into the WTE tensor. | |
[01:11:33.760 --> 01:11:39.920] So we are currently not sharing WTE in our code, but we want to do that. | |
[01:11:42.160 --> 01:11:49.840] So weight sharing scheme. And one way to do this, let's see if Kupal gets it. | |
[01:11:49.840 --> 01:11:58.800] Oh, it does. Okay. So this is one way to do it. Basically, relatively straightforward, | |
[01:11:58.800 --> 01:12:04.640] what we're doing here is we're taking the WTE dot weight, and we're simply redirecting | |
[01:12:05.520 --> 01:12:13.120] it to point to the element. So this basically copies the data pointer, right, it copies the | |
[01:12:13.120 --> 01:12:20.960] reference. And now the WTE dot weight becomes orphaned, the old value of it, and pytorch will | |
[01:12:20.960 --> 01:12:28.640] clean it up. And so we are only locked with a single tensor. And it's going to be used twice | |
[01:12:28.640 --> 01:12:34.640] in the forward pass. And this is, to my knowledge, all that's required. So we should be able to | |
[01:12:35.520 --> 01:12:40.960] use this. And this should probably train. We're just going to basically be using this exact same | |
[01:12:40.960 --> 01:12:48.960] sensor twice. And we weren't being careful with tracking the likelihoods. But according to the | |
[01:12:48.960 --> 01:12:52.240] paper, and according to the results, you'd actually expect slightly better results doing this. | |
[01:12:52.240 --> 01:12:57.600] And in addition to that, one other reason that this is very, very nice for us is that this is a | |
[01:12:57.600 --> 01:13:06.160] ton of parameters, right? What is the size of here? It's 768 times 50,257. So this is 40 million | |
[01:13:06.160 --> 01:13:13.840] parameters. And this is a 124 million parameter model. So 40 divided 124. So this is like 30% of | |
[01:13:13.840 --> 01:13:19.440] the parameters are being saved using this wait time scheme. And so this might be one of the | |
[01:13:19.440 --> 01:13:23.120] reasons that this is working slightly better. If you're not training the model long enough, | |
[01:13:23.840 --> 01:13:28.240] because of the wait time, you don't have to train as many parameters. And so you become more efficient | |
[01:13:28.240 --> 01:13:34.240] in terms of the training process, because you have fewer parameters. And you're putting in this | |
[01:13:34.240 --> 01:13:38.960] inductive bias that these two embeddings should share similarities between tokens. | |
[01:13:38.960 --> 01:13:45.040] So this is the wait time scheme. And we saved a ton of parameters. And we expect our model to work | |
[01:13:45.040 --> 01:13:48.960] slightly better because of this scheme. Okay, next, I would like us to be a bit more careful | |
[01:13:48.960 --> 01:13:53.920] with the initialization and to try to follow the way GPT to initialize their model. Now, | |
[01:13:53.920 --> 01:13:59.360] unfortunately, the GPT two paper and the GPT three paper are not very explicit about initialization. | |
[01:13:59.360 --> 01:14:04.240] So we kind of have to read between lines. And instead of going to the paper, which is quite vague, | |
[01:14:04.240 --> 01:14:10.480] there's a bit of information in the code that opening up released. So when we go to the model.py, | |
[01:14:10.480 --> 01:14:16.160] we see that when they initialize their weights, they are using the standard deviation of 0.02. | |
[01:14:16.720 --> 01:14:22.160] And that's how they, they, so this is a normal distribution for the weights. And the standard | |
[01:14:22.160 --> 01:14:29.040] deviation is 0.02. For the bias, they initialize that with zero. And then when we scroll down here, | |
[01:14:29.040 --> 01:14:39.040] why is this not scrolling? The token embeddings are initialized at 0.02. And position embeddings | |
[01:14:39.040 --> 01:14:45.120] at 0.01 for some reason. So those are the initialization. And we'd like to mirror that and GPT two | |
[01:14:46.320 --> 01:14:51.520] in our module here. So here's a snippet of code that I sort of came up with very quickly. | |
[01:14:51.520 --> 01:14:58.880] So what's happening here is at the end of our initializer for the GPT module, we're calling the | |
[01:14:58.880 --> 01:15:06.240] apply function of an end module. And that iterates all the sub modules of this module and applies | |
[01:15:06.240 --> 01:15:12.400] in it weights function on them. And so what's happening here is that we're in it, we're iterating | |
[01:15:12.400 --> 01:15:17.680] all the modules here. And if they are an end that linear module, then we're going to make sure to | |
[01:15:17.680 --> 01:15:23.360] initialize the weight using a normal with the standard deviation of 0.02. If there's a bias in | |
[01:15:23.360 --> 01:15:28.640] this layer, we make sure to initialize that to zero. Note that zero initialization for the bias | |
[01:15:28.640 --> 01:15:35.760] is not actually the pytorch default. By default, the bias here is initialized with a uniform. So | |
[01:15:35.760 --> 01:15:41.120] that's interesting. So we make sure to use zero. And for the embedding, we're just going to get 0.02. | |
[01:15:41.120 --> 01:15:46.960] And keep it the same. So we're not going to change it to 0.01 for positional, because it's about the | |
[01:15:46.960 --> 01:15:52.240] same. And then if you look through our model, the only other layer that requires initialization, | |
[01:15:52.240 --> 01:15:56.640] and that has parameters, is the layer norm. And the pytorch default initialization | |
[01:15:56.640 --> 01:16:01.200] assess the scale in the layer norm to be one, and the offset in the layer norm to be zero. | |
[01:16:01.200 --> 01:16:04.640] So that's exactly what we want. And so we're just going to keep it that way. | |
[01:16:05.520 --> 01:16:15.440] And so this is the default initialization, if we are following the, where is it, the GPT to a source | |
[01:16:15.440 --> 01:16:20.960] code that they released. I would like to point out, by the way, that typically the standard deviation | |
[01:16:20.960 --> 01:16:25.440] here on this initialization, if you follow the heavier initialization, would be one over the square | |
[01:16:25.440 --> 01:16:31.120] root of the number of features that are incoming into this layer. But if you'll notice, actually, | |
[01:16:31.120 --> 01:16:37.200] 0.02 is basically consistent with that, because the demo sizes inside these transformers for GPT | |
[01:16:37.200 --> 01:16:44.640] two are roughly 768, 1600, etc. So one over the square root of, for example, 768 gives us 0.03. | |
[01:16:44.640 --> 01:16:55.680] If we plug in 600, 1600, we get 0.02. If we plug in three times that 0.014, etc. So basically 0.02 is | |
[01:16:55.680 --> 01:17:04.640] roughly in the vicinity of reasonable values for these initializations anyway. So it's not | |
[01:17:04.640 --> 01:17:11.920] completely crazy to be hard coding 0.02 here. But you'd like typically something that grows with | |
[01:17:11.920 --> 01:17:16.640] the model size instead. But we will keep this because that is the GPT two initialization per | |
[01:17:16.640 --> 01:17:21.120] their source code. But we are not fully done yet on initialization because there's one more caveat | |
[01:17:21.120 --> 01:17:27.920] here. So here, a modified initialization, which accounts for the accumulation on the residual path | |
[01:17:27.920 --> 01:17:32.560] with model depth is used. We scaled the weight of residual layers of initialization by factor | |
[01:17:32.560 --> 01:17:37.680] one over square root of n, where n is the number of residual layers. So this is what GPT two paper | |
[01:17:37.680 --> 01:17:43.360] says. So we have not implemented that yet. And we can do so now. Now I'd like to actually kind of | |
[01:17:43.360 --> 01:17:48.560] like motivate a little bit what they mean here, I think. So here's roughly what they mean. | |
[01:17:50.240 --> 01:17:56.240] If you start out with zeros in your residual stream, remember that each residual stream is a | |
[01:17:56.240 --> 01:18:02.800] is of this form where we continue adding to it x is x plus something, some kind of contribution. | |
[01:18:02.800 --> 01:18:09.840] So every single block of the residual network contributes some amount and it gets added. | |
[01:18:09.840 --> 01:18:17.040] And so what ends up happening is that the variance of the activations in the residual stream grows. | |
[01:18:17.840 --> 01:18:23.840] So here's a small example, if we start at zero, and then we for 100 times, we have sort of this | |
[01:18:23.840 --> 01:18:33.040] residual stream of 768 zeros. And then 100 times we add random, which is a normal distribution, | |
[01:18:33.040 --> 01:18:38.320] zero mean one standard deviation. If we add to it, then by the end, the residual stream has | |
[01:18:38.320 --> 01:18:44.240] grown to have standard deviation of 10. And that's just because we're always adding | |
[01:18:46.320 --> 01:18:52.560] these numbers. And so this scaling factor that they use here exactly compensates for that growth. | |
[01:18:52.560 --> 01:19:00.080] So if we take n, and we basically scale down every one of these contributions into the residual | |
[01:19:00.080 --> 01:19:06.560] stream by one over the square root of n. So one over the square root of n is n to the negative 0.5, | |
[01:19:06.560 --> 01:19:14.000] right? Because n to the 0.5 is the square root, and then one over the square root is n negative 0.5. | |
[01:19:14.000 --> 01:19:19.120] If we scale it in this way, then we see that we actually get 1. | |
[01:19:19.120 --> 01:19:27.040] So this is a way to control the growth of activations inside the residual stream | |
[01:19:27.040 --> 01:19:32.320] in the forward pass. And so we'd like to initialize in the same way, where these weights | |
[01:19:32.320 --> 01:19:40.080] that are at the end of each block, so this seed project layer, the gpt paper proposes to scale | |
[01:19:40.080 --> 01:19:43.360] down those weights by one over the square root of the number of residual layers. | |
[01:19:43.360 --> 01:19:47.040] So one crude way to implement this is the following. | |
[01:19:47.040 --> 01:19:53.600] I don't know if this is a pytorch sanctioned, but it works for me, is we all do in the initialization. | |
[01:19:53.600 --> 01:20:06.880] See, that special nano gpt scale in it is one. So we're setting kind of like a flag for this module. | |
[01:20:07.680 --> 01:20:10.960] There must be a bad way to pytorch, right, but I don't know. | |
[01:20:10.960 --> 01:20:16.960] Okay, so we're basically attaching this flag and trying to make sure that it doesn't conflict | |
[01:20:16.960 --> 01:20:23.440] with anything previously. And then when we come down here, this std should be 0.02 by default. | |
[01:20:24.880 --> 01:20:33.600] But then if has after module of this thing, then s d d times equals | |
[01:20:33.600 --> 01:20:42.480] not guessing correctly. So we want one over the square root of the number of layers. So | |
[01:20:42.480 --> 01:20:48.240] the number of residual layers here is twice times. | |
[01:20:49.520 --> 01:20:56.800] Solve that conflict layers. And then this times negative point five. So we want to scale down | |
[01:20:56.800 --> 01:21:03.440] that standard deviation. And this should be correct and implement that. I should clarify, | |
[01:21:03.440 --> 01:21:07.280] by the way, that the two times number of layers comes from the fact that every single one of | |
[01:21:07.280 --> 01:21:12.000] our layers in the transformer actually has two blocks that add to the residual pathway, right? | |
[01:21:12.000 --> 01:21:15.760] We have the attention and then the MLP. So that's where the two times comes from. | |
[01:21:16.800 --> 01:21:22.320] And the other thing to mention is that what's slightly awkward, but we're not going to fix it, | |
[01:21:22.320 --> 01:21:29.600] is that because we are weight sharing the WTE and the LMAD, in this iteration over all | |
[01:21:29.600 --> 01:21:34.000] sub modules, we're going to actually come around to that tensor twice. So we're going to first | |
[01:21:34.000 --> 01:21:38.800] initialize it as an embedding with 0.02. And then we're going to come back around it again in a | |
[01:21:38.800 --> 01:21:44.960] linear and initialize it again using 0.02. And it's going to be 0.02 because the LMAD is, of course, | |
[01:21:44.960 --> 01:21:50.240] not scaled. So it's not going to come here. It's just it's going to be basically initialized twice | |
[01:21:50.240 --> 01:21:55.920] using the identical same initialization. But that's OK. And then scrolling over here, | |
[01:21:55.920 --> 01:22:04.640] I added some code here so that we have reproducibility to set the seeds. And now we should be able to | |
[01:22:04.640 --> 01:22:11.440] python train GPT2.py and let this running. And as far as I know, this is the GPT2 initialization | |
[01:22:12.240 --> 01:22:20.080] in the way we've implemented right now. So this looks reasonable to me. Okay, so at this point, | |
[01:22:20.080 --> 01:22:24.560] we have the GPT2 model. We have some confidence that it's correctly implemented. We've initialized | |
[01:22:24.560 --> 01:22:28.640] it properly. And we have a data loader that's iterating through data batches. And we can train. | |
[01:22:28.640 --> 01:22:33.760] So now comes the fun part. I'd like us to speed up the training by a lot. So we're getting our | |
[01:22:33.760 --> 01:22:38.640] money's worth with respect to the hardware that we are using here. And we're going to speed up | |
[01:22:38.640 --> 01:22:44.320] the training by quite a bit. Now you always want to start with what hardware do you have? What does | |
[01:22:44.320 --> 01:22:51.840] it offer? And are you fully utilizing it? So in my case, if we go to NVIDIA SMI, we can see that | |
[01:22:51.840 --> 01:23:02.640] I have eight GPUs. And each one of those GPUs is an 8100 SXM 80 gigabytes. So this is the GPU that | |
[01:23:02.640 --> 01:23:09.440] I have available to me in this box. Now, when I use to spin up these kinds of boxes, by the way, | |
[01:23:09.440 --> 01:23:16.800] my favorite place to go to is Lambda Labs. They do sponsor my development and that of my projects. | |
[01:23:16.800 --> 01:23:22.000] But this is my favorite place to go. And this is where you can spin up one of these machines, | |
[01:23:22.000 --> 01:23:26.160] and you pay per hour. And it's very, very simple. So I like to spin them up and then | |
[01:23:26.160 --> 01:23:31.360] connect VS code to it. And that's how I develop. Now, when we look at the 8100s that are available | |
[01:23:31.360 --> 01:23:40.640] here, 8100 80 gigabyte SXM is the GPU that I have here. And we have a bunch of numbers here for | |
[01:23:40.640 --> 01:23:48.960] how many calculations you can expect out of this GPU. So when I come over here, and I break in | |
[01:23:48.960 --> 01:23:54.800] right after here. So I'm breaking in right after we calculate the logits and the loss. | |
[01:23:56.560 --> 01:24:03.680] And the interesting thing I'd like you to note is when I do logits dot d type, this prints a torch | |
[01:24:03.680 --> 01:24:09.600] dot float 32. So by default in PyTorch, when you create tensors, and this is the case for all | |
[01:24:09.600 --> 01:24:13.680] the activations and for the parameters of the network and so on, by default, everything is in | |
[01:24:13.680 --> 01:24:21.200] float 32. That means that every single number activation or weight and so on is using a float | |
[01:24:21.200 --> 01:24:27.520] representation that has 32 bits. And that's actually quite a bit of memory. And it turns out empirically | |
[01:24:27.520 --> 01:24:32.400] that for deep learning as a computational workload, this is way too much. And deep learning and the | |
[01:24:32.400 --> 01:24:38.960] training of these networks can tolerate significantly lower precision. Not all computational workloads | |
[01:24:38.960 --> 01:24:46.000] can tolerate small precision. So for example, if we go back to the data sheet, you'll see that | |
[01:24:46.000 --> 01:24:51.360] actually these GPUs support up to FP64. And this is quite useful, I understand for a lot of | |
[01:24:51.360 --> 01:24:56.560] scientific computing applications. And there they really need this. But we don't need that | |
[01:24:56.560 --> 01:25:03.520] much precision for deep learning training. So currently we are here, FP 32. And with this code | |
[01:25:03.520 --> 01:25:10.800] as it is right now, we expect to get at at most 19.5 teraflops of performance. That means we're | |
[01:25:10.800 --> 01:25:16.480] doing 19.5 trillion operations, floating point operations. So this is floating point multiply, | |
[01:25:16.480 --> 01:25:23.680] add most, most likely. And so these are the floating point operations. | |
[01:25:23.680 --> 01:25:32.160] Now notice that if we are willing to go down in precision, so TF 32 is a lower precision format, | |
[01:25:32.160 --> 01:25:36.560] we're going to see in a second, you can actually get an 8x improvement here. And if you're willing | |
[01:25:36.560 --> 01:25:42.400] to go down to float 16 or B float 16, you can actually get times 16x performance, | |
[01:25:42.400 --> 01:25:48.080] all the way to 312 teraflops. You see here that NVIDIA likes to cite numbers that have | |
[01:25:48.080 --> 01:25:54.240] a asterisk here. This asterisk says with sparsity. But we are not going to be using sparsity in | |
[01:25:54.240 --> 01:25:58.800] our code. And I don't know that this is very widely used in the industry right now. So most | |
[01:25:58.800 --> 01:26:04.240] people look at this number here without sparsity. And you'll notice that we could have got even | |
[01:26:04.240 --> 01:26:12.800] more here. But this is int8. And int8 is used for inference, not for training. Because int8 has a | |
[01:26:12.800 --> 01:26:24.240] it basically has uniform spacing. And we actually require a float so that we get a better match | |
[01:26:24.240 --> 01:26:30.720] to the normal distributions that occur during training of neural logs, where both activations | |
[01:26:30.720 --> 01:26:36.320] and weights are distributed as a normal distribution. And so floating points are really important to | |
[01:26:36.320 --> 01:26:43.200] match that representation. So we're not typically using int8 for training, but we are using it for | |
[01:26:43.200 --> 01:26:49.280] inference. And if we bring down the precision, we can get a lot more teraflops out of the tensor | |
[01:26:49.280 --> 01:26:53.840] course available in the GPUs. We'll talk about that in a second. But in addition to that, if all | |
[01:26:53.840 --> 01:26:59.680] of these numbers have fewer bits of representation, it's going to be much easier to move them around. | |
[01:27:00.240 --> 01:27:03.920] And that's where we start to get into the memory bandwidth and the memory of the model. | |
[01:27:03.920 --> 01:27:08.800] So not only do we have a finite capacity of the number of bits that our GPU can store, | |
[01:27:08.800 --> 01:27:13.360] but in addition to that, there's a speed with which you can access this memory. | |
[01:27:13.360 --> 01:27:20.080] And you have a certain memory bandwidth. It's a very precious resource. And in fact, many of the | |
[01:27:20.080 --> 01:27:25.520] deep learning workloads for training are memory bound. And what that means is actually that the | |
[01:27:25.520 --> 01:27:30.400] tensor course that do all these extremely fast multiplications, most of the time they're waiting | |
[01:27:30.400 --> 01:27:37.120] around their idle, because we can't feed them with data fast enough. We can't load the data fast | |
[01:27:37.120 --> 01:27:41.440] enough for memory. So typical utilizations of your hardware, if you're getting 60% | |
[01:27:41.440 --> 01:27:48.400] utilization, you're actually doing extremely well. So half of the time in a well-tuned application, | |
[01:27:48.400 --> 01:27:52.960] your tensor course are not doing multiplies because the data is not available. So the memory | |
[01:27:52.960 --> 01:27:58.080] bandwidth here is extremely important as well. And if we come down in the precision for all the | |
[01:27:58.080 --> 01:28:04.240] floats, all the numbers, weights and activations, suddenly require less memory. So we can store more | |
[01:28:04.240 --> 01:28:09.840] and we can access it faster. So everything speeds up and it's amazing. And now let's reap the | |
[01:28:09.840 --> 01:28:16.000] benefits of it. And let's first look at the tensor float 32 format. Okay, so first of all, | |
[01:28:16.000 --> 01:28:22.160] what are tensor cores? Well, tensor cores, tensor core is just an instruction in the A100 | |
[01:28:22.160 --> 01:28:28.480] architecture, right? So what it does is it does basically a little four by four matrix multiply. | |
[01:28:28.480 --> 01:28:36.160] So this is just matrix multiplication here of four by four matrices. And there are multiple | |
[01:28:36.160 --> 01:28:42.560] configurations as to what precision any of these matrices are. In what precision the internal | |
[01:28:42.560 --> 01:28:47.760] accumulate happens. And then what is the output precision, input precision, etc. So there's a few | |
[01:28:47.760 --> 01:28:53.280] switches, but it's basically a four by four multiply. And then anytime we have any operations that | |
[01:28:53.280 --> 01:28:59.360] require matrix multiplication, they get broken up into these into this instruction of little four | |
[01:28:59.360 --> 01:29:03.680] by four multiply. And so everything gets broken up into this instruction because it's the fastest | |
[01:29:03.680 --> 01:29:08.560] way to multiply matrices. And it turns out that most of the computational work that we're doing up | |
[01:29:08.560 --> 01:29:14.240] above, all of it really is matrix multiplication. Most of the work computationally happens in the | |
[01:29:14.240 --> 01:29:22.560] linear layers, linear, linear, etc. There's a few things sandwiched in between. So there's some | |
[01:29:22.560 --> 01:29:27.680] additions in residuals, there's some galude nonlinearities, there's some layer norms, etc. | |
[01:29:27.680 --> 01:29:32.720] But if you just time them, you'll see that these are nothing. Like basically, the intra transformer | |
[01:29:32.720 --> 01:29:39.520] is just a bunch of matrix multiplications really. And especially at this small scale 124 million | |
[01:29:39.520 --> 01:29:45.120] parameter model. Actually, the biggest matrix multiplication by far is the classifier layer at | |
[01:29:45.120 --> 01:29:52.800] the top. That is a massive matrix multiply of going from 768 to 50,257. And that matrix multiply | |
[01:29:52.800 --> 01:29:59.280] dominates anything else that happens in that network roughly speaking. So it's matrix multiplies | |
[01:29:59.280 --> 01:30:05.040] that become a lot faster, which are hidden inside our linear layers. And they're accelerated through | |
[01:30:05.040 --> 01:30:09.760] tensor cores. Now, the best reference I would say for tensor cores is basically just go to the | |
[01:30:09.760 --> 01:30:17.840] 8100 architecture white paper. And then it's pretty detailed. But I think people | |
[01:30:17.840 --> 01:30:24.000] select relatively readable mostly if you have understand what's happening. So figure nine, | |
[01:30:24.000 --> 01:30:30.000] tensor float 32. So this is the explanation basically for TF32 and what happens here. | |
[01:30:31.200 --> 01:30:35.840] And you see that there's many configuration options here available. So the input operands | |
[01:30:35.840 --> 01:30:44.400] and what precision are they in? The accumulator and what basically the internal representation | |
[01:30:44.400 --> 01:30:48.480] within the instruction, when you do the accumulate of this matrix multiplication. | |
[01:30:48.480 --> 01:30:54.960] So the intermediate plus equals of the intermediate little vector multiplies here, | |
[01:30:54.960 --> 01:31:01.520] that all happens in FB32. And then this is an eight X improvement as I mentioned to the | |
[01:31:01.520 --> 01:31:06.960] ops that we get. So TF32 specifically, we're looking at this row here. And the way this works is | |
[01:31:06.960 --> 01:31:19.040] normally FB32 has 32 bits. TF32 is the exact same bits. We have one signed bit, | |
[01:31:19.040 --> 01:31:25.840] we have eight exponent bits, except the mantissa bits get cropped in the float. And so basically, | |
[01:31:25.840 --> 01:31:33.920] we end up with just 19 bits instead of 32 bits, because the last 13 bits get truncated, they get | |
[01:31:33.920 --> 01:31:40.160] dropped. And all of this is internal to the instruction. So none of it is visible to anything | |
[01:31:40.160 --> 01:31:45.520] in our PyTorch. None of our PyTorch code will change. All of the numbers will look identical. | |
[01:31:45.520 --> 01:31:51.280] It's just that when you call the tensor core instruction internally in the hardware, | |
[01:31:51.280 --> 01:31:57.840] it will crop out these 13 bits. And that allows it to calculate this little matrix | |
[01:31:57.840 --> 01:32:04.080] multiply significantly faster, eight X faster. Now, of course, this speed up comes at a cost. | |
[01:32:04.080 --> 01:32:09.360] And the cost is that we are reducing the precision. Our accumulate is still in that P32. Our output | |
[01:32:09.360 --> 01:32:16.160] is FB32. Our inputs are FB32. But internally, things get truncated in the operands to perform | |
[01:32:16.160 --> 01:32:20.240] the operation faster. And so our results are starting to be a bit more approximate. | |
[01:32:20.240 --> 01:32:23.520] But empirically, when you actually train with this, you basically can't tell the difference. | |
[01:32:23.520 --> 01:32:29.200] So the reason I like TF32 is because if you can tolerate a little bit of a precision fudge, | |
[01:32:29.200 --> 01:32:36.320] then this is pre like none of your code sees this. It's fully internal to the operation. | |
[01:32:36.320 --> 01:32:40.880] And the operation to you just go to eight X faster, and it's a bit more approximate. | |
[01:32:40.880 --> 01:32:46.320] And so it's a pretty sweet spot, I would say an optimization. And let's see what that looks like | |
[01:32:46.320 --> 01:32:53.280] first. So I've set up our codes to just time the iterations. So import time, I changed the hyper | |
[01:32:53.280 --> 01:32:57.920] parameters so that we have something a bit more that reflects kind of workload that we want to | |
[01:32:57.920 --> 01:33:03.360] run, because we want to do a fairly large run at the end of this. So let's use batch size 16. | |
[01:33:03.360 --> 01:33:08.480] And let's now use the actual GPT to maximum sequence length of 1024 tokens. | |
[01:33:08.480 --> 01:33:16.640] So this is the configuration. And then for 50 iterations, I'm just doing something very lazy | |
[01:33:16.640 --> 01:33:22.240] here. I'm doing time that time to get the current time. And then this is the optimization loop. | |
[01:33:22.240 --> 01:33:30.400] And now I want to time how long this takes. Now, one issue with working with GPUs is that | |
[01:33:31.040 --> 01:33:39.200] as your CPU, when your CPU runs, it's just scheduling work on GPU. It's ordering some work, right? | |
[01:33:39.200 --> 01:33:44.080] And so it sends a request and then it continues running. And so we can actually, it can happen | |
[01:33:44.080 --> 01:33:51.360] sometimes that we sort of speed through this. And we queue up a lot of kernels to run on the GPU. | |
[01:33:51.360 --> 01:33:55.520] And then the CPU sort of like gets here and takes time at that time. But actually the GPU is still | |
[01:33:55.520 --> 01:34:00.560] running, because it takes a time to actually work through the work that was scheduled to run. | |
[01:34:01.440 --> 01:34:06.080] And so you're just building up a queue for the GPU. And so actually, if you need to, | |
[01:34:06.080 --> 01:34:13.120] you want to wait, and this will wait for the GPU to finish all the work that was scheduled to run | |
[01:34:13.120 --> 01:34:18.960] up above here. And then we can actually take the time. So basically, we're waiting for the GPU to | |
[01:34:18.960 --> 01:34:26.320] stop this iteration, take the time, and then we're going to just print it. So here I'm going to run | |
[01:34:26.320 --> 01:34:32.000] the training loop. And here on the right, I'm watching NVIDIA SMI. So we start off at zero. | |
[01:34:32.000 --> 01:34:37.520] We're not using the GPU. And then by default, PyTorch will use GPU zero. So we see that it gets | |
[01:34:37.520 --> 01:34:44.800] filled up. And we're using 35 gigabytes out of 80 gigabytes available. And then here on the left, | |
[01:34:44.800 --> 01:34:52.240] we see that because we cranked up the batch size, now it's only 20 batches to do a single epoch on | |
[01:34:52.240 --> 01:34:57.840] our tiny Shakespeare. And we see that we're seeing roughly 1000 milliseconds per iteration here. | |
[01:34:57.840 --> 01:35:05.200] So the first iteration sometimes is slower. And that's because PyTorch might be doing a lot of | |
[01:35:05.200 --> 01:35:10.240] initializations here on the very first iteration. And so it's probably initializing all these tensors | |
[01:35:10.240 --> 01:35:15.680] and buffers to hold the gradients. And I'm not sure all the work that happens here. But | |
[01:35:15.680 --> 01:35:19.360] this could be a slower iteration. When you're timing your logic, you always want to be careful | |
[01:35:19.360 --> 01:35:26.240] with that. But basically, we're seeing 1000 milliseconds per iteration. And so this will run for roughly | |
[01:35:26.240 --> 01:35:32.880] 50 seconds as we have it right now. So that's our baseline in float 32. One more thing I wanted | |
[01:35:32.880 --> 01:35:38.080] to mention is that if this doesn't fit into your GPU, and you're getting out of memory errors, | |
[01:35:38.080 --> 01:35:43.120] then start decreasing your batch size until things fit. So instead of 16, try eight or four, | |
[01:35:43.120 --> 01:35:49.120] or whatever you need to fit the batch into your GPU. And if you have a bigger GPU, you can actually | |
[01:35:49.120 --> 01:35:55.680] potentially get away with 32 and so on. By default, you want to basically max out the batch size that | |
[01:35:55.680 --> 01:36:01.760] fits on your GPU. And you want to keep it nice numbers. So use numbers that have lots of powers | |
[01:36:01.760 --> 01:36:09.280] of two in them. So 16 is a good number, eight, 24, 32, 48, these are nice numbers. But don't | |
[01:36:09.280 --> 01:36:14.160] do something like 17, because that will run very inefficiently on the GPU. And we're going to see | |
[01:36:14.160 --> 01:36:22.080] that a bit later as well. So for now, let's just stick with 16, 1024. And the one thing that I added | |
[01:36:22.080 --> 01:36:28.480] also here and I ran it again, is I'm calculating tokens per second throughput during training, | |
[01:36:28.480 --> 01:36:34.720] because we might end up changing the batch size around over time. But tokens per second is the | |
[01:36:34.720 --> 01:36:39.280] objective measure that we actually really care about. How many tokens of data are we training on? | |
[01:36:39.280 --> 01:36:43.760] And what is the throughput of tokens that we're getting in our optimization? So right now we're | |
[01:36:43.760 --> 01:36:49.760] processing and training on 163,000 tokens per second roughly. And that's a bit more objective | |
[01:36:49.760 --> 01:36:56.880] metric. Okay, so let's now enable TF 32. Now luckily PyTorch makes this fairly easy for us. | |
[01:36:56.880 --> 01:37:04.000] And to enable TF 32, you just need to do a single line. And it's this. And when we go to the PyTorch | |
[01:37:04.000 --> 01:37:08.880] documentation here for this function, basically this tells PyTorch what kind of kernels to run. | |
[01:37:09.520 --> 01:37:16.160] And by default, I believe it is highest, highest precision for Matmull. And that means that everything | |
[01:37:16.160 --> 01:37:21.360] happens in float 32, just like it did before. But if we set it too high, as we do right now, | |
[01:37:21.360 --> 01:37:30.000] matrix multiplications will now use tensor float 32, once available. My GPU is the A100. So it's | |
[01:37:30.000 --> 01:37:36.400] an ampere series. And therefore, TF 32 is available. If you're an older GPU, this might not be available | |
[01:37:36.400 --> 01:37:41.760] for you. But for my GPU, it's available. And so what I expect my Torx to do is that every single | |
[01:37:41.760 --> 01:37:46.800] place where we see an n dot linear inside there, there's a matrix multiplication. And I expect | |
[01:37:46.800 --> 01:37:53.600] that matrix multiplication now to be running on tensor course, utilizing the TF 32 precision. | |
[01:37:53.600 --> 01:38:00.720] So this is the single line of change that is, I believe, necessary. And let's rerun this. Now, | |
[01:38:00.720 --> 01:38:06.480] we saw that in terms of the throughput that is promised to us, we're supposed to be getting | |
[01:38:06.480 --> 01:38:18.560] 8x roughly. So let's see what happens. And that 8x came from here, right? 8x. And it also came from | |
[01:38:18.560 --> 01:38:26.560] looking at it here 156 T flops instead of 19.5. Okay, so what actually happened? | |
[01:38:27.440 --> 01:38:36.480] So we're seeing that our throughput roughly three X, not eight X. So we are going from 1000 milliseconds, | |
[01:38:36.480 --> 01:38:41.360] we're going down to 300 milliseconds. And our throughput is now about 50,000 tokens per second. | |
[01:38:41.360 --> 01:38:46.240] So we have a roughly three X instead of eight X. So what happened? And basically, what's happening | |
[01:38:46.240 --> 01:38:54.240] here is, again, a lot of these workloads are memory bound. And so even though the TF 32 offers in | |
[01:38:54.240 --> 01:39:01.920] principle, a lot faster throughput, all of these numbers everywhere are still float 32s. And it's | |
[01:39:01.920 --> 01:39:06.560] float 32 numbers that are being shipped all over the place through the memory system. And it's just | |
[01:39:06.560 --> 01:39:10.560] costing us way too much time to shuttle around all this data. And so even though we've made the | |
[01:39:10.560 --> 01:39:15.680] multiply itself much faster, we are memory bound, and we're not actually seeing the full benefit | |
[01:39:15.680 --> 01:39:22.800] that would come from this napkin map here. That said, we are getting one three X faster throughput. | |
[01:39:22.800 --> 01:39:30.400] And this is free. Single line of code in PyTorch, all your variables are still float 32 everywhere. | |
[01:39:30.400 --> 01:39:34.800] It just runs faster, and it's slightly more approximate, but we're not going to notice it | |
[01:39:34.800 --> 01:39:44.480] basically. So that's TF 32. Okay, so let's now continue. So we've exercised this row, and we saw | |
[01:39:44.480 --> 01:39:49.440] that we can crop out some of the precision inside the operation itself. But we saw that we're still | |
[01:39:49.440 --> 01:39:53.360] memory bound. We're still moving around all these floats, right, otherwise, and we're paying that | |
[01:39:53.360 --> 01:39:58.320] cost because of this. So let's now decrease the amount of stuff that we're going to be moving | |
[01:39:58.320 --> 01:40:04.560] around. And we're going to do that by dropping down to B float 16. So we're only going to be | |
[01:40:04.560 --> 01:40:10.000] maintaining 16 bits per float. And we're going to use the B float 16. Now let's plane in a bit | |
[01:40:10.000 --> 01:40:16.640] at B 16 difference. And we're going to be in this row. So when we go back to the documentation | |
[01:40:16.640 --> 01:40:25.280] here for the 100, we see here the precision that are available. And this is the original | |
[01:40:25.280 --> 01:40:32.960] copy 32. The TF 32 crops out the precision. And then here in BF 16, you see that it is very similar | |
[01:40:32.960 --> 01:40:39.600] to TF 32, but it's even more aggressive in cropping off the precision, the mantissa of this float. | |
[01:40:39.600 --> 01:40:44.800] So the important thing with B float 16 is that the exponent bits and the sine bit, of course, | |
[01:40:44.800 --> 01:40:50.160] remain unchanged. So if you're familiar with your float numbers, and I think this should | |
[01:40:50.160 --> 01:40:56.720] should probably be an entire video by itself, the exponent sets the range that you can represent | |
[01:40:56.720 --> 01:41:02.640] of your numbers. And the precision is how much precision you have for your numbers. And so | |
[01:41:02.640 --> 01:41:10.160] the range of numbers is identical, but we can we have fewer possibilities within that range, | |
[01:41:10.720 --> 01:41:14.400] because we are truncating the mantissa. So we have less precision in that range. | |
[01:41:14.400 --> 01:41:20.720] Oh, that means is that things are actually fairly nice, because we have the original range of | |
[01:41:20.720 --> 01:41:26.880] numbers that are representable in float. But we just have less precision for it. And the difference | |
[01:41:26.880 --> 01:41:33.760] with FB 16 is that they actually touch and change the range. So FB 16 cannot represent the full range | |
[01:41:33.760 --> 01:41:39.360] of FB 32. It has a reduced range. And that's where you start to actually run into issues, | |
[01:41:39.360 --> 01:41:44.560] because now you need these gradient scalars and things like that. And I'm not going to detail | |
[01:41:44.560 --> 01:41:50.720] off that in this video, because that's a whole video by itself. But FB 16 actually historically | |
[01:41:50.720 --> 01:41:56.800] came first, that was available in the Volta series before Ampere. And so FB 16 came first, | |
[01:41:56.800 --> 01:42:01.200] and everyone started to train that P16. But everyone had to use all these gradient scaling | |
[01:42:01.200 --> 01:42:05.920] operations, which are kind of annoying. And it's an additional source of state and complexity. | |
[01:42:05.920 --> 01:42:11.920] And the reason for that was because the exponent range was reduced in FB 16. So that's the IEEE | |
[01:42:11.920 --> 01:42:18.240] FB 16 spec. And then they came out with BF 16 and the Ampere. And they made it much simpler, | |
[01:42:18.240 --> 01:42:22.480] because we're just truncating mantissa, we have the exact same range. And we do not need gradient | |
[01:42:22.480 --> 01:42:28.960] scalars. So everything is much much simpler. Now, when we do use BF 16, though, we are impacting | |
[01:42:28.960 --> 01:42:35.120] the numbers that we might be seeing in our PyTorch code. This change is not just local | |
[01:42:35.120 --> 01:42:42.480] to the operation itself. So let's see how that works. There's some documentation here that, | |
[01:42:42.480 --> 01:42:46.960] so I think this is probably the best page to explain how to use mixed precision in PyTorch. | |
[01:42:46.960 --> 01:42:52.960] Because there are many other tutorials and so on, even within PyTorch documentation, | |
[01:42:52.960 --> 01:42:56.320] there are a lot more confusing. And so I recommend specifically this one, | |
[01:42:57.200 --> 01:43:01.680] because there's five other copies that I would not recommend. And then when we come here, | |
[01:43:01.680 --> 01:43:06.560] ignore everything about everything, ignore everything about the gradient scalars, | |
[01:43:06.560 --> 01:43:15.280] and only look at Torch.autocast. And basically, also, this comes to a single line of code at the end. | |
[01:43:15.280 --> 01:43:22.720] So this is the context manager that we want. And we want to use that in our network. When you | |
[01:43:22.720 --> 01:43:30.400] click into the Torch.autocast, autocasting, it has a few more, a bit more guideline for you. | |
[01:43:30.400 --> 01:43:36.560] So it's telling you do not call BF16 on any of your tensors, just use autocast, | |
[01:43:36.560 --> 01:43:42.480] and only surround the forward pass of the model and the loss calculation. And that's the only | |
[01:43:42.480 --> 01:43:46.400] two things that you should be surrounding. Lead the backward and the optimizer step alone. | |
[01:43:46.400 --> 01:43:51.360] So that's the guidance that comes from the PyTorch team. So we're going to follow that guidance. | |
[01:43:51.360 --> 01:43:55.680] And for us, because the loss calculation is inside of the model forward pass for us, | |
[01:43:55.680 --> 01:44:01.440] we are going to be doing this. And then we don't want to be using Torch float 16, | |
[01:44:01.440 --> 01:44:04.720] because if we do that, we need to start using gradient scalars as well. | |
[01:44:04.720 --> 01:44:09.120] So we are going to be using the float 16. This is only possible to do an ampere. | |
[01:44:09.120 --> 01:44:14.400] But this means that the changes are extremely minimal, like basically just this one line of code. | |
[01:44:16.800 --> 01:44:24.480] Let me first break in to here before we actually run this. So right after logins, | |
[01:44:24.480 --> 01:44:30.000] I'd like to show you that different from the TF32 that we saw, this is actually going to impact | |
[01:44:30.000 --> 01:44:38.160] our tensors. So this logis tensor, if we now look at this and we look at the D type, | |
[01:44:38.160 --> 01:44:44.800] we suddenly see that this is now BF16. It's not float 32 anymore. So our activations have been | |
[01:44:44.800 --> 01:44:52.000] changed. The activations tensor is now BF16, but not everything has changed. So model dot transformer | |
[01:44:52.000 --> 01:45:01.680] dot WTE. This is the weight token embedding table. It has a dot weight inside it. And the D type of | |
[01:45:01.680 --> 01:45:08.880] this weight, this parameter is still torch float 32. So our parameters seem to still be in float 32, | |
[01:45:08.880 --> 01:45:14.480] but our activations, the logits are now in BF16. So clearly, this is why we get the mixed | |
[01:45:14.800 --> 01:45:20.000] precision. Some things pytorch is keeping in float 32. Some things pytorch is converting | |
[01:45:20.000 --> 01:45:28.640] to lower precision. And what gets converted, at what point is not super clear. I remember scrolling | |
[01:45:28.640 --> 01:45:40.720] down. Is it here? Okay, I can't find it. I thought it was here. Okay, there we go. | |
[01:45:42.000 --> 01:45:49.040] So there are a few docs on when you're using this autocast, what gets converted to BF16 and when. | |
[01:45:49.040 --> 01:45:54.880] So for example, only these matrix multiply like operations get converted to BF16. But a lot of | |
[01:45:54.880 --> 01:45:59.440] operations remain in float 32. So in particular, a lot of normalizations, like layer norms and | |
[01:45:59.440 --> 01:46:05.760] things like that, not all of those layers might be converted. So only some layers selectively | |
[01:46:05.760 --> 01:46:14.000] would be running BF16. But things like softmax, layer norms, log, log softmax. So loss function | |
[01:46:14.000 --> 01:46:18.480] calculations, a lot of those things might remain in float 32, because they are more susceptible to | |
[01:46:18.480 --> 01:46:27.120] precision changes. Matrix multiplies are fairly robust to precision changes. So some parts of the | |
[01:46:27.120 --> 01:46:35.040] network are impacted more or less by the precision change. So basically, only some parts of the of | |
[01:46:35.040 --> 01:46:40.240] the model are running in reduced precision. Let's take it for a spin. And let's actually see what kind | |
[01:46:40.240 --> 01:46:52.800] of improvement we achieve here. Okay, so we used to be 333 milliseconds. We're now 300. | |
[01:46:52.800 --> 01:46:58.800] And we used to be somewhere around 50,000 tokens per second. We're now 55. So we're definitely running | |
[01:46:58.800 --> 01:47:04.400] faster, but maybe not a lot faster. And that's because there are still many, many bottlenecks | |
[01:47:04.400 --> 01:47:09.280] in our GPT two, we're just getting started. But we have dropped down the precision as far as we can | |
[01:47:09.280 --> 01:47:15.200] with my current GPU, which is a 100, we're using PyTorch AutoCast. Unfortunately, I don't | |
[01:47:15.200 --> 01:47:21.760] actually exactly know what PyTorch AutoCast does. I don't actually know exactly what's in BF16, | |
[01:47:21.760 --> 01:47:27.280] what's in float 32, we could go in and we could start scrutinize it. But these are the kinds of | |
[01:47:27.280 --> 01:47:33.600] rules that PyTorch has internally. And unfortunately, they don't document it very well. So we're not | |
[01:47:33.600 --> 01:47:39.520] gonna go into that in too much detail. But for now, we are training in BF16. We do not need a gradient | |
[01:47:39.520 --> 01:47:46.880] scaler. And the reason things are running faster is because we are able to run tensor course in | |
[01:47:46.880 --> 01:47:53.920] BF16 now. That means we are in this row. But we are also paying in precision for this. | |
[01:47:53.920 --> 01:48:01.520] So we expect slightly less accurate results with respect to the original FP32. But empirically, | |
[01:48:01.520 --> 01:48:06.800] in many cases, this is a worth it kind of trade out, because it allows you to run faster. And you | |
[01:48:06.800 --> 01:48:15.440] could, for example, train longer and make up for that precision decrease. So that's BF16 for now. | |
[01:48:15.440 --> 01:48:20.960] Okay, so as we can see, we are currently at about 300 milliseconds per iteration. And we're now | |
[01:48:20.960 --> 01:48:24.960] gonna reach for some really heavy weapons in the PyTorch arsenal. And in particular, | |
[01:48:24.960 --> 01:48:30.480] we're going to introduce Torch. Compile. So Torch. Compile is really quite incredible infrastructure | |
[01:48:30.480 --> 01:48:36.240] from the PyTorch team. And it's basically a compiler for neural networks. Like it's almost like GCC | |
[01:48:36.240 --> 01:48:44.480] for CNC++ code. And this is just this GCC of neural nets. So came out a while ago and extremely | |
[01:48:44.480 --> 01:48:51.040] simple to use. The way to use Torch compile is to do this. It's a single line of code to compile | |
[01:48:51.040 --> 01:48:57.280] your model and return it. Now this line of code will cost you compilation time. But as you might | |
[01:48:57.280 --> 01:49:01.920] guess, it's going to make the code a lot faster. So let's actually run that. And because this will | |
[01:49:01.920 --> 01:49:06.720] take some time to run, let's currently remember we're 300 milliseconds. And we'll see what happens. | |
[01:49:06.720 --> 01:49:11.680] Now, while this is running, I'd like to explain a little bit of what Torch.compile does under the | |
[01:49:11.680 --> 01:49:18.160] hood. So feel free to read this page of PyTorch. But basically, there's no real good reason for | |
[01:49:18.160 --> 01:49:22.960] you to not use Torch compile in your PyTorch. I kind of feel like you should be using it almost | |
[01:49:22.960 --> 01:49:29.120] by default, unless you're debugging and you want your code to run really fast. And there's one | |
[01:49:29.120 --> 01:49:32.880] line here in Torch compile that I found that actually kind of like gets to why this is faster. | |
[01:49:32.880 --> 01:49:39.680] Speedup comes from reducing Python overhead and GPU read writes. So let me unpack that a little bit. | |
[01:49:39.680 --> 01:49:46.480] Okay, here we are. Okay, so we're going from 300 milliseconds. We're now running at 129 milliseconds. | |
[01:49:47.760 --> 01:49:54.080] So this is 300 divided 129, about 2.3x improvement from a single line of code in PyTorch. | |
[01:49:54.080 --> 01:49:59.520] So quite incredible. So what is happening? What's happening under the hood? Well, when you pass | |
[01:49:59.520 --> 01:50:06.720] the model to Torch compile, what we have here in this end module, this is really just the algorithmic | |
[01:50:06.720 --> 01:50:13.200] description of what we'd like to happen in our network. And Torch compile will analyze the entire | |
[01:50:13.200 --> 01:50:18.480] thing. And it will look at what operations you like to use. And with the benefit of knowing | |
[01:50:18.480 --> 01:50:23.680] exactly what's going to happen, it doesn't have to run in what's called eager mode. It doesn't | |
[01:50:23.680 --> 01:50:30.160] have to just kind of like go layer by layer, like the Python interpreter normally would start at the | |
[01:50:30.160 --> 01:50:36.480] forward. And the Python interpreter will go, okay, let's do this operation. And then let's do that | |
[01:50:36.480 --> 01:50:41.840] operation. And it kind of materializes all the operations as it goes through. So these | |
[01:50:42.400 --> 01:50:48.000] calculations are dispatched and run in this order. And the Python interpreter and this code doesn't | |
[01:50:48.000 --> 01:50:52.640] know what kind of operations are going to happen later. But Torch compile sees your entire code | |
[01:50:52.640 --> 01:50:58.160] at the same time. And it's able to know what operations you intend to run. And it will kind of | |
[01:50:58.160 --> 01:51:02.800] optimize that process. The first thing we'll do is we'll take out the Python interpreter from | |
[01:51:02.800 --> 01:51:07.840] the forward pass entirely. And it will kind of compile this entire neural net as a single object | |
[01:51:07.840 --> 01:51:12.720] with no Python interpreter involved. So it knows exactly what's going to run, it will just run that. | |
[01:51:12.720 --> 01:51:20.400] And it's all going to be running an efficient code. The second thing that happens is this read | |
[01:51:20.400 --> 01:51:24.480] write that they mentioned very briefly. So a good example of that, I think, is the | |
[01:51:24.480 --> 01:51:30.240] GALU nonlinearity that we've been looking at. So here we use the N and GALU. Now this here | |
[01:51:30.240 --> 01:51:37.600] is me basically just breaking up the N and GALU, which you remember has this formula. So this | |
[01:51:37.600 --> 01:51:42.160] here is the equivalent implementation to what's happening inside GALU algorithmically, it's identical. | |
[01:51:42.160 --> 01:51:49.600] Now, by default, if we just were using this instead of an N dot GALU here, what would happen | |
[01:51:49.600 --> 01:51:54.320] without Torch compile? Well, the Python interpreter would make its way here. And then it would be, | |
[01:51:54.320 --> 01:52:00.160] okay, well, there's an input. Well, let me first, let me raise this input to the third power. | |
[01:52:00.160 --> 01:52:04.080] And it's going to dispatch a kernel that takes your input and raises to the third power. | |
[01:52:04.800 --> 01:52:11.360] And that kernel will run. And when this kernel runs, what ends up happening is this input is | |
[01:52:11.360 --> 01:52:17.360] stored in the memory of the GPU. So here's a helpful example of the layout of what's happening, | |
[01:52:17.360 --> 01:52:22.560] right? You have your CPU, this is in every single computer. There's a few cores in there, | |
[01:52:22.560 --> 01:52:29.200] and you have your RAM, your memory. And the CPU can talk to the memory, and this is all well known. | |
[01:52:29.200 --> 01:52:33.120] But now we've added the GPU. And the GPU is a slightly different architecture, of course, | |
[01:52:33.120 --> 01:52:38.560] they can communicate. And it's different in that it's got a lot more cores than a CPU. | |
[01:52:38.560 --> 01:52:43.600] All of those cores are individually a lot simpler, too. But it also has memory, right? | |
[01:52:43.600 --> 01:52:50.640] This high bandwidth memory. Sorry if I'm botching it, HBM. I don't even know what that stands for. | |
[01:52:50.640 --> 01:52:55.440] I'm just realizing now. But this is the memory, and it's very equivalent to | |
[01:52:55.440 --> 01:53:01.520] RAM, basically, in the computer. And what's happening is that input is living in the memory. | |
[01:53:02.160 --> 01:53:11.840] And when you do input cubed, this has to travel to the GPU, to the course, and to all the caches | |
[01:53:11.840 --> 01:53:18.960] and registers on the actual chip of this GPU. And it has to calculate all the elements to the third, | |
[01:53:18.960 --> 01:53:24.960] and then it saves the result back to the memory. And it's this travel time that actually causes | |
[01:53:24.960 --> 01:53:30.960] a lot of issues. So here, remember this memory bandwidth? We can communicate about two terabytes | |
[01:53:30.960 --> 01:53:37.440] per second, which is a lot. But also, we have to traverse this link, and it's very slow. So here | |
[01:53:37.440 --> 01:53:42.080] on the GPU, we're on chip, and everything is super-passed within the chip. But going to the memory is | |
[01:53:42.080 --> 01:53:48.480] extremely expensive, takes extremely long amount of time. And so we load the input, do the calculations, | |
[01:53:48.480 --> 01:53:54.400] and load back the output. And this round trip takes a lot of time. And now right after we do that, | |
[01:53:54.400 --> 01:54:01.360] we multiply by this constant. So what happens then is we dispatch another kernel, and then the result | |
[01:54:01.360 --> 01:54:06.640] travels back, all the elements get multiplied by constant, and then the result travel back to | |
[01:54:06.640 --> 01:54:13.360] the memory. And then we take the result, and we add back input. And so this entire thing, again, | |
[01:54:13.360 --> 01:54:19.360] travels to the GPU, adds the inputs and gets written back. So we're making all these | |
[01:54:19.360 --> 01:54:24.560] round trips from the memory to actually where the computation happens. Because all the tensor cores | |
[01:54:24.560 --> 01:54:29.360] and ALUs and everything like that is all stored on the chip in the GPU. So we're doing a ton of | |
[01:54:29.360 --> 01:54:35.920] round trips. And PyTorch, without using Church Compile, doesn't know to optimize this, because | |
[01:54:35.920 --> 01:54:40.240] it doesn't know what kind of operations you're running later. You're just telling it, raise the | |
[01:54:40.240 --> 01:54:44.400] power to the third, then do this, then do that, and it will just do that in that sequence. | |
[01:54:44.400 --> 01:54:48.960] But Church Compile sees your entire code. It will come here, and it will realize, wait, all of these | |
[01:54:48.960 --> 01:54:54.320] are element-wise operations. And actually, what I'm going to do is I'm going to do a single | |
[01:54:54.320 --> 01:54:59.920] trip of input to the GPU. Then for every single element, I'm going to do all of these operations | |
[01:54:59.920 --> 01:55:06.320] while that memory is on the GPU, or chunks of it, rather. And then I'm going to write back a single | |
[01:55:06.320 --> 01:55:10.160] time. So we're not going to have these round trips. And that's one example of what's called | |
[01:55:10.160 --> 01:55:15.360] kernel fusion, and is a major way in which everything is sped up. So basically, if you have your | |
[01:55:15.360 --> 01:55:19.280] benefit of onset, and you know exactly what you're going to compute, you can optimize your | |
[01:55:19.280 --> 01:55:23.680] round trips to the memory, and you're not going to pay the memory bandwidth cost. | |
[01:55:23.680 --> 01:55:28.240] And that's fundamentally what makes some of these operations a lot faster, and what they mean by | |
[01:55:28.240 --> 01:55:36.480] read writes here. So let me erase this, because we are not using it. And yeah, we should be using | |
[01:55:36.480 --> 01:55:43.440] Church Compile. And our code is now significantly faster, and we're doing about 125,000 tokens per | |
[01:55:43.440 --> 01:55:49.200] second. But we still have a long way to go. Before we move on, I wanted to supplement the discussion | |
[01:55:49.200 --> 01:55:53.440] a little bit with a few more figures, because this is a complicated topic, but it's worth | |
[01:55:53.440 --> 01:55:58.000] understanding on the high level what's happening here. And I could probably spend an entire video | |
[01:55:58.000 --> 01:56:04.240] of like two hours on this, but just the preview of that basically. So this chip here, that is the | |
[01:56:04.240 --> 01:56:11.680] GPU. This chip is where older calculations happen mostly. But this chip also does have some memory | |
[01:56:11.680 --> 01:56:18.880] in it. But most of the memory by far is here in the high bandwidth memory, HBM, and is connected, | |
[01:56:18.880 --> 01:56:26.560] they're connected. But these are two separate chips, basically. Now, here, this is a zoom in of kind of | |
[01:56:26.560 --> 01:56:33.200] this cartoon diagram of a GPU. And what we're seeing here is number one, you see this HBM. I | |
[01:56:33.200 --> 01:56:38.080] realize it's probably very small for you. But on the sides here, it says HBM. And so that's the | |
[01:56:38.080 --> 01:56:45.360] links to the HBM. Now the HBM is again, off chip. On the chip, there are a large number of these | |
[01:56:45.360 --> 01:56:51.840] streaming multiprocessors. Every one of these is an SM. There's 120 of them in total. And this is | |
[01:56:51.840 --> 01:56:57.280] where a lot of the calculations happen. And this is a zoom in of a single individual SM. It has | |
[01:56:57.280 --> 01:57:01.120] these four quadrants. And see, for example, tensor core, this is where a lot of the matrix | |
[01:57:01.120 --> 01:57:06.160] multiply stuff happens. But there's all these other units to do all different kinds of calculations | |
[01:57:06.160 --> 01:57:13.440] for FP64, FP32, and for integers, and so on. Now, so we have all this logic here to the | |
[01:57:13.440 --> 01:57:18.400] calculations. But in addition to that, on the chip, there is memory sprinkled throughout the chip. | |
[01:57:18.400 --> 01:57:25.520] So L2 cache is some amount of memory that lives on the chip. And then on the SMs themselves, | |
[01:57:25.520 --> 01:57:31.360] there's L1 cache. I realize it's probably very small for you, but this blue bar is L1. And there's | |
[01:57:31.360 --> 01:57:38.080] also registers. And so there is memory stored here. But the way this memory is stored is very | |
[01:57:38.080 --> 01:57:44.400] different from the way memory stored in HBM. This is a very different implementation using | |
[01:57:44.400 --> 01:57:49.520] just in terms of like what the silicon looks like. It's a very different implementation. | |
[01:57:49.520 --> 01:57:55.280] So here you would be using transistors and capacitors. And here it's a very different | |
[01:57:55.280 --> 01:58:04.240] implementation with SRAM and what that looks like. But long story short, there is memory | |
[01:58:04.240 --> 01:58:10.320] inside the chip, but it's not a lot of memory. That's the critical point. So this is an example | |
[01:58:10.320 --> 01:58:15.760] diagram of a slightly different GPU, just like here, where it shows that, for example, typical | |
[01:58:15.760 --> 01:58:22.800] numbers for CPU DRM memory, which is this thing here, you might have one terabyte of disk, but it | |
[01:58:22.800 --> 01:58:27.040] would be extremely expensive to access, especially for a GPU. You have to go through the CPU here. | |
[01:58:27.040 --> 01:58:33.600] Now, next we have the HBM. So we have tens of gigabytes of HBM memory on a typical GPU here. | |
[01:58:33.600 --> 01:58:40.400] But as I mentioned, very expensive to access. And then on the chip itself, everything is extremely | |
[01:58:40.400 --> 01:58:46.080] fast within the chip, but we only have a couple of 10 megabytes of memory collectively throughout | |
[01:58:46.080 --> 01:58:51.840] the chip. And so there's just not enough space because the memory is very expensive on the chip. | |
[01:58:51.840 --> 01:58:55.920] And so there's not a lot of it, but it is lightning fast to access in relative terms. | |
[01:58:55.920 --> 01:59:02.240] And so basically, whenever we have these kernels, the more accurate picture of what's happening here | |
[01:59:02.240 --> 01:59:08.480] is that we take these inputs which live by default on the global memory, and now we need to perform | |
[01:59:08.480 --> 01:59:15.680] some calculation. So we start streaming the data from the global memory to the chip. | |
[01:59:15.680 --> 01:59:20.160] We perform the calculations on the chip and then stream it back and store it back to the global | |
[01:59:20.160 --> 01:59:26.160] memory. And so if we don't have torch compile, we are streaming the data through the chip doing | |
[01:59:26.160 --> 01:59:30.160] the calculations and saving to the memory. And we're doing those round trips many, many times. | |
[01:59:30.160 --> 01:59:35.760] But if it's torch compiled, then we start streaming the memory as before. But then while | |
[01:59:35.760 --> 01:59:43.040] we're on the chip, we have a chunk of the data that we're trying to process. So that chunk now | |
[01:59:43.040 --> 01:59:47.680] lives on the chip. While it's on the chip, it's extremely fast to operate on. So if we have kernel | |
[01:59:47.680 --> 01:59:52.880] fusion, we can do all the operations right there in an element-wise passion. And those are very | |
[01:59:52.880 --> 01:59:59.280] cheap. And then we do a single round trip back to the global memory. So operator fusion basically | |
[01:59:59.280 --> 02:00:04.000] allows you to keep your chunk of data on the chip and do lots of calculations on it before you | |
[02:00:04.000 --> 02:00:10.800] write it back. And that gives huge savings. And that's why torch compile ends up being a lot faster, | |
[02:00:10.800 --> 02:00:16.160] or that's one of the major reasons. So again, just a very brief intro to the memory hierarchy, | |
[02:00:16.160 --> 02:00:21.200] and roughly what torch compile does for you. Now torch compile is amazing, but there are | |
[02:00:21.200 --> 02:00:26.880] operations that torch compile will not find. And an amazing example of that is flash attention | |
[02:00:26.880 --> 02:00:33.440] to which we turn next. So flash attention comes from this paper from Stanford in 2022. | |
[02:00:33.440 --> 02:00:41.360] And it's this incredible algorithm for performing attention, and running it a lot faster. So flash | |
[02:00:41.360 --> 02:00:48.640] attention will come here, and we will take out these four lines. And flash attention implements | |
[02:00:48.640 --> 02:00:54.880] these four lines really, really quickly. And how does it do that? Well, flash attention is a | |
[02:00:54.880 --> 02:01:01.760] kernel fusion operation. So you see here, we have in this diagram, they're showing PyTorch, | |
[02:01:01.760 --> 02:01:06.880] and you have these four operations, they're including dropout, but we are not using dropout | |
[02:01:06.880 --> 02:01:12.880] here. So we just have these four lines of code here. And instead of those, we are fusing them | |
[02:01:12.880 --> 02:01:19.520] into a single fused kernel of flash attention. So it's a kernel fusion algorithm, but it's a | |
[02:01:19.520 --> 02:01:24.720] kernel fusion that torch compile cannot find. And the reason that it cannot find it is that it | |
[02:01:24.720 --> 02:01:30.160] requires an algorithmic rewrite of how attention is actually implemented here in this case. | |
[02:01:30.160 --> 02:01:35.280] And what's remarkable about it is that flash attention, actually, if you just count the number | |
[02:01:35.280 --> 02:01:41.920] of flops, flash attention does more flops than this attention here. But flash attention is | |
[02:01:41.920 --> 02:01:49.280] actually significantly faster. In fact, they cite 7.6 times faster potentially. And that's because | |
[02:01:49.280 --> 02:01:55.520] it is very mindful of the memory hierarchy, as I described it just now. And so it's very mindful | |
[02:01:55.520 --> 02:02:00.800] about what's in high bandwidth memory, what's in the shared memory. And it is very careful | |
[02:02:00.800 --> 02:02:06.240] with how it orchestrates to computation, such that we have fewer reads and writes to the | |
[02:02:06.240 --> 02:02:10.720] high bandwidth memory. And so even though we're doing more flops, the expensive part is their load | |
[02:02:10.720 --> 02:02:16.960] and store into HBM. And that's what they avoid. And so in particular, they do not ever materialize | |
[02:02:16.960 --> 02:02:23.200] this end-by-end attention matrix, this ATT here, a flash attention is designed such that this | |
[02:02:23.200 --> 02:02:28.880] matrix never gets materialized at any point, and it never gets read or written to the HBM. | |
[02:02:29.440 --> 02:02:34.480] And this is a very large matrix, right? Because this is where all the queries and keys interact, | |
[02:02:34.480 --> 02:02:42.240] and we're sort of getting, for each head, for each batch element, we're getting a T by T matrix | |
[02:02:42.240 --> 02:02:47.600] of attention, which is a million numbers, even for a single head, at a single batch index, | |
[02:02:47.600 --> 02:02:54.480] so basically this is a ton of memory, and this is never materialized. And the way that this is | |
[02:02:54.480 --> 02:03:00.720] achieved is that basically the fundamental algorithmic rewrite here relies on this online | |
[02:03:00.720 --> 02:03:05.680] softmax trick, which was proposed previously, and I'll show you the paper in a bit. And the | |
[02:03:05.680 --> 02:03:12.320] online softmax trick coming from a previous paper shows how you can incrementally evaluate | |
[02:03:12.320 --> 02:03:18.880] a softmax without having to sort of realize all of the inputs to the softmax of the normalization. | |
[02:03:18.880 --> 02:03:23.520] And you do that by having these intermediate variables m and l, and there's an update to them | |
[02:03:23.520 --> 02:03:29.840] that allows you to evaluate the softmax in an online manner. Now flash attention, | |
[02:03:29.840 --> 02:03:34.160] actually, so recently flash attention to came out as well, so I have that paper up here as well, | |
[02:03:34.160 --> 02:03:39.600] that has additional gains to how it calculates flash attention. And the original paper that | |
[02:03:39.600 --> 02:03:44.240] this is based on basically is this online normalization calculation for softmax, | |
[02:03:44.240 --> 02:03:49.520] and remarkably it came out of NVIDIA, and it came out of it like really early 2018, | |
[02:03:49.520 --> 02:03:55.760] so this is four years before flash attention. And this paper says that we propose a way | |
[02:03:55.760 --> 02:03:59.760] to compute the classical softmax with fewer memory accesses, and hypothesize that this | |
[02:03:59.760 --> 02:04:05.600] reduction in memory accesses should improve softmax performance on actual hardware. And so | |
[02:04:05.600 --> 02:04:10.480] they are extremely correct in this hypothesis. But it's really fascinating to me that they're | |
[02:04:10.480 --> 02:04:15.280] from NVIDIA, and that they had this realization, but they didn't actually take it to the actual | |
[02:04:15.280 --> 02:04:20.800] flash attention that had to come four years later from Stanford. So I don't fully understand the | |
[02:04:20.800 --> 02:04:26.560] historical how this happened historically, but they do basically propose this online update to | |
[02:04:26.560 --> 02:04:32.800] the softmax right here. And this is fundamentally what they reuse here to calculate the softmax | |
[02:04:32.800 --> 02:04:36.960] in a streaming manner. And then they realize that they can actually fuse all the other operations | |
[02:04:36.960 --> 02:04:42.240] with the online softmax calculation into a single fused kernel, flash attention, | |
[02:04:42.240 --> 02:04:48.960] and that's what we are about to use. So a great example I think of being aware of memory hierarchy, | |
[02:04:48.960 --> 02:04:53.040] the fact that flops don't matter, the entire memory access pattern matters, | |
[02:04:53.040 --> 02:04:57.760] and that torch compile is amazing, but there are many optimizations that are still available to us | |
[02:04:57.760 --> 02:05:03.120] that potentially torch compile cannot find. Maybe one day it could, but right now it seems like | |
[02:05:03.120 --> 02:05:07.360] a lot to ask. So here's what we're going to do. We're going to use flash attention, | |
[02:05:08.000 --> 02:05:13.680] and the way to do that basically in PyTorch is we are going to comment out these four lines, | |
[02:05:13.680 --> 02:05:19.600] and we're going to replace them with a single line. And here we are calling this compound | |
[02:05:19.600 --> 02:05:28.240] operation in PyTorch called scaled dot product attention. And PyTorch will call flash attention | |
[02:05:28.240 --> 02:05:33.680] when you use it in this way. I'm not actually 100% sure why torch compile doesn't realize that | |
[02:05:33.680 --> 02:05:39.040] these four lines should just call flash attention in this exact way. We have to do it again for it, | |
[02:05:39.040 --> 02:05:47.440] which in my opinion is a little bit odd, but here we are. So you have to use this compound up, | |
[02:05:47.440 --> 02:05:53.920] and let's wait for a few moments before torch compile gets around to it. And then let's remember | |
[02:05:53.920 --> 02:06:00.560] that we achieved 6.05661, I have it here. That's the loss we are expecting to see, | |
[02:06:01.360 --> 02:06:07.840] and we took 130 milliseconds before this change. So we're expecting to see the exact same result | |
[02:06:07.840 --> 02:06:14.320] by iteration 49, but we expect to see faster runtime, because flash attention is just an | |
[02:06:14.320 --> 02:06:18.320] algorithmic rewrite, and it's a faster kernel, but it doesn't actually change any of the computation, | |
[02:06:18.320 --> 02:06:23.440] and we should have the exact same optimization. So okay, so we're a lot faster. We're at about 95 | |
[02:06:23.440 --> 02:06:32.720] milliseconds, and we achieve 6.058. Okay, so they're basically identical up to a floating point | |
[02:06:32.720 --> 02:06:39.120] fetch factor. So it's the identical computation, but it's significantly faster going from 130 to | |
[02:06:39.120 --> 02:06:48.800] roughly 96, and so this is 96 divided 130 ish, so this may be 27ish percent improvement. | |
[02:06:50.240 --> 02:06:56.560] So really interesting, and that is flash attention. Okay, we are now getting to one of my favorite | |
[02:06:56.560 --> 02:07:02.000] optimizations, and it is simultaneously the dumbest and the most brilliant optimization, | |
[02:07:02.000 --> 02:07:07.760] and it's always a little bit surprising to me. Anyway, so basically, I mentioned a few minutes | |
[02:07:07.760 --> 02:07:15.120] ago that there are some numbers that are nice, and some numbers that are ugly. So 64 is a beautiful | |
[02:07:15.120 --> 02:07:21.120] nice number. 128 is even nicer. 256 is beautiful. What makes these numbers beautiful is that there | |
[02:07:21.120 --> 02:07:27.360] are many powers of two inside them. You can divide by two many times, and examples of ugly numbers | |
[02:07:27.360 --> 02:07:32.560] are like 13 and 17 and something like that, prime numbers, numbers that are not even and so on. | |
[02:07:32.560 --> 02:07:36.880] And so pretty much you always want to use nice numbers in all of your code that deals with neural | |
[02:07:36.880 --> 02:07:43.280] networks or CUDA, because everything in CUDA works in sort of like powers of two, and lots of kernels | |
[02:07:43.840 --> 02:07:50.240] are written in terms of powers of two, and there are lots of blocks of sizes 16 and 64 and so on. | |
[02:07:50.240 --> 02:07:54.720] So everything is written in those terms, and you always have special case handling for all kinds | |
[02:07:54.720 --> 02:08:01.680] of logic that when your inputs are not made of nice numbers. So let's see what that looks like. | |
[02:08:01.680 --> 02:08:08.400] Basically scan your code and look for ugly numbers is roughly the heuristic. So three times is kind | |
[02:08:08.400 --> 02:08:14.640] of ugly. I'm not 100% sure maybe this can be improved, but this is ugly and not ideal. | |
[02:08:14.640 --> 02:08:24.880] Four times is nice. So that's that's nice. 1024 is very nice. That's a power of two. | |
[02:08:24.880 --> 02:08:32.800] 12 is a little bit suspicious. Not too many powers up to 768 is great. 50,000 to 57 is a really, | |
[02:08:32.800 --> 02:08:40.560] really ugly number. It's first of all, it's odd. So and there's no not too many powers of two in there. | |
[02:08:40.560 --> 02:08:46.480] So this is a very ugly number, and it's highly suspicious. And then when we scroll down, all these | |
[02:08:46.480 --> 02:08:53.680] numbers are nice. And then here we have mostly nice numbers except for 25. So in this configuration | |
[02:08:53.680 --> 02:08:59.280] of GPT to Excel, a number of heads is 25. That's a really ugly number. That's an odd number. And | |
[02:09:00.080 --> 02:09:03.440] actually this did cause a lot of headaches for us recently when we're trying to optimize some | |
[02:09:03.440 --> 02:09:10.240] kernels to run this fast and require a bunch of special case handling. So basically these numbers | |
[02:09:10.240 --> 02:09:15.040] are we have some ugly numbers and some of them are easier to fix than others. And in particular, | |
[02:09:15.040 --> 02:09:19.760] the vocab size being 50,000 to 57, that's a very ugly number, very suspicious and we're going to | |
[02:09:19.760 --> 02:09:25.040] fix it. Now when you when you fix these things, one of the easy ways to do that is you basically | |
[02:09:26.320 --> 02:09:32.240] increase the number until it's the nearest power of two that you like. So here's a much nicer number. | |
[02:09:32.240 --> 02:09:42.880] It's 50,304. And why is that? Because 50,304 can be divided by 8 or by 16 or by 32, 64. | |
[02:09:42.880 --> 02:09:50.320] It can even be divided by 128, I think. Yeah. So it's a very nice number. So what we're going to do | |
[02:09:50.320 --> 02:09:55.200] here is this is the GPT config. And you see that we initialize vocab size to 50,257. | |
[02:09:55.760 --> 02:10:05.520] Let's override just that element to be 50,304. Okay. So everything else stays the same, | |
[02:10:05.520 --> 02:10:10.800] we're just increasing our vocabulary size. So we're adding, it's almost like we're adding fake | |
[02:10:10.800 --> 02:10:16.800] tokens. So that vocab size has powers of two inside it. Now, actually, what I'm doing here, | |
[02:10:16.800 --> 02:10:20.640] by the way, is I'm increasing the amount of computation that our network will be doing. | |
[02:10:20.640 --> 02:10:25.040] If you just count the flops on like, do the math of how many flops we're doing, we're going to be | |
[02:10:25.040 --> 02:10:31.120] doing more flops. And we still have to think through whether this doesn't break anything. | |
[02:10:31.120 --> 02:10:40.080] But if I just run this, let's see what we get. Currently, this ran in maybe 96.5 milliseconds | |
[02:10:40.080 --> 02:10:44.720] per step. I'm just kind of eyeballing it. And let's see what kind of result we're going to get. | |
[02:10:44.720 --> 02:10:52.480] While this is compiling, let's think through whether our code actually works. Okay. When we | |
[02:10:52.480 --> 02:10:56.880] increase the vocab size like this, let's look at where vocab size is actually used. | |
[02:10:56.880 --> 02:11:02.560] So we swing up to the in it. And we see that it's used inside the embedding table, of course. | |
[02:11:02.560 --> 02:11:06.400] So all the way at the bottom of the transformer. And it's used at the classifier layer, all the | |
[02:11:06.400 --> 02:11:12.480] way at the top of the transformer. So in two places. And let's take a look. And we're running at 93. | |
[02:11:12.480 --> 02:11:20.400] So 93 milliseconds instead of 96.5. So we are seeing a roughly 4% improvement here. | |
[02:11:21.280 --> 02:11:27.920] By doing more calculations. And the reason for this is we fixed, we made an ugly number | |
[02:11:27.920 --> 02:11:33.280] into a nice number. Let's I'm going to come into the explanation for that a little bit again. | |
[02:11:33.280 --> 02:11:36.640] But for now, let's just convince ourselves that we're not breaking anything when we do this. | |
[02:11:36.640 --> 02:11:42.320] So first of all, we've made the WTE, the embedding table for the tokens. We've made it larger. | |
[02:11:42.320 --> 02:11:48.080] It's almost like we introduced more tokens at the bottom. And these tokens are never used, | |
[02:11:48.080 --> 02:11:54.880] because the GPT tokenizer only has tokens up to 50,000 to 56. And so we'll never index into | |
[02:11:54.880 --> 02:12:00.000] the rows that we've added. So we're wasting a little bit of space here by creating memory that's | |
[02:12:00.000 --> 02:12:04.960] never going to be accessed, never going to be used, et cetera. Now that's not fully correct, | |
[02:12:04.960 --> 02:12:10.400] because this WTE weight ends up being shared and ends up being used in the classifier here at the | |
[02:12:10.400 --> 02:12:15.920] end. So what is that doing to the classifier? I brought here. Well, what that's doing is we're | |
[02:12:15.920 --> 02:12:20.800] predicting additional dimensions of the classifier now. And we're predicting probabilities for tokens | |
[02:12:20.800 --> 02:12:26.160] that will, of course, never be present in the training set. And so therefore, the network has | |
[02:12:26.160 --> 02:12:32.080] to learn that these probabilities have to be driven to zero. And so the logits that the network | |
[02:12:32.080 --> 02:12:37.760] produces have to drive those dimensions of the output to native infinity. But that's no different | |
[02:12:37.760 --> 02:12:43.840] from all the other tokens that are already in our dataset, or rather that are not in our dataset. | |
[02:12:43.840 --> 02:12:49.680] So Shakespeare only probably uses, let's say, 1,000 tokens out of 50,000 to 57 tokens. So most | |
[02:12:49.680 --> 02:12:53.920] of the tokens are already being driven to zero probability by the optimization. We've just introduced | |
[02:12:53.920 --> 02:12:58.800] a few more tokens now that in a similar manner will never be used and have to be driven to zero | |
[02:12:58.800 --> 02:13:07.600] in probability. So functionally, though, nothing breaks, we're using a bit more extra memory. But | |
[02:13:07.600 --> 02:13:12.720] otherwise, this is a harmless operation, as far as I can tell. But and we're adding calculation | |
[02:13:12.720 --> 02:13:16.800] by it's running faster. And it's running faster, because as I mentioned in CUDA, | |
[02:13:16.800 --> 02:13:24.960] so many kernels use block tiles. And these block tiles are usually nice numbers. So powers of 2. | |
[02:13:24.960 --> 02:13:31.040] So calculations are done in like chunks of 64 or chunks of 32. And when you're when | |
[02:13:31.040 --> 02:13:37.760] your desired calculation doesn't neatly fit into those block tiles, there are all kinds of boundary | |
[02:13:37.760 --> 02:13:43.600] kernels that can kick in to like do the last part. So basically, in a lot of kernels, | |
[02:13:43.600 --> 02:13:48.080] they will chunkate up your input, and they will do the nice part first. And then they have a whole | |
[02:13:48.080 --> 02:13:54.880] second second phase, where they come back to anything that like remains. And then they process | |
[02:13:54.880 --> 02:13:58.880] the remaining part. And the kernels for that could be very inefficient. And so you're basically | |
[02:13:58.880 --> 02:14:04.640] spinning up all this extra compute, and it's extremely inefficient. So you might as well pad | |
[02:14:04.640 --> 02:14:11.120] your inputs and make it fit nicely. And usually that empirically ends up actually running faster. | |
[02:14:11.120 --> 02:14:18.880] So this is another example of a 4% improvement that we've added. And this is something that also | |
[02:14:18.880 --> 02:14:23.200] torch compiled did not find for us. You would hope that torch compiled at some point could | |
[02:14:23.200 --> 02:14:28.800] figure an optimization like this out. But for now, this is it. And I also have to point out that | |
[02:14:28.800 --> 02:14:34.480] we're using PyTorch nightly. So that's why we're only seeing 4%. If you're using PyTorch 2.3.1, | |
[02:14:34.640 --> 02:14:39.040] or earlier, you would actually see something like 30% improvement, just from this change, | |
[02:14:39.040 --> 02:14:48.000] from changing it to from 50,000 to 57 to 53.04. So again, one of my favorite examples also | |
[02:14:48.000 --> 02:14:52.240] of having to understand the under the hood and how it all works, and to know what kinds of things | |
[02:14:52.240 --> 02:14:56.880] to tinker with to push the performance of your code. Okay, so at this point, we have improved | |
[02:14:56.880 --> 02:15:01.840] the performance by about 11x, right? Because we started at about 1000 milliseconds per step, | |
[02:15:01.840 --> 02:15:08.240] and we're now down to like 93 milliseconds. So that's quite good. And we're doing a much better | |
[02:15:08.240 --> 02:15:13.840] job of utilizing our GPU resources. So I'm going to now turn to more algorithmic changes | |
[02:15:13.840 --> 02:15:18.320] and improvements to the actual optimization itself. And what we would like to do is we'd | |
[02:15:18.320 --> 02:15:23.200] like to follow the hyper parameters that are mentioned in the GPT-2 or GPT-3 paper. | |
[02:15:23.200 --> 02:15:30.160] Now, sadly, GPT-2 doesn't actually say too much. It's very nice of them that they released the | |
[02:15:30.160 --> 02:15:34.720] model weights and the code, but the paper itself is extremely big as to the optimization details. | |
[02:15:34.720 --> 02:15:40.240] The code itself that they released as well, the code we've been looking at, this is just the | |
[02:15:40.240 --> 02:15:44.640] inference code. So there's no training code here and very few hyper parameters. So this doesn't | |
[02:15:44.640 --> 02:15:51.360] also tell us too much. So for that, we have to turn to the GPT-3 paper. And in the appendix of the | |
[02:15:51.360 --> 02:15:58.800] GPT-3 paper, they have a lot more hyper parameters here for us to use. And the GPT-3 paper in general | |
[02:15:58.800 --> 02:16:04.880] is a lot more detailed as to all the small details that go into the model training, | |
[02:16:04.880 --> 02:16:10.560] but GPT-3 models were never released. So GPT-2, we have the weights, but no details, | |
[02:16:10.560 --> 02:16:16.960] and GPT-3, we have lots of details, but no weights. But roughly speaking, GPT-2 and GPT-3 | |
[02:16:16.960 --> 02:16:22.880] architectures are very, very similar. And basically, there are very few changes. The context length | |
[02:16:22.880 --> 02:16:28.720] was expanded from 1024 to 2048. And that's kind of like the major change. And some of the hyper | |
[02:16:28.720 --> 02:16:32.720] parameters around the transformer have changed. But otherwise, they're pretty much the same model. | |
[02:16:32.720 --> 02:16:37.520] It's just that GPT-3 was trained for a lot longer on a bigger data set and has a lot more | |
[02:16:37.520 --> 02:16:46.880] thorough evaluations. And the GPT-3 model is 175 billion instead of 1.6 billion in the GPT-2. | |
[02:16:46.880 --> 02:16:51.600] So one story short, we're going to go to GPT-3 paper to follow along some of the hyper parameters. | |
[02:16:52.160 --> 02:16:58.880] So to train all the versions of GPT-3, we use atom with beta 1, beta 2 of 0.9 and 0.95. | |
[02:16:58.880 --> 02:17:04.480] So let's move over here and make sure that the beta's parameter, which you can see here, | |
[02:17:04.480 --> 02:17:13.680] defaults to 0.9 and 0.999 is actually set to 0.9 and 0.95. And then the epsilon parameter, | |
[02:17:13.680 --> 02:17:19.680] you can see is the default is 1 and negative 8. And this is also 1 and negative 8. Let's just | |
[02:17:19.680 --> 02:17:26.640] put it in so that we're explicit. Now, next up, they say we clip the grad global norm of the | |
[02:17:26.640 --> 02:17:32.160] gradient at 1.0. So what this is referring to is that once we calculate the gradients | |
[02:17:32.160 --> 02:17:37.600] right after last backward, we basically have the gradients at all the parameter tensors. | |
[02:17:37.600 --> 02:17:43.680] And what people like to do is basically clip them to have some kind of a maximum norm. | |
[02:17:44.800 --> 02:17:50.080] So in PyTorch, this is fairly easy to do. It's one line of code here that we have to insert right | |
[02:17:50.080 --> 02:17:57.040] after we calculate the gradients. And what this utility function is doing is it's calculating the | |
[02:17:57.040 --> 02:18:02.800] global norm of the parameters. So every single gradient on all the parameters, | |
[02:18:02.800 --> 02:18:08.480] you square it and you add it all up and you take a big square root of that. And that's the norm of | |
[02:18:08.480 --> 02:18:15.440] the parameter vector, basically. It's the length of it, if you like to look at it that way. | |
[02:18:15.440 --> 02:18:20.480] And we are basically making sure that its length is no more than 1.0 and we're going to clip it. | |
[02:18:20.480 --> 02:18:25.920] And the reason that people like to use this is that sometimes you can get unlucky during the | |
[02:18:25.920 --> 02:18:31.120] optimization. Maybe it's a bad data batch or something like that. And if you get very unlucky in the | |
[02:18:31.120 --> 02:18:36.160] batch, you might get really high loss and really high loss could lead to a really high gradient. | |
[02:18:36.160 --> 02:18:42.320] And this could basically shock your model and shock the optimization. So people like to use a | |
[02:18:42.320 --> 02:18:49.920] gradient norm clipping to prevent the model from basically getting too big of shocks in terms of | |
[02:18:49.920 --> 02:18:55.440] the gradient magnitude and the upper bounded in this way. It's a bit of a hacky solution, | |
[02:18:55.440 --> 02:19:01.440] it's a bit like a patch on top of deeper issues, but people still do it fairly frequently. | |
[02:19:02.000 --> 02:19:08.080] Now the clip grad norm returns the norm of the gradient, which I like to always visualize | |
[02:19:08.080 --> 02:19:14.480] because it is useful information. And sometimes you can look at the norm of the gradient. And if | |
[02:19:14.480 --> 02:19:18.880] it's well behaved, things are good. If it's climbing, things are bad and they're destabilizing during | |
[02:19:18.880 --> 02:19:22.880] training, sometimes you could get a spike in the norm. And that means there's some kind of | |
[02:19:22.880 --> 02:19:33.840] an issue or instability. So the norm here will be a norm. And let's do a 0.4 f or something like that. | |
[02:19:33.840 --> 02:19:43.120] And I believe this is just a float. And so we should be able to print that. So that's global | |
[02:19:43.120 --> 02:19:49.920] gradient clipping. Now they go into the details of the learning rate scheduler. So they don't just | |
[02:19:49.920 --> 02:19:55.200] use a fixed learning rate like we do here for three negative four, but there's actually basically | |
[02:19:55.200 --> 02:20:03.280] a cosine decay learning rate schedule. It's got a warm up and it's got a cosine decay to 10% | |
[02:20:03.280 --> 02:20:11.760] over some horizon. And so we're going to implement this in a second. I just like to see the norm | |
[02:20:11.760 --> 02:20:17.200] printed here. Okay, there we go. So what happened here is the norm is actually really high in the | |
[02:20:17.200 --> 02:20:22.640] beginning 30 or so. And you see that as we continue training, it kind of stabilizes | |
[02:20:22.640 --> 02:20:31.200] at values below one. And this is not that crazy uncommon for the norm to be high in the very | |
[02:20:31.200 --> 02:20:35.280] first few stages. Basically what's happening here is the model is completely random. And so | |
[02:20:35.280 --> 02:20:39.520] there's a ton of learning happening very early in the network. But that learning is kind of like, | |
[02:20:39.520 --> 02:20:45.920] you know, it's mostly learning the biases of the output tokens. And so it's a bit of an unstable | |
[02:20:45.920 --> 02:20:50.880] time. But the network usually stabilizes in the very few iterations. So this looks relatively | |
[02:20:50.880 --> 02:20:55.760] reasonable to me, except usually I would expect this looks a little bit funky that we go from 28 | |
[02:20:55.760 --> 02:21:04.240] to 62 and then to 10. It's not completely insane, but it's just kind of a little bit funky. Okay, | |
[02:21:04.240 --> 02:21:08.480] so let's now get to the learning rate scheduler. So the learning rate schedule that's used here in | |
[02:21:08.480 --> 02:21:15.280] GPT three is what's called a cosine decay learning schedule with warm up. And the way this looks is | |
[02:21:15.280 --> 02:21:21.680] that the learning rate is basically starts right at around zero, linearly ramps up over some amount | |
[02:21:21.680 --> 02:21:28.080] of time, and then comes down with this cosine sort of form, and comes down to some kind of a | |
[02:21:28.080 --> 02:21:33.280] minimum learning rate that's up to you. So here the minimum learning rate is zero. But here in the | |
[02:21:33.280 --> 02:21:38.880] paper, they said that they use cosine decay for learning rate down to 10% of its value over the | |
[02:21:38.880 --> 02:21:46.160] first 260 billion tokens, and then training continues 10% after. And there's a linear warm-up | |
[02:21:46.160 --> 02:21:52.080] over the first 375 million tokens. So that's about the learning rate. So let's now implement this. | |
[02:21:52.080 --> 02:21:59.040] So I already implemented it here. And the way this works is, let me scroll down first here, | |
[02:21:59.040 --> 02:22:04.400] I changed our training loop a little bit. So this was a 4i in max steps. I just change it to step | |
[02:22:04.400 --> 02:22:11.120] now, so that we have the notion of a step is a single optimization step in the for loop. And then | |
[02:22:11.120 --> 02:22:18.000] here, I get the LR for this step of the optimization using a new function I call get LR. And then in | |
[02:22:18.000 --> 02:22:22.320] PyTorch to set the learning rate, I think this is the way to set the learning rate. It's a little | |
[02:22:22.320 --> 02:22:27.280] bit gnarly, because you have to basically there's notion of different parameter groups that could | |
[02:22:27.280 --> 02:22:31.040] exist in the optimizer. And so you actually have to iterate over them, even though we currently have | |
[02:22:31.040 --> 02:22:36.640] a single parameter group only. And you have to set the LR in this for loop kind of style, | |
[02:22:36.640 --> 02:22:41.920] is my impression right now. So we have this local for LR, we set the learning rate, | |
[02:22:41.920 --> 02:22:47.600] and then on the bottom also printing it. So that's all the changes I made to this loop. | |
[02:22:47.600 --> 02:22:52.240] And then of course the get LR is my scheduler. Now it's worth pointing out that PyTorch actually | |
[02:22:52.240 --> 02:22:56.880] has learning rate schedulers, and you can use them. And I believe there's a cosine learning rate | |
[02:22:56.880 --> 02:23:02.640] schedule in PyTorch. I just don't really love using that code, because honestly, | |
[02:23:02.640 --> 02:23:08.240] it's like five lines of code. And I fully understand what's happening inside these lines. So I don't | |
[02:23:08.240 --> 02:23:13.200] love to use abstractions where they're kind of in scribble. And then I don't know what they're | |
[02:23:13.200 --> 02:23:19.360] doing. So personal style. So the max learning rate here is let's say three negative four, | |
[02:23:19.360 --> 02:23:26.000] but we're going to see that in GPT three here, they have a table of what the maximum learning | |
[02:23:26.000 --> 02:23:36.880] rate is for every model size. So for, for this one, basically 12 12 layer 768 GPT three. So the | |
[02:23:36.880 --> 02:23:42.720] GPT three small is roughly like a GPT two one 24m. We see that here, they use a learning rate | |
[02:23:42.720 --> 02:23:47.200] of six E negative four. So we could actually go higher. In fact, we may want to try to follow | |
[02:23:47.200 --> 02:23:54.400] that. So just set the max LR here, six. Then the maximum learning rate, the middle learning rate is | |
[02:23:55.360 --> 02:24:01.600] 10% of that per description in the paper, some number of steps that we're going to warm up over, | |
[02:24:01.600 --> 02:24:06.000] and then the maximum steps of the optimization, which I now use also in the for look down here. | |
[02:24:06.000 --> 02:24:12.240] And then you can go for this code if you like. It's not, it's not terribly inside floor interesting. | |
[02:24:12.240 --> 02:24:17.600] I'm just modulating based on the iteration number of which learning rate there should be. | |
[02:24:17.600 --> 02:24:24.560] So this is the warm up region. This is the region after the optimization. And then this is the region | |
[02:24:24.560 --> 02:24:29.280] sort of in between. And this is where I calculate the cosine learning rate schedule. And you can | |
[02:24:29.280 --> 02:24:33.360] step through this in detail if you'd like. But this is basically implementing this curve. | |
[02:24:33.360 --> 02:24:39.040] And I ran this already. And this is what that looks like. | |
[02:24:39.040 --> 02:24:49.040] So when we now run, we start at some very low number. Now note that we don't start exactly at zero, | |
[02:24:49.040 --> 02:24:52.880] because that would be not useful to update with a learning rate of zero. That's why there's an | |
[02:24:52.880 --> 02:24:58.000] it plus one, so that on the zero iteration, we are not using exactly zero. We're using something | |
[02:24:58.000 --> 02:25:03.040] very, very low. Then we linearly warm up to maximum learning rate, which in this case was | |
[02:25:03.040 --> 02:25:08.400] three negative four when I ran it, but now would be six negative four. And then it starts to decay | |
[02:25:08.400 --> 02:25:15.040] all the way down to three negative five, which was at the time 10% of the original learning rate. | |
[02:25:15.040 --> 02:25:18.720] Now one thing we are not following exactly is that they mentioned that | |
[02:25:21.280 --> 02:25:26.960] let me see if I can find it again. We're not exactly following what they did because | |
[02:25:26.960 --> 02:25:34.400] they mentioned that their training horizon is 300 billion tokens. And they come down to 10% | |
[02:25:34.400 --> 02:25:41.440] of the initial learning rate of at 260 billion. And then they train after 260 with 10%. So basically | |
[02:25:41.440 --> 02:25:46.560] their decay time is less than the max steps time, whereas for us, they're exactly equal. | |
[02:25:46.560 --> 02:25:54.400] So it's not exactly faithful, but it's an okay. This is okay for us and for our purposes right now. | |
[02:25:54.400 --> 02:26:00.320] And we're just going to use this ourselves. I don't think it makes too big of a difference, | |
[02:26:00.320 --> 02:26:05.440] honestly. I should point out that what learning rate schedule you use is totally up to you. | |
[02:26:05.440 --> 02:26:11.280] There's many different types. Cosine learning rate has been popularized a lot by GPT-2 and GPT-3, | |
[02:26:11.280 --> 02:26:15.760] but people have come up with all kinds of other learning rate schedules. And this is kind of | |
[02:26:15.760 --> 02:26:20.880] what an active area of research as to which one is the most effective at training these networks. | |
[02:26:20.880 --> 02:26:28.160] Okay, next up, the paper talks about the gradual batch size increase. So there's a ramp on the | |
[02:26:28.160 --> 02:26:32.800] batch size that is linear. And you start with very small batch size and you ramp up to a big | |
[02:26:32.800 --> 02:26:37.600] batch size over time. We're going to actually skip this and we're not going to work with it. | |
[02:26:37.600 --> 02:26:42.000] And the reason I don't love to use it is that it complicates a lot of the arithmetic because | |
[02:26:42.000 --> 02:26:46.400] you are changing the number of tokens that you're processing at every single step of the optimization. | |
[02:26:46.400 --> 02:26:50.800] And I like to keep that math very, very simple. Also, my understanding is that this is not like | |
[02:26:50.800 --> 02:26:57.280] a major improvement. And also, my understanding is that this is not like an algorithmic optimization | |
[02:26:57.280 --> 02:27:02.080] improvement. It's more of a systems and speed improvement. And roughly speaking, this is because | |
[02:27:02.080 --> 02:27:10.000] in the early stages of the optimization, again, the model is in a very atypical setting. | |
[02:27:10.000 --> 02:27:15.520] And mostly what you're learning is that you're mostly learning to ignore the tokens that don't | |
[02:27:15.520 --> 02:27:20.880] come up in your training set very often. You're learning very simple biases and that kind of a | |
[02:27:20.880 --> 02:27:27.120] thing. And so every single example that you put through your network is basically just telling you | |
[02:27:27.120 --> 02:27:31.200] use these tokens and don't use these tokens. And so the gradients from every single example are | |
[02:27:31.200 --> 02:27:36.960] actually extremely highly correlated. They all look roughly the same in the original parts of | |
[02:27:36.960 --> 02:27:40.800] the optimization because they're all just telling you that these tokens don't appear and these tokens | |
[02:27:40.800 --> 02:27:46.240] do appear. And so because the gradients are all very similar and they're highly correlated, | |
[02:27:46.240 --> 02:27:51.840] then why are you doing batch sizes of like millions when if you do a batch size of 32K, | |
[02:27:51.840 --> 02:27:57.040] you're basically getting the exact same gradient early on in the training. And then later in the | |
[02:27:57.040 --> 02:28:01.600] optimization, once you've learned all the simple stuff, that's where the actual work starts. And | |
[02:28:01.600 --> 02:28:06.160] that's where the gradients become more decorrelated, for examples. And that's where they actually offer | |
[02:28:06.160 --> 02:28:12.400] you sort of statistical power in some sense. So we're going to skip this just because it kind of | |
[02:28:12.400 --> 02:28:19.440] complicates things. And we're going to go to data are sampled without replacement during training. | |
[02:28:19.440 --> 02:28:24.880] So until an epoch boundary is reached. So without replacement means that they're not | |
[02:28:24.880 --> 02:28:31.840] sampling from some fixed pool and then take a sequence, train on it, but then also like return | |
[02:28:31.840 --> 02:28:37.600] to sequence the pool, they are exhausting a pool. So when they draw a sequence, it's it's gone until | |
[02:28:37.600 --> 02:28:44.400] the next epoch of training. So we're already doing that because our data loader iterates over chunks | |
[02:28:44.400 --> 02:28:49.920] of data. So there's no replacement. They don't become eligible to be drawn again until the next | |
[02:28:49.920 --> 02:28:57.600] epoch. So we're basically already doing that. All models use a weight decay of 0.1 to provide | |
[02:28:57.600 --> 02:29:02.560] a small amount of regularization. So let's implement a weight decay. And you see here | |
[02:29:02.560 --> 02:29:07.040] that I've already kind of made the changes. And in particular, instead of creating the optimizer | |
[02:29:07.040 --> 02:29:14.320] right here, I'm creating a new configure optimizer function inside the model. And I'm passing in some | |
[02:29:14.320 --> 02:29:18.880] of the hyper parameters instead. So let's look at the configure optimizers, which is supposed to | |
[02:29:18.880 --> 02:29:28.400] return the optimizer object. Okay, so it looks complicated, but it's actually really simple. | |
[02:29:28.400 --> 02:29:33.280] And it's just, or just being very careful. And there's a few settings here to go through. | |
[02:29:33.280 --> 02:29:37.600] The most important thing with respect to this line is that you see there's a weight decay | |
[02:29:37.600 --> 02:29:45.200] parameter here. And I'm passing that into, well, I'm passing that into something called | |
[02:29:45.200 --> 02:29:50.720] Optum groups that eventually ends up going into the add and W optimizer. And the weight decay | |
[02:29:50.720 --> 02:29:57.680] that's by default used in add and W here is 0.01. So it's, it's 10 times lower than what's used in | |
[02:29:57.680 --> 02:30:04.160] GPT three paper here. So the weight decay basically ends up making its way into the add and W three | |
[02:30:04.160 --> 02:30:09.280] optimizer groups. Now what else is going on here in this function? So the two things that are happening | |
[02:30:09.280 --> 02:30:14.000] here that are important is that I'm splitting up the parameters into those that should be weight | |
[02:30:14.000 --> 02:30:18.960] decay and those that should not be weight decay. So in particular, it is common to not | |
[02:30:18.960 --> 02:30:25.840] weight decay biases and any other sort of one dimensional tensors. So the one dimensional | |
[02:30:25.840 --> 02:30:31.840] tensors are in the node decay params. And these are also things like layer norm, | |
[02:30:31.840 --> 02:30:36.240] scales and biases. It doesn't really make sense to weight decay those. You mostly want to weight | |
[02:30:36.240 --> 02:30:41.920] decay the weights that participate in matrix multiplications. And you want to potentially | |
[02:30:41.920 --> 02:30:48.000] weight decay the embeddings. And we've covered in previous video why it makes sense to decay the | |
[02:30:48.000 --> 02:30:52.320] weights because you can sort of think of it as a regularization because when you're pulling down | |
[02:30:52.320 --> 02:30:58.080] all the weights, you're forcing the optimization to use more of the weights. And you're not allowing | |
[02:30:58.080 --> 02:31:03.920] any one of the weights individually to be way too large. You're forcing your forcing the network | |
[02:31:03.920 --> 02:31:09.280] to kind of distribute the work across more channels because there's sort of like a pool of gravity | |
[02:31:09.280 --> 02:31:16.240] on the weights themselves. So that's why we are separating it in those ways here. We're only | |
[02:31:16.240 --> 02:31:23.120] decaying the embeddings and the matmole participating weights. We're printing the number of parameters | |
[02:31:23.120 --> 02:31:27.360] that we're decaying and not. Most of the parameters will be decayed. And then one more thing that | |
[02:31:27.360 --> 02:31:35.040] we're doing here is I'm doing another optimization here. And previous Adam W did not have this option, | |
[02:31:35.040 --> 02:31:40.160] but the later parts of PyTorch introduced it. And that's why I'm guarding it with an inspect | |
[02:31:40.160 --> 02:31:48.160] that signature, which is basically checking if this fused quad is present inside Adam W. | |
[02:31:48.160 --> 02:31:53.840] And then if it is present, I'm going to end up using it and passing it in here because some | |
[02:31:53.840 --> 02:32:00.320] earlier versions do not have fused equals. So here's Adam W fused equals. It did not used to exist | |
[02:32:00.320 --> 02:32:06.240] and it was added later. And there's some docs here for what's happening. And basically they say that | |
[02:32:06.240 --> 02:32:11.200] by default they do not use fused because it is relatively new and we want to give it sufficient | |
[02:32:11.200 --> 02:32:16.240] peak time. So by default, they don't use fused. But fused is a lot faster when it is available | |
[02:32:16.240 --> 02:32:21.840] and when you're running on CUDA. And what that does is instead of iterating in a for loop over all | |
[02:32:21.840 --> 02:32:28.000] the parameter tensors and updating them, that would launch a lot of kernels, right? And so fused | |
[02:32:28.000 --> 02:32:34.000] just means that all those kernels are fused into a single kernel. You get rid of a lot of overhead | |
[02:32:34.000 --> 02:32:42.000] and you a single time on all the primers call a kernel that updates them. And so it's just basically | |
[02:32:42.000 --> 02:32:48.480] kernel fusion for the Adam W update instead of iterating over all the tensors. So that's the | |
[02:32:48.480 --> 02:32:54.080] configure optimizers function that I like to use. And we can rerun and we're not going to see any | |
[02:32:54.080 --> 02:32:59.760] major differences from what we saw before, but we are going to see some prints coming from here. | |
[02:32:59.760 --> 02:33:05.840] So let's just take a look at what they look like. So we see that number of decay tensors is 50 | |
[02:33:05.840 --> 02:33:10.000] and it's most of the parameters. A number of non-decade tensors is 98. And these are | |
[02:33:10.000 --> 02:33:16.000] devices in the layer norm parameters mostly. And that's there's only 100,000 of those. So most of | |
[02:33:16.000 --> 02:33:22.160] it is decayed. And then we are using the fused implementation of Adam W, which will be a lot faster. | |
[02:33:22.160 --> 02:33:26.720] So if you have it available, I would advise you to use it. I'm not actually 100% sure why they | |
[02:33:26.720 --> 02:33:31.680] don't default to it. It seems fairly benign and harmless. And also because we are using the | |
[02:33:31.680 --> 02:33:38.080] fused implementation, I think this is why we have dropped. Notice that the running time is to be | |
[02:33:38.080 --> 02:33:43.760] 93 milliseconds per step. And we're now down to 90 milliseconds per step because of using the fused | |
[02:33:43.760 --> 02:33:51.440] Adam W optimizer. So in a single commit here, we are introducing fused Adam getting improvements | |
[02:33:51.440 --> 02:33:56.800] on the time. And we're adding or changing the weight decay, but we're only weight decaying the | |
[02:33:56.800 --> 02:34:00.800] two dimensional parameters, the embeddings, and the matrices that participated in linear. | |
[02:34:00.800 --> 02:34:10.320] So that is this and we can take this out. And yeah, that is it for this line. One more quick note | |
[02:34:10.320 --> 02:34:14.720] before we continue here, I just want to point out that the relationship between weight decay, | |
[02:34:14.720 --> 02:34:20.320] learning rate, batch size, the Adam parameters beta 1, beta 2, the epsilon and so on, these are | |
[02:34:20.320 --> 02:34:27.520] very complicated mathematical relationships in the optimization literature. And for the most part, | |
[02:34:27.520 --> 02:34:31.680] I'm in this video, I'm just trying to copy paste the settings that open AI used. But this is a | |
[02:34:31.680 --> 02:34:37.840] complicated topic quite deep. And yeah, in this video, I just want to copy the parameters because | |
[02:34:37.840 --> 02:34:41.520] it's a whole different video to really talk about that in detail and give it a proper | |
[02:34:41.520 --> 02:34:46.320] justice instead of just high level intuitions. And now the next thing that I want to move on to | |
[02:34:46.320 --> 02:34:51.360] is that this paragraph here, by the way, we're going to turn back around to when we | |
[02:34:51.360 --> 02:34:57.760] improve our data load. For now, I want to swing back around to this table, | |
[02:34:57.760 --> 02:35:07.120] where you will notice that for different models, we, of course, have different | |
[02:35:07.120 --> 02:35:11.840] high parameters for the transformer that dictate the size of the transformer network. We also have | |
[02:35:11.840 --> 02:35:15.200] a different learning rate. So we're seeing the pattern that the bigger networks are trained by | |
[02:35:15.200 --> 02:35:21.840] slightly lower learning rates. And we also see this batch size, where in the small networks, | |
[02:35:21.840 --> 02:35:25.360] they use a smaller batch size, and then the bigger networks, they use a bigger batch size. | |
[02:35:25.360 --> 02:35:32.720] Now the problem with for us is we can't just use 0.5 million batch size, because if I just try to | |
[02:35:32.720 --> 02:35:47.920] come in here, and I try to set this b, where's my b? Where do I call the two? Okay, b equals 16. | |
[02:35:47.920 --> 02:35:54.320] If I try to set, well, we have to be careful, it's not 0.5 million, because this is the batch size | |
[02:35:54.320 --> 02:36:02.320] in the number of tokens. Every single one of our rows is 1024 tokens. So 0.5 e6, 1 million | |
[02:36:02.320 --> 02:36:10.480] divide 1024. This would need about a 488 batch size. So the problem is I can't come in here and set | |
[02:36:10.480 --> 02:36:17.600] this to four eight eight, because my GPU would explode. This would not fit for sure. And so, | |
[02:36:17.600 --> 02:36:23.600] but we still want to use this batch size, because again, as I mentioned, the batch size is correlated | |
[02:36:23.600 --> 02:36:28.160] with all the other optimization high parameters, and the learning rates and so on. So we want to | |
[02:36:28.160 --> 02:36:33.200] have a faithful representation of all the hyper parameters. And therefore, we need to use a batch | |
[02:36:33.200 --> 02:36:39.360] size of 0.5 million roughly. But the question is, how do we use 0.5 million if we only have a small | |
[02:36:39.360 --> 02:36:44.640] GPU? Well, for that, we need to use what's called gradient accumulation. So we're going to turn | |
[02:36:44.640 --> 02:36:50.480] to that next, and it allows us to simulate in a serial way, any arbitrary batch size that we set. | |
[02:36:50.480 --> 02:36:56.720] And so we can do a batch size of 0.5 million, we just have to run longer, and we have to process | |
[02:36:56.720 --> 02:37:02.960] multiple sequences, and basically add up all the gradients from them to simulate a batch size of | |
[02:37:02.960 --> 02:37:07.360] 0.5 million. So let's turn to that next. Okay, so I started the implementation right here, | |
[02:37:07.360 --> 02:37:13.040] just by adding these lines of code. And basically, what I did is first, I set the total batch size | |
[02:37:13.040 --> 02:37:18.560] that we desire. So this is exactly 0.5 million. And I used a nice number, a power of two, | |
[02:37:18.560 --> 02:37:24.480] because two to the 19 is a five 24 to eight. So it's roughly 0.5 million in some nice number. | |
[02:37:25.360 --> 02:37:32.160] Now our micro batch size, as we call it now, is 16. So this is going to be, we still have b by t | |
[02:37:32.160 --> 02:37:37.200] in the seats that go into the transformer and do forward backward. But we're not going to do an | |
[02:37:37.200 --> 02:37:41.840] update, right? We're going to do many forward backwards. We're going to, and those gradients | |
[02:37:41.840 --> 02:37:46.160] are all going to plus equals on the parameter gradients, they're all going to add up. So we're | |
[02:37:46.160 --> 02:37:51.280] going to do forward backward grad Akum steps number of times. And then we're going to do a single | |
[02:37:51.280 --> 02:37:57.280] update once all that is accumulated. So in particular, our micro batch size is just now | |
[02:37:57.280 --> 02:38:01.840] controlling how many tokens, how many rows we're processing in a single go of a forward backward. | |
[02:38:01.840 --> 02:38:09.280] So here, we are doing 16 times one 24, we're doing 16, 384 | |
[02:38:09.280 --> 02:38:17.040] tokens per forward backward. And we are supposed to be doing two to the 19, | |
[02:38:17.600 --> 02:38:24.000] oops, what am I doing? Two to the 19 in total. So the grad Akum will be 32. | |
[02:38:24.000 --> 02:38:32.560] So therefore, grad Akum here will work out to 32. And we have to do 32 forward backward, | |
[02:38:32.560 --> 02:38:38.640] and then a single update. Now we see that we have about 100 milliseconds for a single forward | |
[02:38:38.640 --> 02:38:45.520] backward. So doing 32 of them will be, we'll make every step roughly three seconds, just napkin math. | |
[02:38:47.120 --> 02:38:51.360] So that's grad Akum steps, but now we actually like to implement that. So we're going to swing | |
[02:38:51.360 --> 02:38:59.280] over to our training loop, because now this part here, and this part here, the forward and the | |
[02:38:59.280 --> 02:39:06.000] backward, we have to now repeat this 32 times before we do everything else that follows. So let's | |
[02:39:06.000 --> 02:39:10.800] see how we can implement that. So let's come over here. And actually, we do have to load a new batch | |
[02:39:10.800 --> 02:39:15.600] every single time. So let me move that over here. And now this is where we have the inner loop. So | |
[02:39:15.600 --> 02:39:25.040] for micro step in range, grad Akum steps, we do this. And remember that loss of backward always | |
[02:39:25.040 --> 02:39:29.520] deposits gradients. So we're doing inside loss of backward, there's always a plus equals on the | |
[02:39:29.520 --> 02:39:34.720] gradients. So in every single loss of backward gradients will add up on the gradient testers. | |
[02:39:34.720 --> 02:39:43.120] So we lost a backward, and then we get all the gradients over there. And then we normalize | |
[02:39:43.120 --> 02:39:48.800] and everything else should just follow. So we're very close. But actually, there's like | |
[02:39:48.800 --> 02:39:55.360] subtle and deep issue here. And this is actually incorrect. So I invite you to think about why | |
[02:39:55.360 --> 02:40:01.440] this is not yet sufficient. And let me fix it then. Okay, so I brought back the Jupyter Notebook. So | |
[02:40:01.440 --> 02:40:07.120] we can think about this carefully in a simple toy setting and see what's happening. So let's | |
[02:40:07.120 --> 02:40:12.320] create a very simple neural nut that takes a 16 vector of 16 numbers and returns a single number. | |
[02:40:13.600 --> 02:40:21.520] And then here I'm creating some random examples x and some targets y. And then we are using the | |
[02:40:21.520 --> 02:40:28.640] mean squared loss here to calculate the loss. So basically what this is is four individual | |
[02:40:28.640 --> 02:40:34.080] examples. And we're just doing simple regression with the mean squared loss over those four examples. | |
[02:40:34.080 --> 02:40:38.800] Now when we calculate the loss and we lost that backward and look at the gradient, | |
[02:40:39.360 --> 02:40:45.600] this is the gradient that we achieve. Now the loss objective here, notice that in MSC loss, | |
[02:40:45.600 --> 02:40:52.800] the default for the loss function is reduction is mean. So we're calculating the average mean loss, | |
[02:40:52.800 --> 02:41:01.440] the mean loss here over the four examples. So this is the exact loss objective. And this is the | |
[02:41:01.440 --> 02:41:06.400] average, the one over four, because there are four independent examples here. And then we have the | |
[02:41:07.040 --> 02:41:12.000] four examples and their mean squared error, the squared error, and then this makes it the mean | |
[02:41:12.000 --> 02:41:18.240] squared error. So therefore, we are, we calculate the squared error, and then we normalize it to | |
[02:41:18.240 --> 02:41:22.880] make it the mean over the examples. And there's four examples here. So now when we come to the | |
[02:41:22.880 --> 02:41:31.200] gradient accumulation version of it, this, this here is the gradient accumulation version of it, | |
[02:41:31.200 --> 02:41:36.320] where we have grad acronym steps of four, and I reset the gradient, we've got a school steps of | |
[02:41:36.320 --> 02:41:40.880] four. And now I'm evaluating all the examples individually instead and calling the loss that | |
[02:41:40.880 --> 02:41:45.440] backward on them many times. And then we're looking at the gradient that we achieve from that. So | |
[02:41:45.440 --> 02:41:51.280] basically now we forward our function, calculate the exact same loss, do a backward, and we do that | |
[02:41:51.280 --> 02:41:56.560] four times. And when we look at the gradient, you'll notice that the gradients don't match. | |
[02:41:56.560 --> 02:42:04.400] So here we did a single batch of four. And here we did four gradient accumulation steps | |
[02:42:04.400 --> 02:42:09.760] of batch size one. And the gradients are not the same. And basically the reason that they're not | |
[02:42:09.760 --> 02:42:15.600] the same is exactly because this mean squared error gets lost. This one quarter in this loss | |
[02:42:15.600 --> 02:42:22.240] gets lost, because what happens here is the loss objective for every one of the loops is just a | |
[02:42:22.240 --> 02:42:27.840] mean squared error, which in this case, because there's a long single example is just this term | |
[02:42:27.840 --> 02:42:33.840] here. So that was the loss in the zero iteration, the same in the first third and so on. And then | |
[02:42:33.840 --> 02:42:39.840] when you do the last backward, we're accumulating gradients. And what happens is that accumulation | |
[02:42:39.840 --> 02:42:48.080] the gradient is basically equivalent to doing a sum in the loss. So our loss actually here is | |
[02:42:48.080 --> 02:42:55.760] this without the factor of one quarter outside of it. So we're missing the normalizer. And therefore | |
[02:42:55.760 --> 02:43:00.240] our gradients are off. And so the way to fix this or one of them is basically we can actually come | |
[02:43:00.240 --> 02:43:07.680] here and we can say loss equals loss divide four. And what happens now is that we're introducing | |
[02:43:07.680 --> 02:43:12.560] where we're scaling our loss, we're introducing them one quarter in front of all of these places. | |
[02:43:12.560 --> 02:43:20.000] So all the individual losses are now scaled by one quarter. And then when we backward, | |
[02:43:20.000 --> 02:43:25.520] all of these accumulate with a sum. But now there's a one quarter inside every one of these | |
[02:43:25.520 --> 02:43:33.760] components. And now our losses will be equivalent. So when I run this, you see that the gradients | |
[02:43:33.760 --> 02:43:39.360] are now identical. So long story short, with this simple example, when you step through it, | |
[02:43:39.360 --> 02:43:44.560] you can see that basically the reason that this is not correct is because in the same way as | |
[02:43:44.560 --> 02:43:54.560] here in the MSC loss, the loss that we're calculating here in the model is using a reduction of mean | |
[02:43:54.560 --> 02:44:01.760] as well. So where's the loss after that cross entropy? And by default, the reduction here in | |
[02:44:01.760 --> 02:44:07.520] cross entropy is also, I don't know why they don't show it, but it's the mean, the mean loss at all | |
[02:44:07.520 --> 02:44:15.280] the B by T elements, right? So there's a reduction by mean in there. And if we're just doing this | |
[02:44:15.280 --> 02:44:20.160] gradient accumulation here, we're missing that. And so the way to fix this is to simply compensate | |
[02:44:20.160 --> 02:44:24.320] for the number of gradient accumulation steps, and we can in the same way divide this loss. | |
[02:44:24.320 --> 02:44:30.960] So in particular here, the number of steps that we're doing is plus equals loss divided | |
[02:44:30.960 --> 02:44:36.960] gradient accumulation steps. So even a co-pilot, sorry, gets the modification. But in the same way | |
[02:44:36.960 --> 02:44:42.240] exactly, we are scaling down the loss so that when we do lost a backward, which basically corresponds | |
[02:44:42.240 --> 02:44:49.200] to a sum in the objective, we are summing up the already normalized loss. And therefore, | |
[02:44:49.200 --> 02:44:54.800] when we sum up the losses divided by radical steps, we are recovering the additional normalizer. | |
[02:44:54.800 --> 02:45:02.320] And so now these two will be now this will be equivalent to the original sort of optimization, | |
[02:45:02.320 --> 02:45:06.560] because the gradient will come out the same. Okay, so I had to do a few more touch ups, | |
[02:45:06.560 --> 02:45:10.800] and I launched launched the optimization here. So in particular, one thing we want to do, | |
[02:45:10.800 --> 02:45:16.080] because we want to print things nicely, is well, first of all, we need to create like an accumulator | |
[02:45:16.080 --> 02:45:19.920] over the loss, we can't just print the loss, because we'd be printing only the final loss | |
[02:45:19.920 --> 02:45:24.800] at the final micro step. So instead, we have a loss of coom, which I initialized at zero, | |
[02:45:24.800 --> 02:45:32.560] and then I accumulate a the loss into it. And I'm using detached so that I'm detaching the tensor | |
[02:45:32.560 --> 02:45:38.560] from the graph, and I'm just trying to keep track of the values. So I'm making these leaf nodes | |
[02:45:38.560 --> 02:45:43.840] when I add them. So that's loss of coom, and then we're printing that here instead of loss. | |
[02:45:43.840 --> 02:45:49.120] And then in addition to that I have to account for the gradicum steps inside the tokens processed, | |
[02:45:49.120 --> 02:45:53.920] because now the tokens processed per step is b times t times gradient accumulation. | |
[02:45:53.920 --> 02:46:00.800] So once we're short, here we have the optimization, it looks reasonable, right, we're starting at a | |
[02:46:00.800 --> 02:46:07.280] good spot, we calculated the gradicum steps to be 32. And we're getting about three seconds here, | |
[02:46:07.280 --> 02:46:17.280] right? And so this looks pretty good. Now if you'd like to verify that your optimization | |
[02:46:17.280 --> 02:46:21.280] and the implementation here is correct, and you're working on a side, well, now because we have the | |
[02:46:21.280 --> 02:46:26.640] total back size and the gradient accumulation steps, our setting of b is purely a performance | |
[02:46:26.640 --> 02:46:32.080] optimization kind of setting. So if you have a big GPU, you can actually increase this to 32, | |
[02:46:32.080 --> 02:46:36.320] and you'll probably go a bit faster. If you have a very small GPU, you can try eight or four, | |
[02:46:36.880 --> 02:46:41.200] but in any case, you should be getting the exact same optimization and the same answers up to | |
[02:46:41.200 --> 02:46:47.360] what a floating point error, because the gradient accumulation kicks in and can handle everything | |
[02:46:47.360 --> 02:46:54.080] serially as necessary. So that's it for gradient accumulation, I think. Okay, so now is the time | |
[02:46:54.080 --> 02:46:59.120] to bring out the heavy weapons. You've noticed that so far, we've only been using a single GPU | |
[02:46:59.120 --> 02:47:04.960] for training, but actually I am paying for a GPUs here. And so we should be putting all of them to | |
[02:47:04.960 --> 02:47:11.760] work. And in particular, they're all going to collaborate and optimize over tokens at the same | |
[02:47:11.760 --> 02:47:18.160] time and communicate so that they're all kind of collaborating on the optimization. For this, | |
[02:47:18.160 --> 02:47:22.160] we are going to be using the distributed data parallel from PyTorch. There's also a legacy | |
[02:47:22.160 --> 02:47:27.840] data parallel, which I recommend you not use. And that's kind of like legacy. Distribute data | |
[02:47:27.840 --> 02:47:34.240] parallel works in a very simple way. We have eight GPUs. So we're going to launch eight processes. | |
[02:47:34.240 --> 02:47:40.320] And each process is going to be assigned a GPU. And for each process, the training loop and | |
[02:47:40.320 --> 02:47:44.640] everything we've worked on so far is going to look pretty much the same. Each GPU, as far as | |
[02:47:44.640 --> 02:47:49.680] it's concerned, is just working on exactly what we've built so far. But now secretly, there's eight | |
[02:47:49.680 --> 02:47:55.120] of them, and they're all going to be processing slightly different parts of the data. And we're | |
[02:47:55.120 --> 02:48:00.080] going to add one more part where once they all calculate their gradients, there's one more part | |
[02:48:00.080 --> 02:48:06.400] where we do a average of those gradients. And so that's how they're going to be collaborating | |
[02:48:06.400 --> 02:48:12.720] on the computational workload here. So to use all eight of them, we're not going to be launching | |
[02:48:12.720 --> 02:48:19.440] our script anymore with just PyTorch train GPT2.py. We're going to be running it with a special | |
[02:48:19.440 --> 02:48:25.760] command called torch run in PyTorch. We'll see them in a bit. And torch run, when it runs our | |
[02:48:25.760 --> 02:48:32.560] Python script, will actually make sure to run eight of them in parallel. And it creates these | |
[02:48:32.560 --> 02:48:39.040] environmental variables where each of these processes can look up which basically which one | |
[02:48:39.040 --> 02:48:45.840] of the processes it is. So for example, torch run will set rank local rank in world size environmental | |
[02:48:45.840 --> 02:48:52.000] variables. And so this is a bad way to detect whether DDP is running. So if we're using torch | |
[02:48:52.000 --> 02:48:58.800] run, if DDP is running, then we have to make sure that good is available because I don't know that | |
[02:48:58.800 --> 02:49:06.960] you can run this on CPU anymore, or that that makes sense to do. This is some setup code here. | |
[02:49:06.960 --> 02:49:12.080] The important part is that there's a world size, which for us will be eight. That's the total number | |
[02:49:12.080 --> 02:49:19.680] of processes running. There's a rank which is each process will basically run the exact same code | |
[02:49:19.680 --> 02:49:25.040] at the exact same time roughly. But all the process, the only difference between these processes | |
[02:49:25.040 --> 02:49:31.920] is that they all have a different DDP rank. So the GPU zero will have DDP rank of zero, | |
[02:49:31.920 --> 02:49:37.200] GPU one will have a rank of one, et cetera. So otherwise they're all running the exact | |
[02:49:37.200 --> 02:49:42.000] same script. It's just that DDP rank will be a slightly different integer. And that is the way | |
[02:49:42.000 --> 02:49:47.120] for us to coordinate that they don't, for example, run on the same data. We want them to run on | |
[02:49:47.120 --> 02:49:53.600] different parts of the data and so on. Now local rank is something that is only used in a multi-node | |
[02:49:53.600 --> 02:50:00.480] setting. We only have a single node with a GPU. And so local rank is the rank of the GPU on a single | |
[02:50:00.480 --> 02:50:06.960] node. So from zero to seven as an example. But for us, we're mostly going to be running on a single | |
[02:50:06.960 --> 02:50:12.720] box. So the things we care about are rank and world size. This is eight. And this will be whatever | |
[02:50:12.720 --> 02:50:18.400] it is depending on the GPU that that this particular instantiation of the script runs on. | |
[02:50:18.400 --> 02:50:27.520] Now here, we make sure that according to the local rank, we are setting the device to be | |
[02:50:27.520 --> 02:50:33.360] could a column and column indicates which GPU to use if there are more than one GPUs. | |
[02:50:33.360 --> 02:50:40.400] So depending on the local rank of this process, it's going to use just the appropriate GPU. So | |
[02:50:40.400 --> 02:50:45.520] there's no collisions on which GPUs being used by which process. And finally, there's a Boolean | |
[02:50:45.520 --> 02:50:51.440] variable that I like to create, which is the DDP rank equals equals zero. So the master process | |
[02:50:51.440 --> 02:50:56.320] is arbitrarily process number zero. And it does a lot of the printing, logging, checkpointing, | |
[02:50:56.320 --> 02:51:00.800] et cetera. And the other processes are thought of mostly as a compute processes that are assisting. | |
[02:51:00.800 --> 02:51:05.680] And so master process zero will have some additional work to do. All the other processes | |
[02:51:05.680 --> 02:51:10.720] will will most just be doing forward backwards. And if we're not using DDP and none of these | |
[02:51:10.720 --> 02:51:15.440] variables are set, we revert back to single GPU training. So that means that we only have rank | |
[02:51:15.440 --> 02:51:22.480] zero. The world size is just one. And we are the master process. And we try to auto detect the | |
[02:51:22.480 --> 02:51:28.400] device. And this is world as normal. So so far, all we've done is we've initialized DDP. | |
[02:51:28.400 --> 02:51:32.960] And in the case where we're running with torch run, which we'll see in a bit, | |
[02:51:32.960 --> 02:51:37.520] there's going to be eight copies running in parallel. Each one of them will have a different | |
[02:51:37.520 --> 02:51:42.880] rank. And now we have to make sure that everything happens correctly afterwards. | |
[02:51:42.880 --> 02:51:48.160] So the tricky thing with running multiple processes is you always have to imagine that | |
[02:51:48.160 --> 02:51:53.360] there's going to be eight processes running in parallel. So as you read the code now, | |
[02:51:53.360 --> 02:51:58.800] you have to imagine there's eight, you know, eight Python interpreters running down these lines of | |
[02:51:58.800 --> 02:52:03.760] code. And the only difference between them is that they have a different DDP rank. So they all | |
[02:52:03.760 --> 02:52:09.360] come here, they all pick the exact same seed. They all make all of these calculations completely | |
[02:52:09.360 --> 02:52:14.560] unaware of the other copies running, roughly speaking, right? So they all make the exact same | |
[02:52:14.560 --> 02:52:19.760] calculations. And now we have to adjust these calculations to take into account that there's | |
[02:52:19.760 --> 02:52:26.400] actually like a certain world size and certain ranks. So in particular, these micro batches and | |
[02:52:26.400 --> 02:52:32.800] sequence links, these are all just per GPU, right? So now there's going to be num processes of them | |
[02:52:32.800 --> 02:52:38.560] running in parallel. So we have to adjust this, right? Because the gradicum steps now is going to | |
[02:52:38.560 --> 02:52:49.520] be total batch size divided B times T times DDP role size. Because each process will do B times T, | |
[02:52:49.520 --> 02:52:54.640] and there's this many of them. And so in addition to that, we want to make sure | |
[02:52:54.640 --> 02:53:01.040] that this fits nicely into total batch size, which for us, it will be because 16 times 1.24 times | |
[02:53:01.760 --> 02:53:09.760] eight APUs is 131. Okay. And so five, two, four, two, eight, eight. This means that our | |
[02:53:09.760 --> 02:53:16.800] gradicum will be four with the current settings, right? So there's going to be 16 times 124 process | |
[02:53:16.800 --> 02:53:23.120] in each GPU. And then there's a GPU. So we're going to be doing 131,000 tokens in a single forward | |
[02:53:23.120 --> 02:53:30.480] backward on the HPUs. So we're going to make sure that this fits nicely so that we can derive | |
[02:53:30.480 --> 02:53:36.640] a nice gradient accumulation of steps. And yeah, let's just adjust the comments here, | |
[02:53:36.640 --> 02:53:46.400] times DDP role size. Okay. So each GPU calculates this. Now this is where we started to run into | |
[02:53:46.400 --> 02:53:52.480] issues, right? So we are each process is going to come by a print, and they're all going to print. | |
[02:53:52.480 --> 02:53:57.680] So we're going to have eight copies of these prints. So one way to deal with this is exactly | |
[02:53:57.680 --> 02:54:03.920] this master process variable that we have. So if master process, then guard this. And that's | |
[02:54:03.920 --> 02:54:07.920] just so that we just print this a single time, because otherwise all the processes would have | |
[02:54:07.920 --> 02:54:13.840] computed the exact same variables, and there's no need to print this eight times. Before getting | |
[02:54:13.840 --> 02:54:19.280] into the data loader, we're going to have to refactor it, obviously. Maybe at this point is, | |
[02:54:19.280 --> 02:54:25.360] we should do some prints, and just take it out for a spin and exit at this point. So import sys. | |
[02:54:27.360 --> 02:54:44.480] And sys.exit in print, I am GPU DDP rank. I am GPU DDP rank and print by. | |
[02:54:44.480 --> 02:54:51.920] So now let's try to run this and just see how this works. So let's take it for a spin, | |
[02:54:51.920 --> 02:54:56.720] just so we see what it looks like. So normally we used to launch Python train GPT to that pile like | |
[02:54:56.720 --> 02:55:02.400] this. Now we're going to run the torch run. And this is what it looks like. So torch run standalone | |
[02:55:02.400 --> 02:55:07.760] number of processes, for example, is eight for us, because we have a GPUs, and then train GPT to | |
[02:55:07.760 --> 02:55:13.600] that pile. So this is what the command would look like. And torch run, again, we'll run eight of these. | |
[02:55:13.600 --> 02:55:20.720] So let's just see what happens. So first, it gets a little busy. So there's a lot going on here. So | |
[02:55:20.720 --> 02:55:26.480] first of all, there's some warnings from distributed. And I don't actually know that these mean anything. | |
[02:55:26.480 --> 02:55:30.400] I think this is just like, the code is setting up and the processes are coming online. And we're | |
[02:55:30.400 --> 02:55:36.400] seeing some preliminary failure to collect while the processes come up. I'm not 100% sure about that. | |
[02:55:36.400 --> 02:55:44.560] But we start to then get into actual prints. So all the processes went down. And then the first | |
[02:55:44.560 --> 02:55:52.000] print actually comes from process five, just by chance. And then it printed. So process five | |
[02:55:52.000 --> 02:55:59.920] basically got here first, it said I'm process on GPU five by. And then this de sprints come from | |
[02:55:59.920 --> 02:56:05.760] the master process. So process five just finished first for whatever reason, it just depends on how | |
[02:56:05.760 --> 02:56:11.600] the operating system scheduled the processes to run. Then GPU zero ended, then GPU three and two. | |
[02:56:11.600 --> 02:56:19.840] And then probably process five or something like that has exited. And DDP really doesn't like that | |
[02:56:19.840 --> 02:56:28.240] because we didn't properly dispose of the multi GPUs setting. And so process group has not been | |
[02:56:28.240 --> 02:56:33.840] destroyed before we destruct. So it really doesn't like that. And in an actual application, we would | |
[02:56:33.840 --> 02:56:40.400] want to call destroy process group, so that we clean up DDP properly. And so it doesn't like that | |
[02:56:40.400 --> 02:56:46.240] too much. And then the rest of the GPUs finish. And that's it. So basically, we can't guarantee | |
[02:56:46.240 --> 02:56:50.240] when these processes are running, it's totally arbitrary, but they are running in parallel. | |
[02:56:50.240 --> 02:56:58.800] We don't want that to be printing. And next up, let's erase this. Next up, we want to make | |
[02:56:58.800 --> 02:57:03.520] sure that when we create data, a little light, we need to now make it aware of this multi process | |
[02:57:03.520 --> 02:57:10.160] setting, because we don't want all the processes to be loading the exact same data. We want every | |
[02:57:10.160 --> 02:57:14.720] process to get its own chunk of data so that they're all working on different parts of the data set, | |
[02:57:14.720 --> 02:57:21.280] of course. So let's adjust that. So one particularly simple and an naive way to do this, is we have | |
[02:57:21.280 --> 02:57:27.840] to make sure that we pass in the rank and the size to the data loader. And then we come up here, | |
[02:57:27.840 --> 02:57:33.360] we see that we now take rank and processes and we save them. Now the current position will not be | |
[02:57:33.360 --> 02:57:40.480] zero, because what we want is we want to stride out all the processes. So one way to do this is | |
[02:57:40.480 --> 02:57:45.760] we basically take cell type B times cell type T, and then multiply it by the process rank. | |
[02:57:45.760 --> 02:57:52.880] So process rank zero will start at zero, but process rank one now starts at B times T. | |
[02:57:52.880 --> 02:57:58.880] Process rank two is starts at two times B times T, etc. So that is the initialization. | |
[02:57:58.880 --> 02:58:05.840] Now we still, they still do this identically, but now when we advance, we don't advance by B | |
[02:58:05.840 --> 02:58:12.400] times T, we advance by B times T times number of processes, right? So basically, | |
[02:58:12.400 --> 02:58:20.000] the total number of tokens that we're consuming is B times T times number processes, and they all | |
[02:58:20.000 --> 02:58:27.680] go off to a different rank. And the position has to advance by the entire chunk. And then | |
[02:58:27.680 --> 02:58:34.560] here at B times T times cell type number of processes, plus one would be to exceed number of tokens, | |
[02:58:34.560 --> 02:58:38.640] then we're going to loop. And when we loop, we want to, of course, loop in the exact same way. | |
[02:58:38.640 --> 02:58:46.480] So we sort of like reset back. So this is the simplest change that I can find for kind of a | |
[02:58:46.480 --> 02:58:52.160] very simple distributed data load of light. And you can notice that if process rank is zero, | |
[02:58:52.160 --> 02:58:56.640] and that process he says one, then the whole thing will be identical to what we had before. | |
[02:58:56.640 --> 02:59:00.800] But now we can have actually multiple processes running, and this should work fine. | |
[02:59:03.360 --> 02:59:07.920] So that's the data load. Okay, so next up, once they've all initialized the data load, | |
[02:59:07.920 --> 02:59:15.440] they come here and they all create a GPT model. So we create eight GPT models on eight processes. | |
[02:59:15.440 --> 02:59:19.440] But because the seeds are fixed here, they all create the same identical model. | |
[02:59:19.440 --> 02:59:25.280] They all move it to the device of their rank, and they all compile the model. And because the | |
[02:59:25.280 --> 02:59:30.080] models are identical, there are eight identical compilations happening in parallel, but that's okay. | |
[02:59:31.040 --> 02:59:35.440] Now, none of this changes because that is on a per step basis. And we're currently working | |
[02:59:35.440 --> 02:59:41.280] kind of within step because we need to just all the changes we're making are kind of like a | |
[02:59:41.280 --> 02:59:46.720] within step changes. Now, the important thing here is when we construct the model, | |
[02:59:46.720 --> 02:59:51.120] we actually have a bit of work to do here, get logits is deprecated. So create model. | |
[02:59:51.120 --> 02:59:58.160] We need to actually wrap the model into the distributed data parallel container. | |
[02:59:59.360 --> 03:00:06.080] So this is how we wrap the model into the DDP container. And these are the docs for DDP. And | |
[03:00:06.080 --> 03:00:10.240] they're quite extensive. And there's a lot of caveats and a lot of things to be careful with | |
[03:00:10.240 --> 03:00:15.920] because everything complexifies times 10 when multiple processes are involved. But roughly | |
[03:00:15.920 --> 03:00:20.320] speaking, this device ID, I believe has to be passed in. Now, unfortunately, the docs for what | |
[03:00:20.320 --> 03:00:27.360] device IDs is is extremely unclear. So when you actually like come here, this comment for what | |
[03:00:27.360 --> 03:00:34.880] device IDs is is roughly nonsensical. But I'm pretty sure it's supposed to be the DDP local rank. | |
[03:00:34.880 --> 03:00:41.840] So not the DDP rank, the local rank. So this is what you pass in here. This wraps the model. | |
[03:00:41.840 --> 03:00:46.720] And in particular, what DDP does for you is in a forward pass, it actually behaves identically. | |
[03:00:46.720 --> 03:00:51.920] So my understanding of it is nothing should be changed in the forward pass. But in the backward | |
[03:00:51.920 --> 03:00:58.400] pass, as you are doing the backward pass, in the simplest setting, once the backward pass is over | |
[03:00:58.400 --> 03:01:05.600] on each independent GPU, each independent GPU has the gradient for all the primers. And what DDP | |
[03:01:05.600 --> 03:01:11.040] does for you is once the backward pass is over, it will call what's called all reduce. And it | |
[03:01:11.040 --> 03:01:18.960] basically does an average across all the ranks of their gradients. And then it will deposit that | |
[03:01:18.960 --> 03:01:24.000] average on every single rank. So every single single rank will end up with the average | |
[03:01:24.000 --> 03:01:29.120] on it. And so basically, that's the communication. It just synchronizes and averages the gradients. | |
[03:01:29.120 --> 03:01:36.000] And that's what DDP offers you. Now DDP actually is a little bit more involved in that because | |
[03:01:36.000 --> 03:01:40.880] as you are doing the backward pass through the layers in the transformer, it actually can dispatch | |
[03:01:40.880 --> 03:01:46.880] communications for the gradient while the backward pass is still happening. So there's overlap of the | |
[03:01:46.880 --> 03:01:51.360] communication of the gradients and the synchronization of them and the backward pass. | |
[03:01:51.360 --> 03:01:57.840] And this is just more efficient and to do it that way. So that's what DDP does for you. | |
[03:01:57.840 --> 03:02:04.880] Forward is unchanged and backward is mostly unchanged and we're tacking on this average as we'll see | |
[03:02:04.880 --> 03:02:11.840] in a bit. Okay, so now let's go to the optimization. Nothing here changes. Let's go to the optimization | |
[03:02:11.840 --> 03:02:17.760] here, the inner loop and think through the synchronization of these gradients in the DDP. So basically, | |
[03:02:17.760 --> 03:02:22.640] by default, what happens as I mentioned is when you do lost a backward here, it will do the backward | |
[03:02:22.640 --> 03:02:28.800] pass and then it will synchronize the gradients. The problem here is because of the gradient | |
[03:02:28.800 --> 03:02:35.360] accumulation steps loop here, we don't actually want to do the synchronization after every single | |
[03:02:35.360 --> 03:02:40.320] lost a backward, because we are just depositing gradients. And we're doing that serially. And we | |
[03:02:40.320 --> 03:02:44.560] just want them adding up. And we don't want to synchronize every single time that would be extremely | |
[03:02:44.560 --> 03:02:49.920] wasteful. So basically, we want to add them up. And then on the very last, it's only on the very | |
[03:02:49.920 --> 03:02:55.680] last step, when micro step, when micro step becomes greater from steps minus one, only at that last | |
[03:02:55.680 --> 03:03:03.920] step that we want to actually do the all reduce to average of gradients. So to do that, we come here | |
[03:03:03.920 --> 03:03:11.040] and the official sanctioned way, by the way, is to do this no sync context manager. So PyTorch says, | |
[03:03:11.040 --> 03:03:16.720] this is a context manager to disable gradient synchronization across DDP processes. So within | |
[03:03:16.720 --> 03:03:22.800] this context, gradients will be accumulated. And basically, when you do no sync, there will be no | |
[03:03:22.800 --> 03:03:28.880] communication. So they are telling us to do with DDP, no sync, do the gradient accumulation, | |
[03:03:28.880 --> 03:03:34.320] accumulate grads. And then they are asking us to do DDP again with another input and that backward. | |
[03:03:34.320 --> 03:03:39.600] And I just really don't love this. I just really don't like it. The fact that you have to copy | |
[03:03:39.600 --> 03:03:44.000] paste your code here and use a context manager. And it's just super ugly. So when I went to this | |
[03:03:44.000 --> 03:03:51.360] source code here, you can see that when you enter, you simply toggle this variable, this | |
[03:03:51.360 --> 03:03:58.560] required backward grad sync. And this is being toggled around and changed. And this is the | |
[03:03:58.560 --> 03:04:05.200] variable that basically, if you step through it, is being toggled to determine if the gradient is | |
[03:04:05.200 --> 03:04:10.720] going to be synchronized. So I actually just kind of like to use that directly. So instead, | |
[03:04:10.720 --> 03:04:18.000] what I like to do is the following right here, before the loss back backward, if we are using DDP, | |
[03:04:18.880 --> 03:04:25.760] then then basically, we only want to synchronize. We only want this variable to be true. | |
[03:04:25.760 --> 03:04:33.200] When it is the final iteration in all the other iterations inside the micro steps, we want to be | |
[03:04:33.200 --> 03:04:39.280] false. So I just toggle it like this. So require backward grad sync should only turn on when the | |
[03:04:39.280 --> 03:04:46.480] micro step is the last step. And so I'm toggling this variable directly. And I hope that that | |
[03:04:46.480 --> 03:04:51.040] impacts lost up backward. And this is a naughty thing to do because, you know, they could probably | |
[03:04:51.040 --> 03:04:56.400] change the DDP and this variable will go away. But for now, I believe this this works. And it | |
[03:04:56.400 --> 03:05:01.200] allows me to avoid the use of context managers and code duplication. I'm just toggling the variable | |
[03:05:01.200 --> 03:05:04.800] and then lost up backward will not synchronize most of the steps. And it will synchronize the | |
[03:05:04.800 --> 03:05:15.120] very last step. And so once this is over, and we come out, every single rank will suddenly magically | |
[03:05:15.120 --> 03:05:22.320] have the average of all the gradients that are stored on all the ranks. So now we have to think | |
[03:05:22.320 --> 03:05:29.600] through whether that is what we want. And also, if this suffices and whether how it works with the | |
[03:05:29.600 --> 03:05:34.480] loss and what is loss acume, so let's think through them through them. And the problem I'm | |
[03:05:34.480 --> 03:05:40.160] getting at is that we've averaged the gradients, which is great, but the loss acume has not been | |
[03:05:40.160 --> 03:05:45.840] impacted yet. And the, and this is outside of the DDP container. So that is not being averaged. | |
[03:05:45.840 --> 03:05:51.200] And so here, when we are printing loss acume, well, presumably, we're only going to be printing on | |
[03:05:51.200 --> 03:05:56.080] the master process, rank zero. And it's just going to be printing the losses that it saw on its | |
[03:05:56.080 --> 03:06:01.680] process. But instead, we want it to print the loss over all the processes and the average of | |
[03:06:01.680 --> 03:06:06.880] that loss, because we did average of gradients. So we want the average of loss as well. So simply | |
[03:06:06.880 --> 03:06:14.560] here after this, this is the code that I've used in the past. And instead of loss, we want loss | |
[03:06:14.560 --> 03:06:25.280] acume. So if DDP, again, then this is a PyTorch distributed, I imported, what do I imported? | |
[03:06:25.280 --> 03:06:34.000] Oh, gosh. So this file is starting to get out of control, huh? So if, so import torch | |
[03:06:34.000 --> 03:06:42.400] distributed as this. So this dot all reduce, and we're doing the average on loss acume. And so this | |
[03:06:42.400 --> 03:06:47.840] loss acume tensor exists on all the ranks, when we call all reduce of average, it creates the | |
[03:06:47.840 --> 03:06:53.440] average of those numbers, and it deposits that average on all the ranks. So all the ranks after | |
[03:06:53.440 --> 03:07:01.200] this call will now contain loss acume averaged up. And so when we print here on the master process, | |
[03:07:01.200 --> 03:07:05.920] the loss acume is identical in all the other ranks as well. So here, if master process, | |
[03:07:05.920 --> 03:07:12.640] oops, we want to print like this. Okay, and finally, we have to be careful because we're not | |
[03:07:12.640 --> 03:07:20.160] processing even more tokens. So times DDP world size, that's number of tokens that we've processed | |
[03:07:20.160 --> 03:07:29.440] up above. And everything else should be fine. The only other thing to be careful with is as I | |
[03:07:29.440 --> 03:07:34.000] mentioned, you want to destroy the process group so that we are nice to nickel, and it's not going | |
[03:07:34.000 --> 03:07:42.480] to to to DDP. And it's not going to complain to us when we exit here. So that should be it. | |
[03:07:42.480 --> 03:07:46.720] Let's try to take it for a spin. Okay, so I launched the script, and it should be printing | |
[03:07:46.720 --> 03:07:51.440] here imminently. We're now training with eight GPUs at the same time. So the gradient accumulation | |
[03:07:51.440 --> 03:07:59.840] steps is not 32 is now divide eight. And it's just four. So otherwise, this is what the optimization | |
[03:07:59.840 --> 03:08:08.000] looks like. And while we're going really fast, so we're processing 1.5 million tokens per second | |
[03:08:08.000 --> 03:08:12.880] now. So these are some serious numbers. And the tiny Shakespeare data set is so tiny that we're | |
[03:08:12.880 --> 03:08:17.760] just doing like so many epochs over it, most likely. But this is roughly what it looks like. | |
[03:08:19.600 --> 03:08:24.320] One thing that I had to fix, by the way, is that this was a model that configure optimizers, | |
[03:08:24.320 --> 03:08:29.360] which now doesn't work because model now is a DDP model. So instead, this has to become raw model | |
[03:08:29.360 --> 03:08:36.080] that configure optimizers, where raw model is something I create here. So right after I wrapped | |
[03:08:36.080 --> 03:08:43.280] a model into DDP, I have to create the raw model, which in the case of DDP is a model that module | |
[03:08:43.280 --> 03:08:49.040] is where it stores the raw and a module of GPT two as we have it, which contains the | |
[03:08:49.040 --> 03:08:53.600] configure optimizers function that we want to call. So that's one thing that I had to fix. | |
[03:08:53.600 --> 03:08:58.480] Otherwise, this seems to run. Now, one thing you'll notice is that when you actually compare | |
[03:08:58.480 --> 03:09:04.160] this run and the numbers in it to the just running a single GPU, you'll notice that this is single | |
[03:09:04.160 --> 03:09:11.600] GPU run with 32 gradacum. The numbers won't exactly match up. And it's kind of a boring reason for | |
[03:09:11.600 --> 03:09:16.560] why that happens. The reason for that is that in a data lawyer, we're basically just iterating | |
[03:09:16.560 --> 03:09:21.200] through batches in a slightly different way, because now we're looking for an entire page of data. | |
[03:09:21.200 --> 03:09:28.320] And if that page for all the GPUs, if that chunk exceeds the number of tokens, we just loop. And so | |
[03:09:28.320 --> 03:09:35.200] actually the single GPU and the GPU process will end up resetting in a slightly different manner, | |
[03:09:35.200 --> 03:09:38.640] as our batches are slightly different. And so we get slightly different numbers. | |
[03:09:39.200 --> 03:09:43.920] But one way to convince yourself that this is okay, is just make the total batch size much | |
[03:09:43.920 --> 03:09:52.560] smaller and the B and a T. And then so I think I used four times one 24 times eight. So I used | |
[03:09:52.560 --> 03:09:59.040] 32 768 as a total pack size. And then so I made sure that the single GPU will do eight | |
[03:09:59.040 --> 03:10:04.240] grade even simulation steps, and then the multi GPU. And then you're reducing the boundary effects | |
[03:10:04.240 --> 03:10:08.960] of the data loader. And you'll see that the numbers match up. So one story short, we're now | |
[03:10:08.960 --> 03:10:13.840] going really, really fast. The optimization is mostly consistent with GPT two and three hybrid | |
[03:10:13.840 --> 03:10:20.640] parameters. And we have outgrown our tiny Shakespeare file. And we want to upgrade it. So let's move | |
[03:10:20.640 --> 03:10:25.040] to next to that next. So let's now take a look at what data sets were used by GPT two and GPT three. | |
[03:10:25.040 --> 03:10:31.600] So GPT two used this web text data set that was never released. There's an attempt at | |
[03:10:31.600 --> 03:10:36.160] reproducing it called open web text. So basically, roughly speaking, what they say here in the paper | |
[03:10:36.160 --> 03:10:42.880] is that is create all outbound links from Reddit. And then with at least three karma. And that | |
[03:10:42.880 --> 03:10:46.880] was kind of like their starting point and they collected all the web pages and all the text in | |
[03:10:46.880 --> 03:10:52.320] them. And so this was 45 million links. And this ended up being 40 gigabytes of text. So | |
[03:10:52.320 --> 03:10:58.640] so that's roughly what GPT two says about its data set. So it's basically outbound links from | |
[03:10:58.640 --> 03:11:03.920] Reddit. Now when we go over to GPT three, there's a training data set section. And that's where they | |
[03:11:03.920 --> 03:11:11.360] start to talk about common crawl, which is a lot more used. Actually, I think you can GPT two | |
[03:11:11.360 --> 03:11:16.640] talked about common crawl. But basically, it's not a very high quality data set all by itself, | |
[03:11:16.640 --> 03:11:21.120] because it's extremely noisy. This is a completely random subset of the internet. And it's much | |
[03:11:21.120 --> 03:11:25.520] worse than you think. So people go into great lengths to filter common crawl, because there's | |
[03:11:25.520 --> 03:11:30.480] good stuff in it. But most of it is just like ad spam and random tables and numbers and stock | |
[03:11:30.480 --> 03:11:39.760] takers. And it's just total mess. So that's why people like to train on these data mixtures | |
[03:11:39.760 --> 03:11:45.760] that they curate and are careful with. So a large chunk of these data mixtures typically will be | |
[03:11:45.760 --> 03:11:50.880] common crawl, like for example, 50% of the tokens will be common crawl. But then here in GPT three, | |
[03:11:50.880 --> 03:11:55.280] they're also using web text to from before. So that's Reddit outbound. But they're also adding, | |
[03:11:55.280 --> 03:11:59.840] for example, books. And they're anything Wikipedia. There's many other things you can decide to add. | |
[03:12:00.560 --> 03:12:05.280] Now, this data set for GPT three was also never released. So today, some of the datasets that | |
[03:12:05.280 --> 03:12:09.840] I'm familiar with that are quite good and would be representative of something along these lines. | |
[03:12:09.840 --> 03:12:15.760] Our number one, the red pajama data set, or more specifically, for example, the slim pajama subset | |
[03:12:15.760 --> 03:12:21.280] of the red pajama data set, which is a cleaned and de-duplicated version of it. And just to give | |
[03:12:21.280 --> 03:12:27.600] you a sense, again, it's a bunch of common crawl, C four, which is also, as far as I know, more common | |
[03:12:27.600 --> 03:12:33.600] crawl, but processed differently. And then we have GitHub books, archive, Wikipedia, stack is change. | |
[03:12:33.600 --> 03:12:37.920] These are the kinds of datasets that would go into these data mixtures. Now, specifically the one | |
[03:12:37.920 --> 03:12:44.080] that I like that came out recently is called fine web data set. So this is an attempt to basically | |
[03:12:44.080 --> 03:12:50.320] collect really high quality common crawl data and filter it, in this case, to 15 trillion tokens. | |
[03:12:50.320 --> 03:12:56.000] And then in addition to that, more recently, hugging face released this fine web EDU subset, | |
[03:12:56.000 --> 03:13:02.160] which is 1.3 trillion of educational and 5.4 trillion of high educational content. | |
[03:13:02.160 --> 03:13:08.000] So basically, they're trying to filter common crawl to very high quality educational subsets. | |
[03:13:08.000 --> 03:13:15.040] And this is the one that we will use. There's a long web page here on fine web, and they go into a | |
[03:13:15.040 --> 03:13:19.120] ton of detail about how they process the data, which is really fascinating reading, by the way. | |
[03:13:19.120 --> 03:13:22.560] And I would definitely recommend if you're interested into data mixtures and so on, | |
[03:13:22.560 --> 03:13:27.600] and how data gets processed at these scales, I'll look at this page. And more specifically, | |
[03:13:27.600 --> 03:13:33.200] we'll be working with the fine web EDU, I think. And it's basically educational content from the | |
[03:13:33.200 --> 03:13:41.920] internet. They show that training on educational content in their metrics works really, really well. | |
[03:13:41.920 --> 03:13:49.200] And we're going to use this sample 10 billion tokens subsample of it, because we're not going | |
[03:13:49.200 --> 03:13:55.120] to be training on trillions of tokens. We're just going to train on 10 billion sample off the fine | |
[03:13:55.120 --> 03:14:00.240] web EDU, because empirically, in my previous few experiments, this actually suffices to really | |
[03:14:00.240 --> 03:14:05.520] get close to GPT two performance. And it's simple enough to work with. And so let's work with the | |
[03:14:05.520 --> 03:14:12.960] sample 10 BT. So our goal will be to download it, process it, and make sure that our data loader | |
[03:14:12.960 --> 03:14:19.840] can work with it. So let's get to that. Okay, so I introduced another file here that will | |
[03:14:19.840 --> 03:14:26.320] basically download fine web EDU from hugging face datasets. It will pre process and pre tokenize | |
[03:14:26.320 --> 03:14:36.480] all of the data, and it will save data shards to a folder on a local disk. And so while this is running, | |
[03:14:36.480 --> 03:14:41.840] just wanted to briefly mention that you can kind of look through the data set viewer here, | |
[03:14:41.840 --> 03:14:45.280] just to get a sense of what's in here. And it's kind of interesting. I mean, it's a, | |
[03:14:45.280 --> 03:14:50.160] it basically looks like it's working fairly well. Like it's talking about nuclear energy in France, | |
[03:14:50.160 --> 03:14:58.320] it's talking about Mexican America, some Mac, Pi J's, et cetera. So actually, it seems like | |
[03:14:58.320 --> 03:15:03.680] their filters are working pretty well. The filters here, by the way, were applied automatically using | |
[03:15:04.560 --> 03:15:11.200] llama 370b, I believe. And so basically, LLMs are judging which content is educational, and that | |
[03:15:11.200 --> 03:15:15.760] ends up making it through the filter. So that's pretty cool. Now in terms of the script itself, | |
[03:15:15.760 --> 03:15:20.800] I'm not going to go through the full script, because it's not as interesting and not as LLM | |
[03:15:20.800 --> 03:15:25.040] centric. But when you run this, basically, number one, we're going to load the data set, | |
[03:15:25.040 --> 03:15:30.800] which this is all hugging face code, running this, you're going to need to pip install datasets. | |
[03:15:33.360 --> 03:15:39.360] So it's downloading the dataset, then it is tokenizing all of the documents inside this dataset. | |
[03:15:39.360 --> 03:15:44.960] Now, when we tokenize the documents, you'll notice that to tokenize a single document, | |
[03:15:44.960 --> 03:15:52.080] we first start the tokens with the end of text token. And this is a special token in the GPT-2 | |
[03:15:52.080 --> 03:15:59.040] tokenizer, as you know. So 50,256 is the ID of the end of text. And this is what begins a document, | |
[03:15:59.040 --> 03:16:03.440] even though it's called end of text. But this is the first token that begins a document. | |
[03:16:03.440 --> 03:16:09.680] Then we extend with all of the tokens of that document, then we create a NumPy array out of that. | |
[03:16:09.680 --> 03:16:18.320] We make sure that all the tokens are between, oh, okay, let me debug this. Okay, so apologies for | |
[03:16:18.320 --> 03:16:22.880] that. It just had to do with me using a float division in Python, it must be integer division, | |
[03:16:22.880 --> 03:16:29.360] so that this is an int and everything is nice. Okay, but basically, the tokenization here is | |
[03:16:29.360 --> 03:16:36.320] relatively straightforward returns tokens in mp.un16. We're using un.16 to save a little bit of space, | |
[03:16:36.320 --> 03:16:43.440] because two to the 16 minus one is 65,000. So the GPT-2 max token ID is well below that. | |
[03:16:43.440 --> 03:16:47.920] And then here, there's a bunch of multi processing code, and it's honestly not that exciting, | |
[03:16:47.920 --> 03:16:52.640] so I'm not going to step through it. But we're loading the dataset, we're tokenizing it, | |
[03:16:52.640 --> 03:16:59.520] and we're saving everything to shards. And the shards are NumPy files. So just storing a NumPy | |
[03:16:59.520 --> 03:17:08.480] array, which is very, very similar to Torx tensors. And the first shard, 000, is a validation shard, | |
[03:17:08.480 --> 03:17:14.320] and all the other shards are training shards. And as I mentioned, they all have 100 million tokens | |
[03:17:14.320 --> 03:17:21.680] in them exactly. And that just makes it easier to work with us to shard the files, | |
[03:17:21.680 --> 03:17:25.840] because if we just have a single massive file, sometimes it can be hard to work with on the disk. | |
[03:17:25.840 --> 03:17:32.480] And so sharding it is just kind of a messier from that perspective. And yeah, so we'll just let | |
[03:17:32.480 --> 03:17:38.960] this run. This will be probably 30ish minutes or so, and then we're going to come back to actually | |
[03:17:38.960 --> 03:17:43.120] train on this data. And we're going to be actually doing some legit pre training in this case. This | |
[03:17:43.120 --> 03:17:49.360] is a good data set. We're doing lots of tokens per second. We have HEPUs, the code is ready. | |
[03:17:49.360 --> 03:17:53.440] And so we're actually going to be doing a serious training run. So let's get back in a minute. | |
[03:17:53.440 --> 03:18:01.040] Okay, so we're back. So if we LSE do find web, we see that there's now 100 shards in it. | |
[03:18:01.040 --> 03:18:07.600] And that makes sense because each shard is 100 million tokens. So 100 shards of that is 10 billion | |
[03:18:07.600 --> 03:18:12.960] tokens in total. Now swinging over to the main file, I made some adjustments to our data loader | |
[03:18:12.960 --> 03:18:19.120] again. And that's because we're not running with Shakespeare anymore. We want to use the find web | |
[03:18:19.120 --> 03:18:23.680] shards. And so you'll see some code here that additionally basically can load these shards. | |
[03:18:23.680 --> 03:18:31.600] We load the UN16 NumPy file. We convert it to a torch.long tensor, which is what a lot of the | |
[03:18:31.600 --> 03:18:38.400] layers up top expect by default. And then here we're just enumerating all the shards. I also added | |
[03:18:38.400 --> 03:18:44.000] a split to data loader light. So we can load the split train, but also the split valve, the zero | |
[03:18:44.000 --> 03:18:50.400] split. And then we can load the shards. And then here we also have not just a current position now, | |
[03:18:50.400 --> 03:18:56.720] but also the current shard. So we have a position inside a shard. And then when we run out of tokens | |
[03:18:56.720 --> 03:19:02.880] in a single shard, we first advance the shard and loop if we need to. And then we get the tokens | |
[03:19:02.880 --> 03:19:09.920] and readjust the position. So this data loader will now iterate all the shards as well. So I | |
[03:19:09.920 --> 03:19:15.200] changed that. And then the other thing that I did while the data is processing is our train | |
[03:19:15.200 --> 03:19:24.000] loader now has split train, of course. And down here I set up some numbers. So we are doing two to | |
[03:19:24.000 --> 03:19:34.320] the 19 tokens per per step. And we want to do roughly 10 billion tokens, | |
[03:19:34.320 --> 03:19:39.680] because that's how many unique tokens we have. So if we did 10 billion tokens, then divide that | |
[03:19:39.680 --> 03:19:46.000] by two to the 19, we see that this is 19,073 steps. So that's where that's from. And then | |
[03:19:46.000 --> 03:19:50.880] the GPT three paper says that they warm up the learning rate over 375 million tokens. | |
[03:19:51.760 --> 03:20:01.760] So I came here and 375 e6 tokens divide two to the 19 is 715 steps. So that's why warm up steps | |
[03:20:01.760 --> 03:20:07.680] is set to 715. So this will exactly match the warm up schedule that GPT three used. | |
[03:20:07.680 --> 03:20:13.520] And I think 715 by the way is very mild. And this could be made significantly more aggressive, | |
[03:20:13.520 --> 03:20:19.280] probably even like 100 is good enough. But it's okay, let's leave it for now so that we have the | |
[03:20:19.280 --> 03:20:26.560] exact hyper parameters of GPT three. So I fix that. And then that's pretty much it. We can, | |
[03:20:26.560 --> 03:20:34.720] we can run. So we have our script here. And we can launch. And actually, sorry, let me do one more thing. | |
[03:20:34.720 --> 03:20:47.920] Excuse me. For my GPU, I can actually fit more back size. And I believe I can fit 64 on my GPU | |
[03:20:48.800 --> 03:20:51.120] as a micro batch size. So let me try that. | |
[03:20:51.120 --> 03:21:01.680] I could be misremembering. But that means 64 times 124 per GPU. And then we have a GPU. So that | |
[03:21:01.680 --> 03:21:07.680] means we would not even be doing gradient accumulation in this fit. Because this just multiplies out to | |
[03:21:07.680 --> 03:21:14.800] the full total back size. So no gradient accumulation. And that would run pretty quickly if that fits. | |
[03:21:19.680 --> 03:21:31.840] Let's go. Let's go. I mean, if this works, then this is basically a serious pre-training run. | |
[03:21:31.840 --> 03:21:36.480] We're not logging. We're not evaluating the validation split. We're not running any | |
[03:21:36.480 --> 03:21:42.800] valuations yet. So it's not we haven't crossed our T's and dotted our eyes. But if we let this run | |
[03:21:42.800 --> 03:21:48.000] for a while, we're going to actually get a pretty good model. And the model that might even be | |
[03:21:48.000 --> 03:21:55.360] on par with or better than GPT two on 24. Okay. So it looks like everything is growing great. | |
[03:21:55.360 --> 03:21:58.160] We're processing 1.5 million tokens per second. | |
[03:21:58.160 --> 03:22:08.720] Everything here looks good. We're doing 330 milliseconds per iteration. And we have to do a total of | |
[03:22:10.320 --> 03:22:20.560] where are we printing that? 1973. So 1903 times 0.33 is this many seconds, this many minutes. | |
[03:22:20.560 --> 03:22:30.560] So this will run for 1.7 hours. So one and a half hour run like this. And we don't even have to use | |
[03:22:30.560 --> 03:22:34.720] gradient accumulation, which is nice. And you might not have that luxury in your GPU. In that case, | |
[03:22:34.720 --> 03:22:38.800] just start decreasing the batch size until things fit. But keep it to nice numbers. | |
[03:22:41.040 --> 03:22:44.640] So that's pretty exciting. We're currently warming up the learning rate. So you see that | |
[03:22:44.640 --> 03:22:49.360] it's still very low, one in negative four. So this will ramp up over the next few steps all the way | |
[03:22:49.360 --> 03:22:58.080] to 16 negative four here. Very cool. So now what I'd like to do is let's cross the T's and dot | |
[03:22:58.080 --> 03:23:03.360] our eyes. Let's evaluate on the validation split. And let's try to figure out how we can run emails, | |
[03:23:03.360 --> 03:23:09.040] how we can do logging, how we can visualize our losses, and all the good stuff. So let's get to | |
[03:23:09.040 --> 03:23:13.600] that before we actually do the run. Okay, so I've adjusted the code so that we're evaluating on | |
[03:23:13.600 --> 03:23:18.400] the validation split. So creating the val loader, just by passing in split equals val, | |
[03:23:18.400 --> 03:23:22.000] that will basically create a data loader just for the validation shard. | |
[03:23:22.000 --> 03:23:28.560] The other thing I did is in the data loader, I introduced a new function reset, which is called | |
[03:23:28.560 --> 03:23:34.240] at init. And it basically resets the data loader. And that is very useful because when we come to | |
[03:23:34.240 --> 03:23:40.720] the main training loop now, so this is the code I've added. And basically every 100 iteration, | |
[03:23:40.720 --> 03:23:46.480] including the zero iteration, we put the model into evaluation mode, we reset the val loader, | |
[03:23:46.480 --> 03:23:55.360] and then no gradients involved. We're going to basically accumulate the gradients over say 20 | |
[03:23:55.360 --> 03:24:02.720] steps, and then average it all up and print out the validation loss. And so that basically | |
[03:24:02.720 --> 03:24:07.840] is the exact same logic as the training roughly, but there's no loss that backward. It's only | |
[03:24:07.840 --> 03:24:12.160] inference. We're just measuring the loss. We're adding it up. Everything else otherwise applies | |
[03:24:12.160 --> 03:24:17.760] and is exactly as we've seen it before. And so this will print the validation loss every 100th | |
[03:24:17.760 --> 03:24:24.160] iteration, including the very first iteration. So that's nice. That will tell us some amount, | |
[03:24:24.160 --> 03:24:29.760] some a little bit about how much we're overfitting. That said, like we have roughly infinity data. | |
[03:24:29.760 --> 03:24:34.240] So we're mostly expecting our train and vowel loss to be about the same. But the other reason | |
[03:24:34.240 --> 03:24:39.280] I'm kind of interested in this is because we're can take the GPT two 124M as opening I released it, | |
[03:24:39.280 --> 03:24:43.680] we can initialize from it, and we can basically see what kind of loss it achieves on the validation | |
[03:24:43.680 --> 03:24:48.560] loss as well. And that gives us kind of an indication as to how much that model would | |
[03:24:48.560 --> 03:24:54.560] generalize to 124M. But it's not an, it's sorry to find what edu validations, but that said, | |
[03:24:54.560 --> 03:24:58.160] it's not a super fair comparison to GPT two because it was trained on a very different data | |
[03:24:58.160 --> 03:25:02.720] distribution. But it's still kind of like an interesting data point. And in any case, you would | |
[03:25:02.720 --> 03:25:08.160] always want to have a validation split in a training run like this, so that you can make sure that | |
[03:25:08.160 --> 03:25:14.880] you are not overfitting. And this is especially a concern if we were to make more epochs in our | |
[03:25:14.880 --> 03:25:20.240] training data. So for example, right now we're just doing a single epoch. But if we get to a point | |
[03:25:20.240 --> 03:25:25.360] for everyone on training box or something like that, we would be really careful with maybe we | |
[03:25:25.360 --> 03:25:30.640] are memorizing that data too much. If we have a big enough model, and our validation split would | |
[03:25:30.640 --> 03:25:34.880] be one way to tell whether that is happening. Okay, and in addition to that, if you remember, | |
[03:25:34.880 --> 03:25:39.440] at the bottom of our script, we had all of this orphaned code for sampling from way back when. | |
[03:25:39.440 --> 03:25:45.680] So I deleted that code and I moved it up to here. So once in a while, we simply evaluate | |
[03:25:45.680 --> 03:25:53.520] validation. Once in a while, we sample, we generate samples. And then we do that only every 100 | |
[03:25:53.520 --> 03:25:57.840] steps, and we train on every single step. So that's how I have a structure right now. And | |
[03:25:57.840 --> 03:26:01.920] I've been running this for 1000 iterations. So here are some samples on iteration 1000. | |
[03:26:01.920 --> 03:26:08.320] Hello, I'm a language model, and I'm not able to get more creative. | |
[03:26:08.320 --> 03:26:13.920] I'm a language model and languages file you're learning about here is, or is the beginning of | |
[03:26:13.920 --> 03:26:21.360] a computer. Okay, so this is all like pretty, there's still a garble, but we're only at | |
[03:26:21.360 --> 03:26:26.800] iteration 1000. And we've only just barely reached the maximum learning rate. So this is still | |
[03:26:26.800 --> 03:26:37.760] a learning. We're about to get some more samples coming up in 100. Okay, this is, you know, the | |
[03:26:37.760 --> 03:26:45.280] model is still a young baby. Okay, so basically all of this sampling code that I've put here, | |
[03:26:45.280 --> 03:26:49.360] everything should be familiar with to you and came from before. The only thing that I did is I | |
[03:26:49.360 --> 03:26:55.040] created a generator object in PyTorch, so that I have a direct control over the sampling | |
[03:26:55.040 --> 03:26:59.920] of the random numbers, because I don't want to impact the RNG state of the random number | |
[03:26:59.920 --> 03:27:04.400] generator that is the global one used for training. I want this to be completely outside of the | |
[03:27:04.400 --> 03:27:11.040] training loop. And so I'm using a special sampling RNG. And then I make sure to seed it | |
[03:27:11.040 --> 03:27:16.240] that every single rank has a different seed. And then I pass in here, where we sort of | |
[03:27:16.240 --> 03:27:20.880] consumer in the numbers in multinomial, where the sampling happens, I make sure to pass in the | |
[03:27:20.880 --> 03:27:27.040] generator object there. Otherwise, this is identical. Now the other thing is, you'll notice that we're | |
[03:27:27.040 --> 03:27:32.240] running a bit slower. That's because I actually had to disable torch.compile to get this to sample. | |
[03:27:32.240 --> 03:27:37.040] And so we're running a bit slower. So for some reason it works with no torch compile, | |
[03:27:37.040 --> 03:27:41.520] but when I torch compile my model, I get a really scary error from PyTorch and I have no idea how | |
[03:27:41.520 --> 03:27:46.080] to resolve it right now. So probably by the time you see this code released or something like that, | |
[03:27:46.080 --> 03:27:51.120] maybe it's fixed. But for now, I'm just going to do pin false. And I'm going to bring back | |
[03:27:51.120 --> 03:27:57.760] torch compile. And you're not going to get samples. And I think I'll fix this later. By the way, | |
[03:27:57.760 --> 03:28:03.440] I will be releasing all this code. And actually, I've been very careful about making Git commits | |
[03:28:03.440 --> 03:28:08.160] every time we add something. And so I'm going to release the entire repo that starts completely | |
[03:28:08.160 --> 03:28:13.440] from scratch all the way to now and after this as well. And so everything should be | |
[03:28:13.440 --> 03:28:19.200] exactly documented in the Git commit history. And so I think that will be nice. So hopefully, | |
[03:28:19.200 --> 03:28:23.360] by the time you go to GitHub, this is removed and it's working. And I will have fixed the buck. | |
[03:28:23.360 --> 03:28:28.960] Okay, so I have the optimization running here. And it's stepping and we're on step 6000 or so, | |
[03:28:28.960 --> 03:28:33.200] so we're about 30% through training. Now, while this is training, I would like to introduce | |
[03:28:33.200 --> 03:28:39.120] one evaluation that we're going to use to supplement the validation set. And that is the heliswag eval. | |
[03:28:39.120 --> 03:28:46.160] So heliswag comes from this paper back in 2019. So it's a five year old eval now. And the way | |
[03:28:46.160 --> 03:28:51.520] heliswag works is there is basically a sentence completion data set. So it's a multiple choice. | |
[03:28:51.520 --> 03:28:57.680] For every one of these questions, we have basically a shared context like a woman is outside with | |
[03:28:57.680 --> 03:29:04.000] a bucket and a dog. The dog is running around trying to avoid bath, she, A, raises the bucket | |
[03:29:04.000 --> 03:29:10.800] off with soap and blow dry the dog's head, B uses a hose to keep it from getting soapy, C | |
[03:29:10.800 --> 03:29:17.600] gets the dog wet and it runs away again, or D gets into a bathtub with the dog. And so basically, | |
[03:29:17.600 --> 03:29:23.280] the idea is that these multiple choice are constructed so that one of them is a natural | |
[03:29:23.280 --> 03:29:32.880] continuation of the sentence and the others are not. And the others might not make sense like | |
[03:29:32.880 --> 03:29:37.120] uses the hose to keep it from getting soapy, that makes no sense. And so what happens is that | |
[03:29:37.120 --> 03:29:43.040] models that are not trained very well are not able to tell these apart, but models that have a lot | |
[03:29:43.040 --> 03:29:49.440] of world knowledge and can tell which and can tell a lot about the world will be able to | |
[03:29:49.440 --> 03:29:55.680] create these completions. And these sentences are sourced from activity net and from Wookieow. | |
[03:29:55.680 --> 03:30:04.560] And at the bottom of the paper, there's kind of like a cool chart of the kinds of domains | |
[03:30:04.560 --> 03:30:09.840] in Wookieow. So there's a lot of sentences from computers and electronics and homes and garden. | |
[03:30:09.840 --> 03:30:14.320] And it has kind of a broad coverage of the kinds of things you need to know about the world in order | |
[03:30:14.320 --> 03:30:21.040] to find the most likely completion and the identity of that completion. | |
[03:30:21.040 --> 03:30:25.920] One more thing that's kind of interesting about Halaswag is the way it was constructed | |
[03:30:25.920 --> 03:30:35.600] is that the incorrect options are deliberately adversarially sourced. So they're not just random | |
[03:30:35.600 --> 03:30:40.400] sentences. They're actually sentences generated by language models. And they're generated in such | |
[03:30:40.400 --> 03:30:45.360] a way that language models basically find them difficult, but humans find them easy. And so | |
[03:30:45.360 --> 03:30:50.400] they mentioned that humans have a 95% accuracy on this set. But at the time, the state of their | |
[03:30:50.400 --> 03:30:57.440] language models had only 48%. And so at the time, this was a good benchmark. Now, you can read the | |
[03:30:57.440 --> 03:31:02.960] details of this paper to learn more. The thing to point out though is that this is five years ago. | |
[03:31:02.960 --> 03:31:10.960] And since then what happened to Halaswag is that it's been totally just solved. And so now the | |
[03:31:10.960 --> 03:31:17.120] language models here are 96%. So basically, the last 4% is probably errors in the dataset, | |
[03:31:17.120 --> 03:31:21.440] or the questions are really, really hard. And so basically, this dataset is kind of crushed with | |
[03:31:21.440 --> 03:31:25.040] respect to language models. But back then, the best of the language model was only at about 50%. | |
[03:31:25.040 --> 03:31:32.000] But this is how far things got. But still, the reason people like Halaswag, | |
[03:31:32.560 --> 03:31:39.920] and it's not used by the way in GPT-2, but in GPT-3, there is Halaswag Evale. And lots of people use | |
[03:31:39.920 --> 03:31:48.560] Halaswag. And so with GPT-3, we have results here that are cited. So we know what percent | |
[03:31:48.560 --> 03:31:55.360] accuracy is GPT-3 attains at all these different model checkpoints for Halaswag Evale. And the | |
[03:31:55.360 --> 03:32:00.960] reason people like it is because Halaswag is a smooth eval. And it is an eval that offers, | |
[03:32:00.960 --> 03:32:07.360] quote, unquote, early signal. So early signal means that even small language models are going to | |
[03:32:07.360 --> 03:32:11.680] start at the random chance of 25%. But they're going to slowly improve. And you're going to see | |
[03:32:11.680 --> 03:32:18.800] 25, 26, 27, et cetera. And you can see slow improvement, even when the models are very small, | |
[03:32:18.800 --> 03:32:26.960] and it's very early. So it's smooth. It has early signal. And it's been around for a long time. So | |
[03:32:26.960 --> 03:32:32.720] that's why people will kind of like this eval. Now, the way that we're going to evaluate this | |
[03:32:32.720 --> 03:32:39.440] is as follows. As I mentioned, we have a shared context. And this is kind of like a multiple | |
[03:32:39.440 --> 03:32:44.320] choice task. But instead of giving the model a multiple choice question and asking it for A, | |
[03:32:44.320 --> 03:32:50.320] B, C, or D, we can't do that because these models, when they are so small, as we are seeing here, | |
[03:32:50.320 --> 03:32:54.720] the models can't actually do multiple choice. They don't understand the concept of associating | |
[03:32:54.720 --> 03:32:59.760] a label to one of the options of multiple choice. They don't understand that. So we have to give | |
[03:32:59.760 --> 03:33:06.240] it to them in native form. And the native form is a token completion. So here we do we construct a | |
[03:33:06.240 --> 03:33:14.240] batch of four rows and T tokens, whatever that T happens to be. Then the shared context, that is | |
[03:33:14.240 --> 03:33:20.080] basically the context for the four choices, the tokens of that are shared across all of the rows. | |
[03:33:20.640 --> 03:33:25.280] And then we have the four options. So we kind of like lay them out. And then only one of the | |
[03:33:25.280 --> 03:33:31.360] options is correct. In this case, label three, option three. And so this is the correct option, | |
[03:33:31.360 --> 03:33:37.200] and option one, two for our incorrect. Now, these options might be of different lengths. | |
[03:33:37.200 --> 03:33:42.400] So what we do is we sort of like take the longest length, and that's the size of the batch B by T. | |
[03:33:42.400 --> 03:33:47.840] And then some of these here are going to be padded dimensions. So they're going to be | |
[03:33:47.840 --> 03:33:54.000] unused. And so we need the tokens. We need the correct label. And we need a mask | |
[03:33:54.000 --> 03:33:59.920] that tells us which tokens are active. And the mask is then zero for these padded areas. | |
[03:33:59.920 --> 03:34:06.160] So that's how we construct these batches. And then in order to get the language model to predict | |
[03:34:06.160 --> 03:34:10.880] A, B, C, or D, the way this works is basically we're just going to look at the tokens, | |
[03:34:10.880 --> 03:34:18.000] their probabilities. And we're going to pick the option that gets the lowest or the highest | |
[03:34:18.000 --> 03:34:25.920] average probability for the token. So for the tokens, because that is the most likely completion | |
[03:34:25.920 --> 03:34:31.440] according to the language model. So we're just going to look at the probabilities here, | |
[03:34:31.440 --> 03:34:37.280] and average them up across the options, and pick the one with the highest probability, | |
[03:34:37.280 --> 03:34:45.200] roughly speaking. So this is how we're going to do hella swag. And this is, I believe, also how | |
[03:34:45.200 --> 03:34:53.200] GPT 3 did it. This is how GPT 3 did it, as far as I know. But you should note that some of the | |
[03:34:53.200 --> 03:34:58.000] other emails, or you might see hella swag, may not do it this way. They may do it in a multiple | |
[03:34:58.000 --> 03:35:03.760] choice format where you sort of give the context a single time, and then the four completions. | |
[03:35:03.760 --> 03:35:08.560] And so the model is able to see all the four options before it picks the best possible option. | |
[03:35:08.560 --> 03:35:12.960] And that's actually an easier task for a model because you get to see the other options when | |
[03:35:12.960 --> 03:35:18.560] you're picking your choice. But unfortunately, models that are sized can't do that. Only models | |
[03:35:18.560 --> 03:35:23.600] at a bigger size are able to do that. And so our models are actually slightly handicapped in this | |
[03:35:23.600 --> 03:35:28.960] way that they are not going to see the other options. They're only going to see one option at a time, | |
[03:35:28.960 --> 03:35:33.200] and they just have to assign probabilities. And the correct option has to win out in this metric. | |
[03:35:33.920 --> 03:35:38.400] All right, so let's now implement this very briefly and incorporate it into our script. | |
[03:35:38.400 --> 03:35:43.280] Okay, so what I've done here is I've introduced a new file called hella swag.py, | |
[03:35:43.280 --> 03:35:47.040] and you can take a look into it. And I'm not going to step through all of it because | |
[03:35:47.040 --> 03:35:52.960] this is not exactly like deep code, deep code. It's kind of like a little bit tedious, honestly, | |
[03:35:52.960 --> 03:35:57.520] because what's happening is I'm downloading hella swag from GitHub, and I'm rendering all of | |
[03:35:57.520 --> 03:36:02.640] its examples. And there are a total of 10,000 examples. I am rendering them into this format. | |
[03:36:04.160 --> 03:36:10.880] And so here at the end of this render example function, you can see that I'm returning the tokens, | |
[03:36:10.880 --> 03:36:21.360] the tokens of this four by T array of tokens, the mask, which tells us which parts are the options, | |
[03:36:21.360 --> 03:36:27.120] and everything else is zero, and the label, that is the correct label. And so that allows us to | |
[03:36:27.120 --> 03:36:31.840] then iterate the examples and render them. And I have an evaluate function here, which can load | |
[03:36:31.840 --> 03:36:40.560] a GPT two from our new face. And it runs the eval here. And basically just calculates, | |
[03:36:40.560 --> 03:36:47.040] just as I described, it predicts the option that has the lowest or the highest probability. And | |
[03:36:47.040 --> 03:36:51.840] the way to do that actually is we can basically evaluate the cross entropy loss. So we're basically | |
[03:36:51.840 --> 03:36:57.120] valuing the loss of predicting the next token in the sequence. And then we're looking at the row | |
[03:36:57.120 --> 03:37:05.040] that has the lowest average loss. And that's the option that we pick as the prediction. | |
[03:37:05.040 --> 03:37:09.440] And then we do some stats and prints and stuff like that. So that is a way to evaluate the | |
[03:37:09.440 --> 03:37:15.520] loss. Now, if you go up here, I'm showing that for GPT two, one 24m, if you run this script, | |
[03:37:15.520 --> 03:37:22.720] you're going to see that hella swag gets 29.55%. So that's the performance we get here. Now, | |
[03:37:22.720 --> 03:37:28.560] remember that random chances 25%. So we haven't gone too far. And GPT two XL, | |
[03:37:28.560 --> 03:37:35.840] which is the biggest V GPT two, gets all the way up to 49% roughly. So these are pretty low values | |
[03:37:35.840 --> 03:37:40.720] considering that today's state of the art is more like 95%. So these are definitely older models by | |
[03:37:40.720 --> 03:37:45.760] now. And then there's one more thing called a Luther harness, which is a very common piece of | |
[03:37:45.760 --> 03:37:50.240] infrastructure for running emails for language models. And they get slightly different numbers. | |
[03:37:50.240 --> 03:37:55.280] And I'm not 100% sure what the discrepancy is for these. It could be that they actually | |
[03:37:55.280 --> 03:37:59.840] do the multiple choice instead of just the completions. And then that could be the | |
[03:37:59.840 --> 03:38:06.080] discrepancy. But I'm not 100% sure about that. I'd have to take a look. But for now, our script | |
[03:38:06.080 --> 03:38:12.320] reports 29.55. And so that is the number that we'd like to beat if we're training a GP two on 24m | |
[03:38:12.320 --> 03:38:21.120] from scratch in ourselves. So now I'm going to go into actually incorporating this eval | |
[03:38:21.120 --> 03:38:28.880] into our main training script. And and basically because we want to evaluate it in a periodic manner, | |
[03:38:28.880 --> 03:38:34.880] so that we can track hella swag and how it evolves over time. And see when and if we cross | |
[03:38:35.600 --> 03:38:43.040] this 29.55 sort of region. So let's now walk through some of the changes to train GPT to that | |
[03:38:43.040 --> 03:38:49.280] pipe. The first thing I did here is actually made use compile optional kind of and I disabled it | |
[03:38:49.280 --> 03:38:55.120] by default. And the problem with that is the problem with compile is that unfortunately it | |
[03:38:55.120 --> 03:38:59.520] does make our code faster, but it actually breaks the evaluation code and the sampling code. It | |
[03:38:59.520 --> 03:39:03.520] gives me a very gnarly message and I don't know why. So hopefully by the time you get to the | |
[03:39:04.160 --> 03:39:08.400] code base, when I put it up on GitHub, we're gonna fix that by then. But for now I'm running | |
[03:39:08.400 --> 03:39:13.360] without torch compile, which is why you see this be a bit slower. So we're running without torch | |
[03:39:13.360 --> 03:39:20.960] compile. I also created a log directory log where we can place our log.txt, which will record the | |
[03:39:20.960 --> 03:39:26.080] train loss, validation loss, and the hella swag accuracies. So a very simple text file and we're | |
[03:39:26.080 --> 03:39:31.520] going to open for writing so that it sort of starts empty. And then we're going to append to it. | |
[03:39:33.440 --> 03:39:39.520] I created a simple variable that helps tell us when we have a last step. And then basically | |
[03:39:39.520 --> 03:39:46.080] periodically inside this loop, every 250th iteration or at the last step, we're going to evaluate | |
[03:39:46.080 --> 03:39:54.560] the validation loss. And then every 2250th iteration, we are going to evaluate hella swag, | |
[03:39:54.560 --> 03:40:00.320] but only if we are not using compile because compile breaks it. So I'm going to come back to | |
[03:40:00.320 --> 03:40:06.080] this code for evaluating hella swag in a second. And then every 250th iteration as well, | |
[03:40:06.080 --> 03:40:09.920] we're also going to sample from the model. And so you should recognize this as our | |
[03:40:09.920 --> 03:40:14.160] ancient code from way back when we started the video. And we're just sampling from the model. | |
[03:40:14.160 --> 03:40:22.800] And then finally here, these are, if we're not, after we validate sample and evaluate hella swag, | |
[03:40:22.800 --> 03:40:28.560] we actually do a training step here. And so this is one step of training, and you should be pretty | |
[03:40:28.560 --> 03:40:33.440] familiar with all of what this does. And at the end here, once we get our training loss, | |
[03:40:33.440 --> 03:40:38.320] we write it to the file. So the only thing that changed that I really added is this entire section | |
[03:40:38.320 --> 03:40:43.680] for hella swag eval. And the way this works is I'm trying to get all the GPUs to collaborate on | |
[03:40:43.680 --> 03:40:50.480] the hella swag. And so we're iterating on the examples. And then each process only picks the | |
[03:40:50.480 --> 03:40:56.320] examples that assigned to it. So we sort of take I and mod it by the world size, and we have to | |
[03:40:56.320 --> 03:41:02.480] make it equal to rank, otherwise we continue. And then we render an example, put it on a GPU, | |
[03:41:02.480 --> 03:41:08.000] we get the logits, then I created helper function that helps us basically predict the option with | |
[03:41:08.000 --> 03:41:13.840] the lowest loss. So this comes here, the prediction. And then if it's correct, we sort of keep count. | |
[03:41:13.840 --> 03:41:19.040] And then if multiple processes were collaborating on all this, then we need to synchronize their | |
[03:41:19.040 --> 03:41:24.960] stats. And so the way one way to do that is to package up our statistics here into tensors, | |
[03:41:25.600 --> 03:41:33.120] which we can then call this dot already is on, and some. And then here we sort of unwrap them | |
[03:41:33.120 --> 03:41:38.480] from tensors so that we just have ins. And then here the master process will print and log the | |
[03:41:38.480 --> 03:41:47.120] hella swag accuracy. So that's kind of the that's kind of it. And that's what I'm running right here. | |
[03:41:47.120 --> 03:41:54.080] So you see this optimization here. And we just had a generation. And this is step 10,000 out of | |
[03:41:54.080 --> 03:41:59.680] about 20,000, right? So we are halfway done. And these are kinds of samples that we are getting | |
[03:41:59.680 --> 03:42:05.360] at this stage. So let's take a look. Hello, I'm a language model. So I'd like to use it to generate | |
[03:42:05.360 --> 03:42:09.920] some kinds of output. Hello, I'm a language model, and I'm a developer for a lot of companies. | |
[03:42:11.040 --> 03:42:21.520] A long language model. Let's see if I can find any fun one. | |
[03:42:21.520 --> 03:42:32.800] I don't know, you can go through this yourself, but certainly the predictions are getting less and | |
[03:42:32.800 --> 03:42:38.960] less random. It seems like the model is a little bit more self-aware in using language that is a bit | |
[03:42:38.960 --> 03:42:46.240] more specific to it being a language model. Hello, I'm a language model. And like how the language | |
[03:42:46.240 --> 03:42:51.440] is used to communicate, I'm a language model and are going to be speaking English and German. | |
[03:42:51.440 --> 03:42:57.760] So let's just wait until this optimization finishes. And we'll see what kind of samples we get. | |
[03:42:57.760 --> 03:43:03.680] And we're also going to look at the train, the vowel, and the hella swag accuracy and see how | |
[03:43:03.680 --> 03:43:10.560] we're doing with respect to GPT2. Okay, good morning. So focusing for a moment on the Jupyar | |
[03:43:10.560 --> 03:43:15.520] notebook here on the right, I created a new cell that basically allows us to visualize the train, | |
[03:43:15.520 --> 03:43:23.040] vowel, and the hella score. And you can step through this. It basically parses the log file that we | |
[03:43:23.040 --> 03:43:29.520] are writing. And a lot of this is just like boring matplotlib code. But basically, this is what our | |
[03:43:29.520 --> 03:43:40.160] optimization looks like. So we ran for 19,073 steps, which is roughly 10 billion tokens, | |
[03:43:40.160 --> 03:43:45.520] which is whoops, oh my gosh, which is one epoch of the sample 10B of firewall video. | |
[03:43:45.520 --> 03:43:52.000] On the left, we have the loss. And in the blue, we have the training loss. In orange, we have the | |
[03:43:52.000 --> 03:43:59.360] validation loss. And red as a horizontal line, we have the opening of GPT2, 124M model checkpoint, | |
[03:43:59.360 --> 03:44:06.640] when it's just evaluated on the validation set of this fine web media. So you can see that we are | |
[03:44:06.640 --> 03:44:12.560] surpassing this orange is below the red. So we're surpassing the validation set of this data set. | |
[03:44:12.560 --> 03:44:16.640] And like I mentioned, the data distribution is very different from what GPT2 trained on. | |
[03:44:16.640 --> 03:44:22.720] So this is not exactly fair comparison, but it's a good cross check to look at. | |
[03:44:22.720 --> 03:44:29.040] Now we would ideally like something that is withheld and comparable and somewhat standard. | |
[03:44:29.040 --> 03:44:35.680] And so for us, that is hella swag. And so on here, we see the hella swag progress we made from 25% | |
[03:44:35.680 --> 03:44:44.560] all the way here. In red, we see the opening of GPT2, 124M model in red. So it achieves this | |
[03:44:44.560 --> 03:44:51.920] hella swag here. And the GPT3 model 124M, which was trained on 300 billion tokens, | |
[03:44:51.920 --> 03:45:00.000] achieves green. So that's over here. So you see that we basically surpassed the GPT2 124M model | |
[03:45:00.000 --> 03:45:07.520] right here, which is really nice. Now, interestingly, we were able to do so with | |
[03:45:07.520 --> 03:45:12.160] only training on 10 billion tokens, while GPT2 was trained on 100 billion tokens. | |
[03:45:12.960 --> 03:45:17.760] So for some reason, we were able to get away with significantly fewer tokens for training. | |
[03:45:17.760 --> 03:45:22.960] There are many possibilities to us to why we could match or surpass this accuracy | |
[03:45:22.960 --> 03:45:30.960] with only 10 billion training. So number one, it could be that opening of GPT2 was trained on | |
[03:45:30.960 --> 03:45:37.520] a much wider data distribution. So in particular, fine web EDU is all English. It's not multilingual. | |
[03:45:38.080 --> 03:45:44.240] And there's not that much math and code. And so math and code and multilingual could have been | |
[03:45:44.240 --> 03:45:52.000] stealing capacity from the original GPT2 model. And basically, that could be partially the reason | |
[03:45:52.000 --> 03:45:57.600] why this is not working out. There's many other reasons. So for example, the hella swag eval is | |
[03:45:57.600 --> 03:46:03.600] fairly old, maybe five years or so. It is possible that aspects of hella swag in some way or even | |
[03:46:03.600 --> 03:46:10.080] identically have made it into the training set of fine web. We don't know for sure, but if that | |
[03:46:10.080 --> 03:46:13.360] was the case, then we are basically looking at the training curve instead of the validation curve. | |
[03:46:13.360 --> 03:46:18.880] So long story short, this is not a perfect eval and there's some caveats here. But at least we | |
[03:46:18.880 --> 03:46:25.920] have some confidence that we're not doing something completely wrong. And it's probably the case | |
[03:46:25.920 --> 03:46:30.240] that when people try to create these datasets, they try to make sure that test sets that are very | |
[03:46:30.240 --> 03:46:35.520] common are not part of the training set. For example, when hugging face created the fine web EDU, | |
[03:46:35.520 --> 03:46:40.240] they use hella swag as an eval. So I would hope that they make sure that they deduplicate and | |
[03:46:40.240 --> 03:46:45.760] that there's no hella swag in the training set. But we can't be sure. The other thing I wanted | |
[03:46:45.760 --> 03:46:50.960] to address briefly is, look at this lusker. This looks really, this looks really wrong here. | |
[03:46:50.960 --> 03:46:56.720] I don't actually know 100% what this is. And I suspect it's because the 10 billion sample of fine | |
[03:46:56.720 --> 03:47:04.960] web EDU was not properly shuffled. And there's some issue here with the data that I don't fully | |
[03:47:04.960 --> 03:47:10.560] understand yet. And there's some weird periodicity to it. And because we are in a very lazy way sort | |
[03:47:10.560 --> 03:47:15.120] of serializing all the tokens and just iterating on them from scratch without doing any permutations | |
[03:47:15.120 --> 03:47:21.360] or any random sampling ourselves, I think we're inheriting some of the ordering that they have | |
[03:47:21.360 --> 03:47:27.840] in a dataset. So this is not ideal. But hopefully by the time you get to this repo, | |
[03:47:27.840 --> 03:47:34.400] some of these things, by the way, will hopefully be fixed. And I will release this build nano GPT | |
[03:47:34.400 --> 03:47:39.520] repo. And right now it looks a little ugly and preliminary. So hopefully by the time you get | |
[03:47:39.520 --> 03:47:45.360] here, it's nicer. But down here, I'm going to show Erada. And I'm going to talk about some of the | |
[03:47:45.360 --> 03:47:51.520] things that happened after the video. And I expect that we will have fixed the small issue. But for | |
[03:47:51.520 --> 03:47:58.320] now, basically, this shows that our training is not completely wrong. And it shows that we're able | |
[03:47:58.320 --> 03:48:05.120] to surpass the accuracy with only 10x the token budget. And possibly it could be also that the | |
[03:48:05.120 --> 03:48:11.680] data set might have improved. So the original GPT two data set was webtext. It's possible that | |
[03:48:11.680 --> 03:48:16.720] not a lot of care and attention wanted to the data set. This was very early in LMS. Whereas | |
[03:48:16.720 --> 03:48:22.720] now there's a lot more scrutiny on good practices around deduplication, filtering, quality filtering, | |
[03:48:22.720 --> 03:48:27.120] and so on. And it's possible that the data set we're training on is just a higher quality per token. | |
[03:48:27.120 --> 03:48:32.080] And that could be giving us a boost as well. So a number of caveats to think about. But for now, | |
[03:48:32.080 --> 03:48:38.080] we're pretty happy with this. And yeah, now the next thing I was interested in is, as you see, | |
[03:48:38.080 --> 03:48:42.800] it's a morning now. So there was an overnight. And I wanted to basically see how far I could | |
[03:48:42.800 --> 03:48:48.720] push the result. So to do an overnight run, I basically did instead of one epoch, which took | |
[03:48:48.720 --> 03:48:53.680] roughly two hours. I just did a times four, so that that would take eight hours while I was sleeping. | |
[03:48:53.680 --> 03:48:59.600] And so we did four epochs or roughly 40 billion tokens of training. And I was trying to see how | |
[03:48:59.600 --> 03:49:05.600] far we could get. And so this was the only change in I rerun the script. And when I point and read | |
[03:49:05.600 --> 03:49:13.520] the log file at the 40 B, this is what the curve looked like. Okay, so to narrate this, number one, | |
[03:49:13.520 --> 03:49:18.400] we are seeing this issue here with the periodicity through the different epochs and something really | |
[03:49:18.400 --> 03:49:25.280] weird with the fine web edu data set. And that is to be determined. But otherwise, we are seeing | |
[03:49:25.280 --> 03:49:32.320] that the helus wag actually went up by a lot. And we almost, we almost made it to the GPT three | |
[03:49:32.320 --> 03:49:39.120] 124M accuracy up here, but not quite. So it's too bad that I didn't sleep slightly longer. | |
[03:49:39.120 --> 03:49:47.520] And I think if this was a five epoch run, we may have gotten here. Now, one thing to point out is | |
[03:49:47.520 --> 03:49:52.880] that if you're doing multi epoch runs, we're not actually being very careful in our data loader. | |
[03:49:52.880 --> 03:50:00.960] And we're not. This data loader goes through the data in exactly the same format and exactly the | |
[03:50:00.960 --> 03:50:05.280] same order. And this is kind of suboptimal. And you would want to look into extensions where you | |
[03:50:05.280 --> 03:50:11.360] actually permute the data randomly. You permute the documents around in every single shard on every | |
[03:50:11.360 --> 03:50:18.080] single new epoch, and potentially even permute the shards. And that would go a long way into | |
[03:50:18.080 --> 03:50:22.560] decreasing the pre allicity. And it's also better for the optimization, so that you're not seeing | |
[03:50:22.560 --> 03:50:28.000] things in the identical format. And you're introducing some of the some of the randomness in how the | |
[03:50:28.000 --> 03:50:32.640] documents follow each other. Because you have to remember that in every single row, these documents | |
[03:50:32.640 --> 03:50:36.480] follow each other. And then there's the end of text token, and then the next document. So the | |
[03:50:36.480 --> 03:50:41.920] documents are currently glued together in the exact same identical manner. But we actually want | |
[03:50:41.920 --> 03:50:46.640] to break break up the documents and shuffle them around, because the order of the documents shouldn't | |
[03:50:46.640 --> 03:50:51.680] matter. And they shouldn't basically want to break up that dependence, because it says kind of a | |
[03:50:51.680 --> 03:50:56.960] spurious correlation. And so our data loader is not currently doing that. And that's one improvement | |
[03:50:56.960 --> 03:51:03.120] you could think of making. The other thing to point out is we're almost matching GPT-3 accuracy | |
[03:51:03.120 --> 03:51:09.040] with only 40 billion tokens, GPT-3 trained on 300 billion tokens. So again, we're seeing about a | |
[03:51:09.040 --> 03:51:16.400] 10x improvement here with respect to learning efficiency. The other thing I wanted to, and | |
[03:51:16.400 --> 03:51:19.680] I don't actually know exactly what to attribute this to, other than some of the things that I | |
[03:51:19.680 --> 03:51:24.800] already mentioned previously for the previous one. The other thing I wanted to briefly mention is | |
[03:51:26.160 --> 03:51:31.360] the max LR here. I saw some people already play with this a little bit in a previous related | |
[03:51:31.360 --> 03:51:36.880] repository. And it turns out that you can actually almost like 3x this. So it's possible that the | |
[03:51:36.880 --> 03:51:41.200] maximum learning rate can be a lot higher. And for some reason, the GPT-3 hyperparameters that we | |
[03:51:41.200 --> 03:51:45.520] are inheriting are actually extremely conservative. And you can actually get away with higher learning | |
[03:51:45.520 --> 03:51:52.000] rate and it would train faster. So a lot of these hyperparameters are quite tunable and feel free | |
[03:51:52.000 --> 03:51:59.200] to play with them. And they're probably not set precisely correctly. And it's possible that you | |
[03:51:59.200 --> 03:52:05.200] can get away with doing this, basically. And if you wanted to exactly be faithful to GPT-3, | |
[03:52:05.200 --> 03:52:11.920] you would also want to make the following difference. You'd want to come here. And the sequence length | |
[03:52:11.920 --> 03:52:19.600] of GPT-3 is 2x. It's 2048 instead of 1024. So you would come here, changes to 2048 for T. | |
[03:52:19.600 --> 03:52:25.440] And then if you want the exact same number of tokens, half a million per iteration or per step, | |
[03:52:25.440 --> 03:52:29.440] you want to then decrease this to 232. So they still multiply to half a million. | |
[03:52:29.440 --> 03:52:36.400] So that would give your model sequence length equal to that GPT-3. And in that case, basically, | |
[03:52:36.400 --> 03:52:44.480] the models would be roughly identical as far as I'm aware. Because again, GPT-2 and GPT-3 are very, | |
[03:52:44.480 --> 03:52:49.120] very similar models. Now, we can also look at some of the samples here from the model | |
[03:52:49.120 --> 03:52:56.000] that was trained overnight. So this is the optimization. And you see that here we stepped | |
[03:52:56.000 --> 03:53:04.480] all the way to 76,000, 290, also or so. And the hella-spy we achieved was 33.24. | |
[03:53:04.480 --> 03:53:10.320] And these are some of the samples from the model. And you can see that if you read through this | |
[03:53:10.320 --> 03:53:16.560] and pause the video briefly, you can see that there are a lot more coherent. And they're actually | |
[03:53:16.560 --> 03:53:20.000] addressing the fact that it's a language model almost. So | |
[03:53:20.000 --> 03:53:25.840] hello, I'm a language model, and I try to be as accurate as possible. | |
[03:53:25.840 --> 03:53:30.400] I'm a language model, not a programming language. | |
[03:53:30.400 --> 03:53:35.280] I know how to communicate. I use Python. | |
[03:53:35.280 --> 03:53:43.040] I don't know. If you pause this and look at it and then compare it to the model that was | |
[03:53:43.040 --> 03:53:46.720] only trained for 10 billion, you will see that these are a lot more coherent. | |
[03:53:46.720 --> 03:53:51.040] And you can play with this yourself. One more thing I added to the code, by the way, | |
[03:53:51.040 --> 03:53:56.240] is this chunk of code here. So basically, right after we evaluate the validation loss, | |
[03:53:56.240 --> 03:54:01.520] if we are the master process, in addition to logging the validation loss, every 5,000 steps | |
[03:54:01.520 --> 03:54:06.080] we're also going to save the checkpoint, which is really just the state dictionary of the model. | |
[03:54:06.080 --> 03:54:10.720] And so checkpointing is nice just because you can save the model and later you can | |
[03:54:11.360 --> 03:54:16.880] use it in some way. If you wanted to resume the optimization, then in addition to saving the model, | |
[03:54:16.880 --> 03:54:22.240] we have to also save the optimizer state dict, because remember that the optimizer has a few | |
[03:54:22.240 --> 03:54:29.200] additional buffers because of Adam. So it's got the M and B, and you need to also resume the | |
[03:54:29.200 --> 03:54:34.080] optimizer properly. You have to be careful with the RNG seeds, random number generators, and so on. | |
[03:54:34.080 --> 03:54:38.960] So if you wanted to exactly be able to resume optimization, you have to think through the state | |
[03:54:38.960 --> 03:54:43.440] of the training process. But if you just want to save the model, this is how you would do it. | |
[03:54:43.440 --> 03:54:48.640] And one nice reason why you might want to do this is because you may want to evaluate the model | |
[03:54:48.640 --> 03:54:54.640] up more carefully. So here we are only kind of like winging the hella swag eval, but you may want | |
[03:54:54.640 --> 03:55:02.960] to use something nicer. Like, for example, the Luther evaluation hardness, evaluation hardness, | |
[03:55:04.000 --> 03:55:11.680] hardness. So this is a way to also evaluate language models. And so it's possible that | |
[03:55:11.680 --> 03:55:18.880] you may want to use basically different infrastructure to more thoroughly evaluate the models on different | |
[03:55:18.880 --> 03:55:25.840] evaluations and compare it to the opening RGP2 model on many other tasks. Like, for example, | |
[03:55:25.840 --> 03:55:30.480] that involve math code or different languages and so on. So this is a nice functionality to have as | |
[03:55:30.480 --> 03:55:36.960] possible. And then the other thing I wanted to mention is that everything we've built here, | |
[03:55:36.960 --> 03:55:43.840] this is only the pre-training step. So the GPT here is a, it dreams documents, it just predicts | |
[03:55:43.840 --> 03:55:49.520] the next token. You can't talk to it like you can talk to chat GPT chat GPT. If you wanted to | |
[03:55:49.520 --> 03:55:54.400] talk to the model, we have to fine tune it into the chat format. And that's not actually like | |
[03:55:54.400 --> 03:55:58.960] that complicated. If you're looking at supervised fine tuning or SFT, really what that means is | |
[03:55:58.960 --> 03:56:02.960] we're just swapping out a dataset into a dataset that is a lot more conversational, | |
[03:56:02.960 --> 03:56:07.200] and there's a user-assistant user-assistant kind of structure. And we just fine tune on it. | |
[03:56:07.200 --> 03:56:13.440] And then we basically fill in the user tokens and we sample the assistant tokens. It's not a lot | |
[03:56:13.440 --> 03:56:19.120] more deeper than that. But basically we swap out the dataset and continue training. But for now, | |
[03:56:19.120 --> 03:56:23.520] we're going to stop at pre-training. One more thing that I wanted to briefly show you is that, | |
[03:56:23.520 --> 03:56:28.800] of course, what we've built up today was building towards nano GPT, which is this repository from | |
[03:56:28.800 --> 03:56:34.320] earlier. But also there's actually another nano GPT implementation, and it's hiding in a | |
[03:56:34.320 --> 03:56:42.240] more recent project that I've been working on, called LLM.C. And LLM.C is a pure C CUDA implementation | |
[03:56:42.240 --> 03:56:49.040] of GPT-2 or GPT-3 training. And it just directly uses CUDA and is written as C CUDA. | |
[03:56:49.600 --> 03:56:54.880] Now the nano GPT here acts as reference code in PyTorch to the C implementation. So we're | |
[03:56:54.880 --> 03:56:59.920] trying to exactly match up the two, but we're hoping that the C CUDA is faster. And of course, | |
[03:56:59.920 --> 03:57:05.200] currently that seems to be the case, because it is a direct optimized implementation. So | |
[03:57:05.200 --> 03:57:12.000] traingpt2.py in LLM.C is basically the nano GPT. And when you scroll through this file, | |
[03:57:12.000 --> 03:57:19.360] you'll find a lot of things that very much look like things that we've built up in this lecture. | |
[03:57:19.920 --> 03:57:25.360] And then when you look at traingpt2.cu, this is the C CUDA implementation. | |
[03:57:25.360 --> 03:57:33.280] So there's a lot of MPI and they're called GPU CUDA CC++. And you have to be familiar with that. But | |
[03:57:33.280 --> 03:57:40.240] when this is built up, we can actually run the two side by side. And they're going to produce | |
[03:57:40.240 --> 03:57:46.480] the exact same results, but LLM.C actually runs faster. So let's see that. So on the left, I have | |
[03:57:46.480 --> 03:57:52.960] PyTorch, nano GPT looking thing. On the right, I have the LLM.C call. And here I'm going to | |
[03:57:52.960 --> 03:57:58.000] launch the two. Both of these are going to be running on a single GPU. And here I'm putting | |
[03:57:58.000 --> 03:58:04.000] the LM.C on GPU one. And this one will grab GPU zero by default. And then | |
[03:58:04.000 --> 03:58:12.240] we can see here that LLM.C compiled and then allocate space and it's stepping. | |
[03:58:13.520 --> 03:58:21.200] So basically, meanwhile, PyTorch is still compiling because Torque compile is a bit slower here | |
[03:58:21.200 --> 03:58:28.400] than the LLM.C NVCC C CUDA compile. And so this program has already started running. And we're | |
[03:58:28.400 --> 03:58:33.440] still waiting here for Torque compile. Now, of course, this is a very specific implementation | |
[03:58:33.440 --> 03:58:38.320] to GPT two and three. PyTorch is a very general neural network framework. So they're not exactly | |
[03:58:38.320 --> 03:58:43.280] comparable. But if you're only interested in training GPT two and three, LLM.C is very fast. | |
[03:58:43.280 --> 03:58:51.920] It takes less space. It's faster to start and it's faster per step. And so PyTorch started | |
[03:58:51.920 --> 03:58:58.080] stepping here. And as you can see, we're running at about 223,000 tokens per second here and about | |
[03:58:58.080 --> 03:59:06.080] 185,000 tokens per second here. So quite a bit slower. But I don't have full confidence that I | |
[03:59:06.640 --> 03:59:11.440] exactly squeezed out all the juice from the PyTorch implementation. But the important thing | |
[03:59:11.440 --> 03:59:17.200] here is notice that if I line up the steps, you will see that the losses and the norms that are | |
[03:59:17.200 --> 03:59:23.120] printed between these two are identical. So on the left, we have the PyTorch and on the right, | |
[03:59:23.120 --> 03:59:28.880] this secret implementation. And they're the same except this one runs faster. So that's kind of, | |
[03:59:28.880 --> 03:59:34.480] I wanted to show you also briefly, LLM.C. And this is a parallel implementation. And it's also | |
[03:59:34.480 --> 03:59:39.440] something that you may want to play with or look at. And it's kind of interesting. | |
[03:59:39.440 --> 03:59:43.360] Okay. So at this point, I should probably start wrapping up the video because I think it's getting | |
[03:59:43.360 --> 03:59:48.720] way longer than anticipated. But we did cover a lot of ground and we built everything from scratch. | |
[03:59:48.720 --> 03:59:54.560] So as a brief summary, we were looking at the GPT two and GPT three papers. | |
[03:59:54.560 --> 04:00:00.560] We were looking at how you set up these training runs and all the considerations evolved. | |
[04:00:00.560 --> 04:00:05.120] We wrote everything from scratch. And then we saw that over the duration of either a two-hour | |
[04:00:05.120 --> 04:00:10.960] training run or an overnight run, we can actually match the 124 million parameter checkpoints of | |
[04:00:10.960 --> 04:00:16.880] GPT two and GPT three to a very large extent. In principle, the code that we wrote would be | |
[04:00:16.880 --> 04:00:20.400] able to train even bigger models if you have the patients or the computing resources. | |
[04:00:20.400 --> 04:00:24.480] And so you could potentially think about training some of the bigger checkpoints as well. | |
[04:00:25.920 --> 04:00:30.960] There are a few remaining issues to address. What's happening with the loss here, which I suspect | |
[04:00:30.960 --> 04:00:36.880] has to do with the fine web EDU data sampling. Why can't we turn on torch compile? It currently | |
[04:00:36.880 --> 04:00:41.760] breaks generation and hella swag. What's up with that? In the data loader, we should probably be | |
[04:00:41.760 --> 04:00:46.640] permuting our data when we reach epoch boundaries. So there's a few more issues like that. And I | |
[04:00:46.640 --> 04:00:51.840] expect to be documenting some of those over time in the build managed GPT repository here, | |
[04:00:52.960 --> 04:00:57.920] which I'm going to be releasing with this video. If you have any questions or would like to talk | |
[04:00:57.920 --> 04:01:04.160] about anything that we covered, please go to discussion tab so we can talk here. Or please go | |
[04:01:04.160 --> 04:01:08.720] to issues or pull requests, pull requests, depending on what you'd like to contribute. | |
[04:01:08.720 --> 04:01:15.040] Or also have a look at the zero to hero discord. And I'm going to be hanging out here on that | |
[04:01:15.040 --> 04:01:24.160] on GPT. Otherwise, for now, I'm pretty happy about where we got. And I hope you enjoyed the video. | |
[04:01:24.160 --> 04:01:25.680] And I will see you later. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment