Skip to content

Instantly share code, notes, and snippets.

@TheLustriVA
Last active August 11, 2022 16:59
Show Gist options
  • Save TheLustriVA/710a28c707c9139efa9c0a6ad4a67328 to your computer and use it in GitHub Desktop.
Save TheLustriVA/710a28c707c9139efa9c0a6ad4a67328 to your computer and use it in GitHub Desktop.
GPT-3 explains training an image model in basic terms

A conversation with GPT-3 on the basics of training an AI

A diffusion model's response to the prompt: 'The Latent Space.'

Even with some formal postgraduate education in machine-learning, lately I've been getting lost in the complexity of the subject's technical literature.

What follows is my first attempt at using Codex's code-explanation mode to drill down from some recent ML code to concepts I am more confident I have a solid grasp on.

I hope that this and any potential future 'conversations' illustrate that:

  1. Everyone needs a refresher, especially for disciplines as complex as machine learning.
  2. Formal education is helpful but not complete. I found this video on 'experts' says this better than I can.

The following code snippet was taken from the Dataloader section of Lucidrain's imagen_pytorch GitHub repo README.

from imagen_pytorch import Unet, Imagen, ImagenTrainer
from imagen_pytorch.data import Dataset

# unets for unconditional imagen

unet = Unet(
    dim = 32,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 1,
    layer_attns = (False, False, False, True),
    layer_cross_attns = False
)

# imagen, which contains the unet above

imagen = Imagen(
    condition_on_text = False,  # this must be set to False for unconditional Imagen
    unets = unet,
    image_sizes = 128,
    timesteps = 1000
)

trainer = ImagenTrainer(
    imagen = imagen,
    split_valid_from_train = True # whether to split the validation dataset from the training
).cuda()

# instantiate your dataloader, which returns the necessary inputs to the DDPM as tuple in the order of images, text embeddings, then text masks. in this case, only images is returned as it is unconditional training

dataset = Dataset('/path/to/training/images', image_size = 128)

trainer.add_train_dataset(dataset, batch_size = 16)

# working training loop

for i in range(200000):
    loss = trainer.train_step(unet_number = 1, max_batch_size = 4)
    print(f'loss: {loss}')

    if not (i % 50):
        valid_loss = trainer.valid_step(unet_number = 1, max_batch_size = 4)
        print(f'valid loss: {valid_loss}')

    if not (i % 100) and trainer.is_main: # is_main makes sure this can run in distributed
        images = trainer.sample(batch_size = 1, return_pil_images = True) # returns List[Image]
        images[0].save(f'./sample-{i // 100}.png')

The following text was generated by starting with the code snippet and the prompt Here's what the above code is doing: 1. This might be best read by thinking of it as an expert trying to explain a complex topic to someone who responds only with blank looks.

Here's what the above code is doing:

  1. instantiate a unet
  2. instantiate an imagen, which contains the unet above
  3. instantiate a trainer, which contains the imagen above
  4. add a training dataset to the trainer
  5. train for 200000 steps, validating every 50 steps and saving samples every 100 steps

In simple terms, a Unet is a generator, an Imagen is a model, and an ImagenTrainer is a trainer. A Unet generator functions by taking in a text embedding and a text mask, and generating an image. An Imagen model functions by taking in a text embedding and a text mask, and generating an image. An ImagenTrainer trainer functions by taking in a text embedding and a text mask, and generating an image. A text embedding is a tensor of shape (batch_size, text_embedding_dim), and a text mask is a tensor of shape (batch_size, max_text_length). text_embedding_dim is best explained as the number of dimensions in your text embedding. max_text_length is best explained as the maximum length of your text.

A text embedding looks like this:

[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]

A text mask looks like this:

[[1, 1, 1], [1, 1, 0]]

The numbers in a text embedding refer to the embedding of each word in your text. The numbers in a text mask refer to whether or not that word is present in your text. You create a text embedding by using a text embedding model, such as BERT. You create a text mask by using a text embedding model, such as BERT. Text embedding models, such as BERT, create text embeddings by taking in a text, and outputting a text embedding. Text embedding models, such as BERT, create text masks by taking in a text, and outputting a text mask. In simple terms, a text embedding is a tensor that represents a text, and a text mask is a tensor that represents whether or not each word in the text is present. A text embedding tensor represents a text by representing each word in the text as a vector. A text mask tensor represents whether or not each word in the text is present by having a 1 if the word is present, and a 0 if the word is not present.

Text is represented as a vector using the following method:

  1. Take a text, such as "hello world".
  2. Split the text into words, such as ["hello", "world"].
  3. For each word in the text, look up its vector representation in a dictionary.
  4. Concatenate all of the vectors together to form a tensor.
  5. The resulting tensor is the text embedding.

Examples of dictionaries for representing text as a vector include:

  1. Word2Vec
  2. GloVe
  3. BERT
  4. GPT-2
  5. ELMo
  6. XLNet
  7. RoBERTa
  8. ALBERT
  9. DistilBERT
  10. ERNIE
  11. Flair
  12. FastText
  13. Swivel
  14. USE (Universal Sentence Encoder)
  15. ULMFit (Universal Language Model Fine-tuning)
  16. OpenAI GPT (Generative Pre-Training)
  17. OpenAI GPT-2 (Generative Pre-Training 2)
  18. OpenAI GPT-3 (Generative Pre-Training 3)
  19. OpenAI Transformer-XL (Transformer eXtra Long)
  20. OpenAI XLNet (Generalized Autoregressive Pretraining for Language Understanding)
  21. Google BERT (Bidirectional Encoder Representations from Transformers)
  22. Google ALBERT (A Lite BERT for Self-supervised Learning of Language Representations)
  23. Google RoBERTa (Robustly Optimized BERT Pretraining Approach)
  24. Google T5 (Text-To-Text Transfer Transformer)
  25. Facebook XLM (Cross-lingual Language Model Pretraining)
  26. Facebook XLM-R (Cross-lingual Language Model Pretraining for Reading Comprehension)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment