Created
December 30, 2023 17:16
-
-
Save iamaziz/413cd4b5e2a7efa4457c40d7ea366fdb to your computer and use it in GitHub Desktop.
Running Jais LLM on M3 Max chip with 64GB - for some reason it's very slow and too big for a 13b model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Based on: https://huggingface.co/core42/jais-13b" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Memory: 64 GB\n", | |
" Total Number of Cores: 16 (12 performance and 4 efficiency)\n", | |
" Chip: Apple M3 Max\n", | |
"\n", | |
"Sat Dec 30 12:11:51 EST 2023\n", | |
"aziz\n" | |
] | |
} | |
], | |
"source": [ | |
"%%bash\n", | |
"system_profiler SPHardwareDataType | grep \" Memory:\"\n", | |
"system_profiler SPHardwareDataType | grep Cores:\n", | |
"system_profiler SPHardwareDataType | grep Chip:\n", | |
"echo\n", | |
"date\n", | |
"whoami" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"from transformers import AutoTokenizer, AutoModelForCausalLM\n", | |
"# model_path = \"inception-mbzuai/jais-13b\"\n", | |
"# model_path = \"core42/jais-13b\"\n", | |
"model_path = \"./jais-13b\" # local" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Using device: mps\n" | |
] | |
} | |
], | |
"source": [ | |
"# Check if CUDA is available, else check for MPS, otherwise default to CPU\n", | |
"if torch.cuda.is_available():\n", | |
" device = torch.device(\"cuda\")\n", | |
"elif torch.backends.mps.is_available():\n", | |
" device = torch.device(\"mps\")\n", | |
"else:\n", | |
" device = torch.device(\"cpu\")\n", | |
"\n", | |
"print(f\"Using device: {device}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"tokenizer = AutoTokenizer.from_pretrained(model_path)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Loading checkpoint shards: 100%|██████████| 6/6 [00:28<00:00, 4.77s/it]\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 14.2 s, sys: 29.5 s, total: 43.7 s\n", | |
"Wall time: 2min 45s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"# Load model directly\n", | |
"from transformers import AutoModelForCausalLM\n", | |
"model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(device)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def get_response(text,tokenizer=tokenizer,model=model):\n", | |
" input_ids = tokenizer(text, return_tensors=\"pt\").input_ids\n", | |
" inputs = input_ids.to(device)\n", | |
" input_len = inputs.shape[-1]\n", | |
" generate_ids = model.generate(\n", | |
" inputs,\n", | |
" top_p=0.9,\n", | |
" temperature=0.3,\n", | |
" max_length=200-input_len,\n", | |
" min_length=input_len + 4,\n", | |
" repetition_penalty=1.2,\n", | |
" do_sample=True,\n", | |
" )\n", | |
" response = tokenizer.batch_decode(\n", | |
" generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True\n", | |
" )[0]\n", | |
" return response" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"عاصمة دولة الإمارات العربية المتحدة هيمدينة أبوظبي, وهي أكبر مدينة في البلاد. تقع على جزيرة أبو ظبي, عاصمة دولة الإمارات العربية المتحدة هي أيضا واحدة من أكثر المدن اكتظاظا بالسكان في العالم مع ما يقرب من 1 مليون نسمة يعيشون فيها.\n" | |
] | |
} | |
], | |
"source": [ | |
"# this took: 191min 25seconds !!\n", | |
"text= \"عاصمة دولة الإمارات العربية المتحدة\"\n", | |
"print(get_response(text))" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": ".env", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
See updated version: https://gist.github.com/iamaziz/1f14dc9263ec96de7c0b7c6de3d38185