Skip to content

Instantly share code, notes, and snippets.

@CurtisASmith
Created July 12, 2021 12:32
Show Gist options
  • Save CurtisASmith/6d36629b61f350b7ee2c851ade77ae00 to your computer and use it in GitHub Desktop.
Save CurtisASmith/6d36629b61f350b7ee2c851ade77ae00 to your computer and use it in GitHub Desktop.
Unfinished guide to fine-tuning GPT-J

How to Fine Tune GPT-J - The Basics

Before anything else, you'll likely want to apply for access to the TPU Research Cloud (TRC). Combined with a Google Cloud free trial, that should allow you to do everything here for free. Once you're in TRC, you need to create a project and with the name of the new project fill out the form that was emailed to you. Use create_tfrecords.py from the GPT-NEO repo to prepare your data as tfrecords; I might do a separate guide on that. Another thing you might want to do is fork the mesh-transformer-jax repo to make it easier to add and modify the config files.

  1. Install the Google Cloud SDK. We'll need it later.

  2. If you didn't make a project and activate TPU access through TRC yet (or if you plan on paying out of pocket), make one now.

  3. TPUs use Google Cloud buckets for storage, go ahead and create one now. Make sure it's in the region the TPU VM will be; the email from TRC will tell you which region(s) you can use free TPUs in.

  4. You'll need the full pretrained weights in order to fine-tune the model. Download those here

Now that you have a bucket on the cloud and the weights on your PC, you need to upload the weights to the bucket.

  1. Extract GPT-J-6B/step_383500.tar.zstd so you're left with the uncompressed .tar.

  2. Open the Google Cloud SDK and run the following command, replacing the path names as appropriate: gsutil cp LOCAL_PATH_TO.tar gs://YOUR-BUCKET. If that works, the console will show the file being uploaded. Took about 12 hours for me. You'll want to upload tfrecords of your data as well, you can do that here or through the web interface, but trust me when I say you don't want to upload the nearly 70GB weights through the web interface.

  3. Follow this guide up to and including the step "Connect to your Cloud TPU VM"

At this point you should have remote access to the TPU VM!

  1. git clone https://github.com/kingoflolz/mesh-transformer-jax (or your fork)

  2. do something with configs

  3. run device_train.py --config=YOUR_CONFIG.json --tune-model-path=gs://YOUR-BUCKET/step_383500.tar

  4. ???

  5. profit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment