Skip to content

Instantly share code, notes, and snippets.

@innat
Created April 4, 2023 16:12
Show Gist options
  • Save innat/e6c4826382641f640cc91def95026ad3 to your computer and use it in GitHub Desktop.
Save innat/e6c4826382641f640cc91def95026ad3 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "b189b3f9-d343-48d5-a856-f1c08c6b8090",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"('4.28.0.dev0', '2.0.0+cu117')"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import os\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"from accelerate import Accelerator\n",
"from accelerate import (\n",
" init_empty_weights, \n",
" infer_auto_device_map, \n",
" load_checkpoint_and_dispatch\n",
")\n",
"import transformers\n",
"from transformers import GPTJForCausalLM, AutoTokenizer, AutoModelForCausalLM, AutoConfig\n",
"from transformers import TextDataset, DataCollatorForLanguageModeling\n",
"from transformers import Trainer, TrainingArguments\n",
"import torch\n",
"\n",
"transformers.__version__, torch.__version__"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b696e671-4f36-4899-b9b4-c60fab89ecfc",
"metadata": {},
"outputs": [],
"source": [
"def print_trainable_parameters(model):\n",
" \"\"\"\n",
" Prints the number of trainable parameters in the model.\n",
" \"\"\"\n",
" trainable_params = 0\n",
" all_param = 0\n",
" for _, param in model.named_parameters():\n",
" all_param += param.numel()\n",
" if param.requires_grad:\n",
" trainable_params += param.numel()\n",
" print(\n",
" f\"trainable params: {trainable_params} || \\\n",
" all params: {all_param} || trainable%: {100 * trainable_params / all_param}\"\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "20988cf2-9486-4aa0-a458-e1354ea40c9b",
"metadata": {},
"outputs": [],
"source": [
"def load_dataset(file_path, tokenizer, block_size = 256):\n",
" dataset = TextDataset(\n",
" tokenizer = tokenizer,\n",
" file_path = file_path,\n",
" block_size = block_size,\n",
" )\n",
" return dataset"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "49f4993c-46cf-4465-8333-a76f80ef9f66",
"metadata": {},
"outputs": [],
"source": [
"def load_data_collator(tokenizer, mlm = False):\n",
" data_collator = DataCollatorForLanguageModeling(\n",
" tokenizer=tokenizer, \n",
" mlm=mlm,\n",
" )\n",
" return data_collator"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "792458f9-6e96-483f-b28e-b60b5f445c76",
"metadata": {},
"outputs": [],
"source": [
"def get_tokenizer(model_name):\n",
" tokenizer = AutoTokenizer.from_pretrained(\n",
" model_name, \n",
" padding=\"max_length\", \n",
" truncation=True\n",
" )\n",
" tokenizer.pad_token = tokenizer.eos_token\n",
" return tokenizer"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "7e96e437-dfd7-4c5e-a934-c960d52cbe24",
"metadata": {},
"outputs": [],
"source": [
"def get_model(model_name, model_parallel=True):\n",
" model = AutoModelForCausalLM.from_pretrained(\n",
" model_name,\n",
" device_map='auto' if model_parallel else None,\n",
" torch_dtype=torch.float16\n",
" )\n",
" \n",
" # tricks, it reduce gpu-memory consumption\n",
" setattr(model, 'gradient_checkpointing', True)\n",
" \n",
" if model_parallel:\n",
" setattr(model, 'model_parallel', True)\n",
" setattr(model, 'is_parallelizable', True)\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "d8ed7e9f-d959-47f4-8a93-180600139924",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"trainable params: 2651596800 || all params: 2651596800 || trainable%: 100.0\n"
]
}
],
"source": [
"model_name = 'facebook/opt-2.7b'\n",
"model = get_model(model_name, model_parallel=True)\n",
"tokenizer = get_tokenizer(model_name)\n",
"print_trainable_parameters(model)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "d6cc640e-f4b6-4514-a6d2-acccf0db2829",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'model.decoder.embed_tokens': 0, 'lm_head': 0, 'model.decoder.embed_positions': 0, 'model.decoder.final_layer_norm': 0, 'model.decoder.layers.0': 0, 'model.decoder.layers.1': 0, 'model.decoder.layers.2': 0, 'model.decoder.layers.3': 0, 'model.decoder.layers.4': 0, 'model.decoder.layers.5': 0, 'model.decoder.layers.6': 0, 'model.decoder.layers.7': 0, 'model.decoder.layers.8': 0, 'model.decoder.layers.9': 0, 'model.decoder.layers.10': 0, 'model.decoder.layers.11': 0, 'model.decoder.layers.12': 0, 'model.decoder.layers.13': 0, 'model.decoder.layers.14': 0, 'model.decoder.layers.15': 0, 'model.decoder.layers.16': 1, 'model.decoder.layers.17': 1, 'model.decoder.layers.18': 1, 'model.decoder.layers.19': 1, 'model.decoder.layers.20': 1, 'model.decoder.layers.21': 1, 'model.decoder.layers.22': 1, 'model.decoder.layers.23': 1, 'model.decoder.layers.24': 1, 'model.decoder.layers.25': 1, 'model.decoder.layers.26': 1, 'model.decoder.layers.27': 1, 'model.decoder.layers.28': 1, 'model.decoder.layers.29': 1, 'model.decoder.layers.30': 1, 'model.decoder.layers.31': 1}\n"
]
}
],
"source": [
"try:\n",
" print(model.hf_device_map)\n",
"except:\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "613facb7-d1d3-49fb-a1dd-1368d4b04ebb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ALL TRUE\n"
]
}
],
"source": [
"if hasattr(model, \"is_parallelizable\") and model.is_parallelizable and model.model_parallel:\n",
" print('ALL TRUE')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "08346676-161b-476c-98e6-e2c7041a3fec",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 10,
"id": "e292e6fb-8e32-4737-b27c-8de63b8b7b9c",
"metadata": {},
"outputs": [],
"source": [
"train_file_path = \"./data.txt\"\n",
"train_dataset = load_dataset(train_file_path, tokenizer)\n",
"data_collator = load_data_collator(tokenizer)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "af7590b4-262a-4421-be3e-2c0b6137f8b8",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='162' max='3231' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [ 162/3231 01:17 < 24:53, 2.05 it/s, Epoch 0.15/3]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>Step</th>\n",
" <th>Training Loss</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table><p>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"output_dir = 'model'\n",
"os.makedirs(output_dir, exist_ok = True)\n",
"\n",
"training_args = TrainingArguments(\n",
" output_dir=output_dir,\n",
" gradient_accumulation_steps=1,\n",
" per_device_train_batch_size=1,\n",
" overwrite_output_dir=False,\n",
" num_train_epochs=3,\n",
")\n",
"trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" data_collator=data_collator,\n",
" train_dataset=train_dataset,\n",
")\n",
"model.config.use_cache = False \n",
"trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "975966f6-96fd-4f94-9641-c649dbc84233",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"environment": {
"kernel": "gpt_neox",
"name": "pytorch-gpu.1-13.m104",
"type": "gcloud",
"uri": "gcr.io/deeplearning-platform-release/pytorch-gpu.1-13:m104"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
@innat
Copy link
Author

innat commented Apr 4, 2023

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment