Last active
April 16, 2024 16:59
-
-
Save iamaziz/1f14dc9263ec96de7c0b7c6de3d38185 to your computer and use it in GitHub Desktop.
Running Jais LLM on M3 Max chip with 64GB - using `torch_dtype=torch.float16`. Now much faster but way off
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Based on: https://huggingface.co/core42/jais-13b" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| " Memory: 64 GB\n", | |
| " Total Number of Cores: 16 (12 performance and 4 efficiency)\n", | |
| " Chip: Apple M3 Max\n", | |
| "\n", | |
| "Wed Jan 3 01:34:00 EST 2024\n", | |
| "aziz\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%bash\n", | |
| "system_profiler SPHardwareDataType | grep \" Memory:\"\n", | |
| "system_profiler SPHardwareDataType | grep Cores:\n", | |
| "system_profiler SPHardwareDataType | grep Chip:\n", | |
| "echo\n", | |
| "date\n", | |
| "whoami" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import torch\n", | |
| "from transformers import AutoTokenizer, AutoModelForCausalLM\n", | |
| "# model_path = \"inception-mbzuai/jais-13b\"\n", | |
| "# model_path = \"core42/jais-13b\"\n", | |
| "model_path = \"./jais-13b\" # local" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Using device: mps\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Check if CUDA is available, else check for MPS, otherwise default to CPU\n", | |
| "if torch.cuda.is_available():\n", | |
| " device = torch.device(\"cuda\")\n", | |
| "elif torch.backends.mps.is_available():\n", | |
| " device = torch.device(\"mps\")\n", | |
| "else:\n", | |
| " device = torch.device(\"cpu\")\n", | |
| "\n", | |
| "print(f\"Using device: {device}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "tokenizer = AutoTokenizer.from_pretrained(model_path)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Loading checkpoint shards: 100%|██████████| 6/6 [00:34<00:00, 5.81s/it]\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 2min 28s, sys: 1min 14s, total: 3min 43s\n", | |
| "Wall time: 41.3 s\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%time\n", | |
| "# Load model directly\n", | |
| "from transformers import AutoModelForCausalLM\n", | |
| "model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16).to(device)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def get_response(text,tokenizer=tokenizer,model=model):\n", | |
| " input_ids = tokenizer(text, return_tensors=\"pt\").input_ids\n", | |
| " inputs = input_ids.to(device)\n", | |
| " input_len = inputs.shape[-1]\n", | |
| " generate_ids = model.generate(\n", | |
| " inputs,\n", | |
| " top_p=0.9,\n", | |
| " temperature=0.3,\n", | |
| " max_length=200-input_len,\n", | |
| " min_length=input_len + 4,\n", | |
| " repetition_penalty=1.2,\n", | |
| " do_sample=True,\n", | |
| " )\n", | |
| " response = tokenizer.batch_decode(\n", | |
| " generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True\n", | |
| " )[0]\n", | |
| " return response" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "عاصمة دولة الإمارات العربية المتحدة /,,- (;,.\n", | |
| "\n", | |
| " and-... of(;, from, \"/ by,[.\n", | |
| " of the:\n", | |
| "\n", | |
| "?\n", | |
| "CPU times: user 4.85 s, sys: 1.16 s, total: 6.01 s\n", | |
| "Wall time: 15.5 s\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%time\n", | |
| "# this took: 191min 25seconds !!\n", | |
| "text= \"عاصمة دولة الإمارات العربية المتحدة\"\n", | |
| "print(get_response(text))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "كم عدد سكان الإمارات العربية المتحدة؟؟ with\n", | |
| ".:,;\n", | |
| ",- is,..'s(..-/,.,?... the:*,..., \"._; ( [/ and, the+. -,.-\n", | |
| ".\n", | |
| ". of and for..\n", | |
| " when from by,,,, -..- the\n", | |
| ",/..,\n", | |
| "\n", | |
| "CPU times: user 10.8 s, sys: 1.38 s, total: 12.2 s\n", | |
| "Wall time: 22 s\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%time\n", | |
| "text= \"كم عدد سكان الإمارات العربية المتحدة؟\"\n", | |
| "print(get_response(text))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "JAISLMHeadModel(\n", | |
| " (transformer): JAISModel(\n", | |
| " (wte): Embedding(84992, 5120)\n", | |
| " (drop): Dropout(p=0.0, inplace=False)\n", | |
| " (h): ModuleList(\n", | |
| " (0-39): 40 x JAISBlock(\n", | |
| " (ln_1): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)\n", | |
| " (attn): JAISAttention(\n", | |
| " (c_attn): Conv1D()\n", | |
| " (c_proj): Conv1D()\n", | |
| " (attn_dropout): Dropout(p=0.0, inplace=False)\n", | |
| " (resid_dropout): Dropout(p=0.0, inplace=False)\n", | |
| " )\n", | |
| " (ln_2): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)\n", | |
| " (mlp): JAISMLP(\n", | |
| " (c_fc): Conv1D()\n", | |
| " (c_fc2): Conv1D()\n", | |
| " (c_proj): Conv1D()\n", | |
| " (act): SwiGLUActivation()\n", | |
| " (dropout): Dropout(p=0.0, inplace=False)\n", | |
| " )\n", | |
| " )\n", | |
| " )\n", | |
| " (ln_f): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)\n", | |
| " (relative_pe): AlibiPositionEmbeddingLayer()\n", | |
| " )\n", | |
| " (lm_head): Linear(in_features=5120, out_features=84992, bias=False)\n", | |
| ")" | |
| ] | |
| }, | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "model" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "PreTrainedTokenizerFast(name_or_path='./jais-13b', vocab_size=84992, model_max_length=2048, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True), added_tokens_decoder={\n", | |
| "\t0: AddedToken(\"<|endoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n", | |
| "}" | |
| ] | |
| }, | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "tokenizer" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": ".env", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.9.6" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment