Created
April 4, 2023 16:12
-
-
Save innat/e6c4826382641f640cc91def95026ad3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "b189b3f9-d343-48d5-a856-f1c08c6b8090", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"('4.28.0.dev0', '2.0.0+cu117')" | |
] | |
}, | |
"execution_count": 1, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import os\n", | |
"import warnings\n", | |
"warnings.filterwarnings(\"ignore\")\n", | |
"\n", | |
"from accelerate import Accelerator\n", | |
"from accelerate import (\n", | |
" init_empty_weights, \n", | |
" infer_auto_device_map, \n", | |
" load_checkpoint_and_dispatch\n", | |
")\n", | |
"import transformers\n", | |
"from transformers import GPTJForCausalLM, AutoTokenizer, AutoModelForCausalLM, AutoConfig\n", | |
"from transformers import TextDataset, DataCollatorForLanguageModeling\n", | |
"from transformers import Trainer, TrainingArguments\n", | |
"import torch\n", | |
"\n", | |
"transformers.__version__, torch.__version__" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "b696e671-4f36-4899-b9b4-c60fab89ecfc", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def print_trainable_parameters(model):\n", | |
" \"\"\"\n", | |
" Prints the number of trainable parameters in the model.\n", | |
" \"\"\"\n", | |
" trainable_params = 0\n", | |
" all_param = 0\n", | |
" for _, param in model.named_parameters():\n", | |
" all_param += param.numel()\n", | |
" if param.requires_grad:\n", | |
" trainable_params += param.numel()\n", | |
" print(\n", | |
" f\"trainable params: {trainable_params} || \\\n", | |
" all params: {all_param} || trainable%: {100 * trainable_params / all_param}\"\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "20988cf2-9486-4aa0-a458-e1354ea40c9b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def load_dataset(file_path, tokenizer, block_size = 256):\n", | |
" dataset = TextDataset(\n", | |
" tokenizer = tokenizer,\n", | |
" file_path = file_path,\n", | |
" block_size = block_size,\n", | |
" )\n", | |
" return dataset" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "49f4993c-46cf-4465-8333-a76f80ef9f66", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def load_data_collator(tokenizer, mlm = False):\n", | |
" data_collator = DataCollatorForLanguageModeling(\n", | |
" tokenizer=tokenizer, \n", | |
" mlm=mlm,\n", | |
" )\n", | |
" return data_collator" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "792458f9-6e96-483f-b28e-b60b5f445c76", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def get_tokenizer(model_name):\n", | |
" tokenizer = AutoTokenizer.from_pretrained(\n", | |
" model_name, \n", | |
" padding=\"max_length\", \n", | |
" truncation=True\n", | |
" )\n", | |
" tokenizer.pad_token = tokenizer.eos_token\n", | |
" return tokenizer" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "7e96e437-dfd7-4c5e-a934-c960d52cbe24", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def get_model(model_name, model_parallel=True):\n", | |
" model = AutoModelForCausalLM.from_pretrained(\n", | |
" model_name,\n", | |
" device_map='auto' if model_parallel else None,\n", | |
" torch_dtype=torch.float16\n", | |
" )\n", | |
" \n", | |
" # tricks, it reduce gpu-memory consumption\n", | |
" setattr(model, 'gradient_checkpointing', True)\n", | |
" \n", | |
" if model_parallel:\n", | |
" setattr(model, 'model_parallel', True)\n", | |
" setattr(model, 'is_parallelizable', True)\n", | |
" return model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "d8ed7e9f-d959-47f4-8a93-180600139924", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"trainable params: 2651596800 || all params: 2651596800 || trainable%: 100.0\n" | |
] | |
} | |
], | |
"source": [ | |
"model_name = 'facebook/opt-2.7b'\n", | |
"model = get_model(model_name, model_parallel=True)\n", | |
"tokenizer = get_tokenizer(model_name)\n", | |
"print_trainable_parameters(model)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "d6cc640e-f4b6-4514-a6d2-acccf0db2829", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"{'model.decoder.embed_tokens': 0, 'lm_head': 0, 'model.decoder.embed_positions': 0, 'model.decoder.final_layer_norm': 0, 'model.decoder.layers.0': 0, 'model.decoder.layers.1': 0, 'model.decoder.layers.2': 0, 'model.decoder.layers.3': 0, 'model.decoder.layers.4': 0, 'model.decoder.layers.5': 0, 'model.decoder.layers.6': 0, 'model.decoder.layers.7': 0, 'model.decoder.layers.8': 0, 'model.decoder.layers.9': 0, 'model.decoder.layers.10': 0, 'model.decoder.layers.11': 0, 'model.decoder.layers.12': 0, 'model.decoder.layers.13': 0, 'model.decoder.layers.14': 0, 'model.decoder.layers.15': 0, 'model.decoder.layers.16': 1, 'model.decoder.layers.17': 1, 'model.decoder.layers.18': 1, 'model.decoder.layers.19': 1, 'model.decoder.layers.20': 1, 'model.decoder.layers.21': 1, 'model.decoder.layers.22': 1, 'model.decoder.layers.23': 1, 'model.decoder.layers.24': 1, 'model.decoder.layers.25': 1, 'model.decoder.layers.26': 1, 'model.decoder.layers.27': 1, 'model.decoder.layers.28': 1, 'model.decoder.layers.29': 1, 'model.decoder.layers.30': 1, 'model.decoder.layers.31': 1}\n" | |
] | |
} | |
], | |
"source": [ | |
"try:\n", | |
" print(model.hf_device_map)\n", | |
"except:\n", | |
" pass" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "613facb7-d1d3-49fb-a1dd-1368d4b04ebb", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"ALL TRUE\n" | |
] | |
} | |
], | |
"source": [ | |
"if hasattr(model, \"is_parallelizable\") and model.is_parallelizable and model.model_parallel:\n", | |
" print('ALL TRUE')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "08346676-161b-476c-98e6-e2c7041a3fec", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "e292e6fb-8e32-4737-b27c-8de63b8b7b9c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_file_path = \"./data.txt\"\n", | |
"train_dataset = load_dataset(train_file_path, tokenizer)\n", | |
"data_collator = load_data_collator(tokenizer)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "af7590b4-262a-4421-be3e-2c0b6137f8b8", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"\n", | |
" <div>\n", | |
" \n", | |
" <progress value='162' max='3231' style='width:300px; height:20px; vertical-align: middle;'></progress>\n", | |
" [ 162/3231 01:17 < 24:53, 2.05 it/s, Epoch 0.15/3]\n", | |
" </div>\n", | |
" <table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: left;\">\n", | |
" <th>Step</th>\n", | |
" <th>Training Loss</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" </tbody>\n", | |
"</table><p>" | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"output_dir = 'model'\n", | |
"os.makedirs(output_dir, exist_ok = True)\n", | |
"\n", | |
"training_args = TrainingArguments(\n", | |
" output_dir=output_dir,\n", | |
" gradient_accumulation_steps=1,\n", | |
" per_device_train_batch_size=1,\n", | |
" overwrite_output_dir=False,\n", | |
" num_train_epochs=3,\n", | |
")\n", | |
"trainer = Trainer(\n", | |
" model=model,\n", | |
" args=training_args,\n", | |
" data_collator=data_collator,\n", | |
" train_dataset=train_dataset,\n", | |
")\n", | |
"model.config.use_cache = False \n", | |
"trainer.train()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "975966f6-96fd-4f94-9641-c649dbc84233", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"environment": { | |
"kernel": "gpt_neox", | |
"name": "pytorch-gpu.1-13.m104", | |
"type": "gcloud", | |
"uri": "gcr.io/deeplearning-platform-release/pytorch-gpu.1-13:m104" | |
}, | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Author
innat
commented
Apr 4, 2023
- ModelParallel_LoRA.ipynb
- Issue
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment