Forked from MattPitlyk/fine-tuning-gpt-2-on-a-custom-dataset.ipynb
Created
December 18, 2022 20:48
-
-
Save wilfoderek/aa5a19059b524c10c988c35e03c5fabf to your computer and use it in GitHub Desktop.
Fine-Tuning GPT-2 on a Custom Dataset
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "Fine-Tuning GPT-2 on a Custom Dataset", | |
"provenance": [], | |
"collapsed_sections": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/MattPitlyk/45541145ad48b93da395f0a72ec2e7dc/fine-tuning-gpt-2-on-a-custom-dataset.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "nXQ9wn5dfiBz", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Fine-tuning GPT-2 on a Custom Dataset" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "sL4ht6hRHj1W", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## Setup" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "iSLvhw96EXK5", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Make sure you enable a GPU or TPU in the runtime \n", | |
"Runtime -> Change Runtime -> Hardware Accelerator " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "y-2RF_rof4o0", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Install helper library for fine-tuning." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "KBkpRgBCBS2_", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"!pip install -q gpt-2-simple" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "FYfTaj9lfgDa", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Imports" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "dJOSDm7EF00R", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# Colab pre-installs many common libraries including TensorFlow.\n", | |
"# Use Colab magic command to tell Colab to not use TensorFlow 2.0.\n", | |
"%tensorflow_version 1.x\n", | |
"import gpt_2_simple as gpt2\n", | |
"from datetime import datetime\n", | |
"from google.colab import files" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Gvlc3CXwhDal", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Download GPT-2 models. \n", | |
"When fine-tuning on a single GPU, only the 124M and 335M size models can be used." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "P8wSlgXoDPCR", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"#gpt2.download_gpt2(model_name=\"124M\")\n", | |
"gpt2.download_gpt2(model_name=\"355M\")\n" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "-mLQqkm5G-A1", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Set variables" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "6OFnPCLADfll", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"file_name = \"my_custom_dataset.txt\"\n", | |
"run_name = 'fine_tuning_run_1'\n", | |
"model_size = '355M'" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "d5nqoJ_IhcAj", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"If your custom data is stored in your G-Drive, mount your drive and you can copy the data to Colab with the code below. \n", | |
"Alternatively, you can upload your dataset directly to Colab using the Colab \"Files\" menu on the left (not the \"File\" menu above). \n", | |
"Training examples in the dataset file should be separated with a blank line." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "puq4iC6vUAHc", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"gpt2.mount_gdrive()" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "-Z6okFD8VKtS", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"gpt2.copy_file_from_gdrive(file_name)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "L0qodq58HZmO", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## Perform fine-tuning" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "aeXshJM-Cuaf", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"sess = gpt2.start_tf_sess()\n", | |
"\n", | |
"gpt2.finetune(sess,\n", | |
" dataset=file_name,\n", | |
" model_name=model_size,\n", | |
" steps=200,\n", | |
" restore_from='fresh',\n", | |
" run_name = run_name,\n", | |
" print_every=10,\n", | |
" sample_every=50,\n", | |
" save_every=50)\n", | |
" # , learning_rate=.00003)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "VHdTL8NDbAh3", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# If G-Drive is mounted, save the fine-tuned model to the drive.\n", | |
"gpt2.copy_checkpoint_to_gdrive(run_name)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "rvL8l2rCcp-V", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"End of training" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "7P_W0Ir-itMu", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## Explore results" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "lWGNVgMekNQr", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"print(run_name)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "DCcx5u7sbPTD", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
" # Copy the model from G-Drive if it wasn't trained in this Colab session.\n", | |
"gpt2.copy_checkpoint_from_gdrive(run_name=run_name)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "KkSL_kfii591", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# Must reset the graph if training was done in this Colab session.\n", | |
"import tensorflow as tf\n", | |
"tf.reset_default_graph()" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "-fxL77nvAMAX", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# Load the model.\n", | |
"sess = gpt2.start_tf_sess()\n", | |
"gpt2.load_gpt2(sess, run_name=run_name)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Iofztc2f-58U", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Unconditional generation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "4RNY6RBI9LmL", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"gpt2.generate(sess, run_name=run_name, temperature=.7, length=100, prefix=None, top_k=40, nsamples=10)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "15Es5HiU_GGm", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Conditional generation (give the model an input prompt)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "82hy6qlX_FtR", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"input_prompt = \"Today's weather is\"" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "_WY9OUZmMqEr", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"gpt2.generate(sess, run_name=run_name, temperature=.7, length=100, prefix=input_prompt, top_k=40, nsamples=10)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "uTBDSjm_Onwm", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## Copy model to Google Cloud Storage \n", | |
"We have already copied the model to your G-Drive. \n", | |
"This code will copy the model to your Google Cloud Platform Cloud Storage account (not free), where it can be used in data science deployment pipelines.\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "NLuEs3LtMp0z", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"project_id = 'your_gcp_project_id'\n", | |
"bucket_name = 'your_gcp_bucket_name'" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "JXaejaPOOncO", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"from google.colab import auth\n", | |
"auth.authenticate_user()" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Z0EmRyEYOnT6", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"!gcloud config set project {project_id}" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "D0VyAFzVXkFe", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# Need to zip just the inner folder.\n", | |
"# Currently, the name of the folder inside the zip needs to be `run1`.\n", | |
"!cd checkpoint; zip -r ../your_model_name.zip ." | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "aQH4FMhxYCO7", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"!gsutil -m cp your_model_name.zip gs://{bucket_name}" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "VlNmSlvSMNXr", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "bnZrnHLjF0of", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment