Last active
August 27, 2023 13:26
-
-
Save niw/b46b3a55833adf3294f9f8be8370d678 to your computer and use it in GitHub Desktop.
Use CodeLlama on macOS with MPS quickly
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Usage | |
# ===== | |
# | |
# ## Prerequisite | |
# | |
# Prepare Python 3, for exmaple, install Homebrew and `brew install python`. | |
# | |
# ## Install dependencies | |
# | |
# $ python3 -m venv .venv | |
# $ .venv/bin/pip3 install --pre --index-url https://download.pytorch.org/whl/nightly/cpu torch | |
# $ .venv/bin/pip3 install git+https://github.com/huggingface/transformers.git | |
# | |
# Due to known issue, need to use the latest pytorch to use MPS. | |
# | |
# ## Patch Transformers | |
# | |
# Patch `lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py` to change | |
# `long()` to `int()` on the line `position_ids = attention_mask.long().cumsum(-1) - 1`. | |
# This is required to use MPS. | |
# | |
# ## Run test script on Python REPL | |
# | |
# $ .venv/bin/python3 -i test.py | |
# >>> gen('Write a function prints "Hello World"') | |
# | |
# First time when you run this script, it downloads large model and metadata in `~/.cache`. | |
# Be on the faster internet and prepare storage. | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
device = "mps" | |
model_name = "codellama/CodeLlama-7b-Instruct-hf" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
) | |
#print(model) | |
model.to(device) | |
@torch.no_grad() | |
def gen(instruction, system=None): | |
if system is not None: | |
instruction = f"<<SYS>>\n{system}\n<</SYS>>\n\n{instruction}" | |
prompt = f"[INST] {instruction} [/INST]" | |
print(prompt) | |
inputs = tokenizer( | |
prompt, | |
return_tensors="pt", | |
) | |
inputs.to(device) | |
#print(inputs) | |
tokens = model.generate( | |
**inputs, | |
max_new_tokens=128, | |
temperature=0.2, | |
do_sample=True, | |
) | |
#print(tokens) | |
print(tokenizer.decode(tokens[0], skip_special_tokens=True)) | |
# gen( | |
# 'Write a function that computes the set of sums of all contiguous sublists of a given list in Python.', | |
# system="Provide answers in JavaScript" | |
# ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment