Before anything else, you'll likely want to apply for access to the TPU Research Cloud (TRC). Combined with a Google Cloud free trial, that should allow you to do everything here for free. Once you're in TRC, you need to create a project and with the name of the new project fill out the form that was emailed to you. Use create_tfrecords.py
from the GPT-NEO repo to prepare your data as tfrecords; I might do a separate guide on that. Another thing you might want to do is fork the mesh-transformer-jax repo to make it easier to add and modify the config files.
-
Install the Google Cloud SDK. We'll need it later.
-
If you didn't make a project and activate TPU access through TRC yet (or if you plan on paying out of pocket), make one now.
-
TPUs use Google Cloud buckets for storage, go ahead and create one now. Make sure it's in the region the TPU VM will be; the email from TRC will tell you which region(s) you can use free TPUs in.
-
You'll need the full pretrained weights in order to fine-tune the model. Download those here
Now that you have a bucket on the cloud and the weights on your PC, you need to upload the weights to the bucket.
-
Extract
GPT-J-6B/step_383500.tar.zstd
so you're left with the uncompressed.tar
. -
Open the Google Cloud SDK and run the following command, replacing the path names as appropriate:
gsutil cp LOCAL_PATH_TO.tar gs://YOUR-BUCKET
. If that works, the console will show the file being uploaded. Took about 12 hours for me. You'll want to upload tfrecords of your data as well, you can do that here or through the web interface, but trust me when I say you don't want to upload the nearly 70GB weights through the web interface. -
Follow this guide up to and including the step "Connect to your Cloud TPU VM"
At this point you should have remote access to the TPU VM!
-
git clone https://github.com/kingoflolz/mesh-transformer-jax
(or your fork) -
do something with configs
-
run
device_train.py --config=YOUR_CONFIG.json --tune-model-path=gs://YOUR-BUCKET/step_383500.tar
-
???
-
profit