Created
April 4, 2023 17:27
-
-
Save innat/48857f0796246d7852ae6e38a5010ffa to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "b189b3f9-d343-48d5-a856-f1c08c6b8090", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'0.3.0.dev0'" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import os\n", | |
"import warnings\n", | |
"warnings.filterwarnings(\"ignore\")\n", | |
"\n", | |
"from accelerate import Accelerator\n", | |
"from accelerate import (\n", | |
" init_empty_weights, \n", | |
" infer_auto_device_map, \n", | |
" load_checkpoint_and_dispatch\n", | |
")\n", | |
"from transformers import GPTJForCausalLM, AutoTokenizer, AutoModelForCausalLM, AutoConfig\n", | |
"from transformers import TextDataset, DataCollatorForLanguageModeling\n", | |
"from transformers import Trainer, TrainingArguments\n", | |
"import torch\n", | |
"\n", | |
"import peft\n", | |
"from peft import (\n", | |
" prepare_model_for_int8_training,\n", | |
" LoraConfig,\n", | |
" get_peft_model,\n", | |
")\n", | |
"\n", | |
"peft.__version__" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "b696e671-4f36-4899-b9b4-c60fab89ecfc", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def print_trainable_parameters(model):\n", | |
" \"\"\"\n", | |
" Prints the number of trainable parameters in the model.\n", | |
" \"\"\"\n", | |
" trainable_params = 0\n", | |
" all_param = 0\n", | |
" for _, param in model.named_parameters():\n", | |
" all_param += param.numel()\n", | |
" if param.requires_grad:\n", | |
" trainable_params += param.numel()\n", | |
" print(\n", | |
" f\"trainable params: {trainable_params} || \\\n", | |
" all params: {all_param} || trainable%: {100 * trainable_params / all_param}\"\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "20988cf2-9486-4aa0-a458-e1354ea40c9b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def load_dataset(file_path, tokenizer, block_size = 256):\n", | |
" dataset = TextDataset(\n", | |
" tokenizer = tokenizer,\n", | |
" file_path = file_path,\n", | |
" block_size = block_size,\n", | |
" )\n", | |
" return dataset" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "49f4993c-46cf-4465-8333-a76f80ef9f66", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def load_data_collator(tokenizer, mlm = False):\n", | |
" data_collator = DataCollatorForLanguageModeling(\n", | |
" tokenizer=tokenizer, \n", | |
" mlm=mlm,\n", | |
" )\n", | |
" return data_collator" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "792458f9-6e96-483f-b28e-b60b5f445c76", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def get_tokenizer(model_name):\n", | |
" tokenizer = AutoTokenizer.from_pretrained(\n", | |
" model_name, \n", | |
" padding=\"max_length\", \n", | |
" truncation=True\n", | |
" )\n", | |
" tokenizer.pad_token = tokenizer.eos_token\n", | |
" return tokenizer" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "7e96e437-dfd7-4c5e-a934-c960d52cbe24", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def get_model(model_name, model_parallel=True):\n", | |
" config = LoraConfig(\n", | |
" r=16,\n", | |
" lora_alpha=32,\n", | |
" lora_dropout=0.05,\n", | |
" bias=\"none\",\n", | |
" task_type=\"CAUSAL_LM\",\n", | |
" )\n", | |
" model = AutoModelForCausalLM.from_pretrained(\n", | |
" model_name, \n", | |
" load_in_8bit=True,\n", | |
" device_map='auto' if model_parallel else None,\n", | |
" low_cpu_mem_usage=True\n", | |
" ) \n", | |
" model = prepare_model_for_int8_training(\n", | |
" model, \n", | |
" output_embedding_layer_name=\"lm_head\",\n", | |
" layer_norm_names=[]\n", | |
" )\n", | |
" \n", | |
" if model_parallel:\n", | |
" setattr(model, 'model_parallel', True)\n", | |
" setattr(model, 'is_parallelizable', True)\n", | |
" \n", | |
" # tricks, it reduce gpu-memory consumption\n", | |
" setattr(model, 'gradient_checkpointing', True)\n", | |
" \n", | |
" model = get_peft_model(model, config)\n", | |
" return model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "d8ed7e9f-d959-47f4-8a93-180600139924", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Overriding torch_dtype=None with `torch_dtype=torch.float16` due to requirements of `bitsandbytes` to enable model loading in mixed int8. Either pass torch_dtype=torch.float16 or don't pass this argument at all to remove this warning.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"trainable params: 5242880 || all params: 2656839680 || trainable%: 0.19733520390662038\n" | |
] | |
} | |
], | |
"source": [ | |
"model_name = 'facebook/opt-2.7b'\n", | |
"model = get_model(model_name, model_parallel=True)\n", | |
"tokenizer = get_tokenizer(model_name)\n", | |
"print_trainable_parameters(model) # loRA will reduce trainable parameter" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "d6cc640e-f4b6-4514-a6d2-acccf0db2829", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"{'model.decoder.embed_tokens': 0, 'lm_head': 0, 'model.decoder.embed_positions': 0, 'model.decoder.final_layer_norm': 0, 'model.decoder.layers.0': 0, 'model.decoder.layers.1': 0, 'model.decoder.layers.2': 0, 'model.decoder.layers.3': 0, 'model.decoder.layers.4': 0, 'model.decoder.layers.5': 0, 'model.decoder.layers.6': 0, 'model.decoder.layers.7': 0, 'model.decoder.layers.8': 0, 'model.decoder.layers.9': 0, 'model.decoder.layers.10': 0, 'model.decoder.layers.11': 0, 'model.decoder.layers.12': 0, 'model.decoder.layers.13': 0, 'model.decoder.layers.14': 0, 'model.decoder.layers.15': 0, 'model.decoder.layers.16': 1, 'model.decoder.layers.17': 1, 'model.decoder.layers.18': 1, 'model.decoder.layers.19': 1, 'model.decoder.layers.20': 1, 'model.decoder.layers.21': 1, 'model.decoder.layers.22': 1, 'model.decoder.layers.23': 1, 'model.decoder.layers.24': 1, 'model.decoder.layers.25': 1, 'model.decoder.layers.26': 1, 'model.decoder.layers.27': 1, 'model.decoder.layers.28': 1, 'model.decoder.layers.29': 1, 'model.decoder.layers.30': 1, 'model.decoder.layers.31': 1}\n" | |
] | |
} | |
], | |
"source": [ | |
"try:\n", | |
" print(model.hf_device_map)\n", | |
"except:\n", | |
" pass" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "613facb7-d1d3-49fb-a1dd-1368d4b04ebb", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"ALL TRUE\n" | |
] | |
} | |
], | |
"source": [ | |
"if hasattr(model, \"is_parallelizable\") and model.is_parallelizable and model.model_parallel:\n", | |
" print('ALL TRUE')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "08346676-161b-476c-98e6-e2c7041a3fec", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "e292e6fb-8e32-4737-b27c-8de63b8b7b9c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_file_path = \"./data.txt\"\n", | |
"train_dataset = load_dataset(train_file_path, tokenizer)\n", | |
"data_collator = load_data_collator(tokenizer)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "af7590b4-262a-4421-be3e-2c0b6137f8b8", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"\n", | |
" <div>\n", | |
" \n", | |
" <progress value='78' max='3231' style='width:300px; height:20px; vertical-align: middle;'></progress>\n", | |
" [ 78/3231 01:03 < 43:53, 1.20 it/s, Epoch 0.07/3]\n", | |
" </div>\n", | |
" <table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: left;\">\n", | |
" <th>Step</th>\n", | |
" <th>Training Loss</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" </tbody>\n", | |
"</table><p>" | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"output_dir = 'model'\n", | |
"os.makedirs(output_dir, exist_ok = True)\n", | |
"\n", | |
"training_args = TrainingArguments(\n", | |
" output_dir=output_dir,\n", | |
" gradient_accumulation_steps=1,\n", | |
" per_device_train_batch_size=1,\n", | |
" overwrite_output_dir=False,\n", | |
" num_train_epochs=3,\n", | |
")\n", | |
"trainer = Trainer(\n", | |
" model=model,\n", | |
" args=training_args,\n", | |
" data_collator=data_collator,\n", | |
" train_dataset=train_dataset,\n", | |
")\n", | |
"model.config.use_cache = False \n", | |
"trainer.train()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "1c4374c1-ca51-4929-a764-fc057018161b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"environment": { | |
"kernel": "gpt_neox", | |
"name": "pytorch-gpu.1-13.m104", | |
"type": "gcloud", | |
"uri": "gcr.io/deeplearning-platform-release/pytorch-gpu.1-13:m104" | |
}, | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Author
innat
commented
Apr 4, 2023
- ModelParallel.ipynb
- Issue
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment