Created
March 28, 2023 14:42
-
-
Save jooray/0407c8ccc951003c4a2d63f3dd0202d2 to your computer and use it in GitHub Desktop.
Running alpacoom model on MPS (Apple Silicon) using HuggingFace Transformers and Peft
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
from peft import PeftModel, PeftConfig | |
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig | |
import sys | |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" | |
peft_model_id = "mrm8488/Alpacoom" | |
config = PeftConfig.from_pretrained(peft_model_id) | |
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, return_dict=True, load_in_8bit=False).to("mps") | |
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-7b1") | |
model = PeftModel.from_pretrained(model, peft_model_id).to("mps") | |
model.eval() | |
# Based on the inference code by `tloen/alpaca-lora` | |
def generate_prompt(instruction, input=None): | |
if input: | |
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. | |
### Instruction: | |
{instruction} | |
### Input: | |
{input} | |
### Response:""" | |
else: | |
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. | |
### Instruction: | |
{instruction} | |
### Response:""" | |
def generate( | |
instruction, | |
input=None, | |
temperature=0.1, | |
top_p=0.75, | |
top_k=40, | |
num_beams=4, | |
**kwargs, | |
): | |
prompt = generate_prompt(instruction, input) | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("mps") | |
generation_config = GenerationConfig( | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
num_beams=num_beams, | |
**kwargs, | |
) | |
with torch.no_grad(): | |
generation_output = model.generate( | |
input_ids=input_ids, | |
generation_config=generation_config, | |
return_dict_in_generate=True, | |
output_scores=True, | |
max_new_tokens=256, | |
) | |
s = generation_output.sequences[0] | |
output = tokenizer.decode(s) | |
return output.split("### Response:")[1].strip().split("Below")[0] | |
instruction = sys.argv[1] | |
print("Instruction:", instruction) | |
print("Response:", generate(instruction, sys.argv[2])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment