Last active
May 7, 2024 22:33
-
-
Save brandon-lockaby/0e357aecfe51bbd53a4c41457c29d484 to your computer and use it in GitHub Desktop.
MultiClassifier
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
DEV = "cuda" | |
# I want a base model and this is instruct-tuned, but it will fit on my gpu | |
model_path = "microsoft/Phi-3-mini-128k-instruct" | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
device_map=DEV, | |
torch_dtype="auto", | |
trust_remote_code=True, | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
class MultiClassifier(): | |
def __init__(self, dev, model, tokenizer, prompt, class_names): | |
self.__dict__.update(locals()) | |
# tokenize the given prompt for reuse on every classify() call | |
self.prompt_ids = tokenizer.encode(self.prompt, return_tensors="pt").to(self.dev) | |
# get kv cache to also reuse on every classify() call | |
self.kv_cache = self.model(self.prompt_ids, return_dict=True).past_key_values | |
# and keep these token ids | |
self.yes = " yes" | |
self.no = " no" | |
self.yes_id = torch.tensor(tokenizer.encode(self.yes, add_special_tokens=False)[-1]).to(self.dev).unsqueeze(0).unsqueeze(0) | |
self.no_id = torch.tensor(tokenizer.encode(self.no, add_special_tokens=False)[-1]).to(self.dev).unsqueeze(0).unsqueeze(0) | |
def classify(self, held_out_example, return_probs=False): | |
output_class_list = [] | |
output_probs = {} | |
kv_cache = self.kv_cache | |
# iterate through all the class names | |
new_text = held_out_example | |
for class_name in self.class_names: | |
# generate a token following the class name marker | |
prompt_ids = tokenizer.encode(f"{new_text}\n{class_name}:", add_special_tokens=False, return_tensors="pt").to(self.dev) | |
attention_mask = torch.ones(len(kv_cache) + len(prompt_ids), device=self.dev) | |
outputs = self.model(prompt_ids, past_key_values=kv_cache, attention_mask=attention_mask, return_dict=True) | |
kv_cache = outputs.past_key_values | |
# just keep the two logits we're interested in | |
logits = torch.tensor([outputs.logits[-1,-1,self.yes_id], outputs.logits[-1,-1,self.no_id]], device=self.dev) | |
# and convert to probabilities | |
probs = torch.nn.functional.softmax(logits, dim=-1) | |
yes_prob = probs[0].item() | |
no_prob = probs[1].item() | |
# results get | |
if yes_prob >= no_prob: | |
output_class_list.append(class_name) | |
new_text = self.yes | |
else: | |
new_text = self.no | |
if return_probs: | |
output_probs[class_name] = {"yes": yes_prob, "no": no_prob} | |
return (output_class_list, output_probs) if return_probs else output_class_list | |
prompt = """Text: I ate an apple and then a few oranges. | |
Apples: yes | |
Oranges: yes | |
Text: Do you sell chocolate oranges? | |
Apples: no | |
Oranges: yes | |
Text: I want something red to eat. | |
Apples: yes | |
Oranges: no | |
Text: Orange you glad I didn't say apple? | |
Apples: yes | |
Oranges: yes | |
Text: I hate oranges and I hate apples! | |
Apples: yes | |
Oranges: yes | |
Text: My car is orange | |
Apples: no | |
Oranges: no | |
Text: Red! | |
Apples: no | |
Oranges: no | |
Text: These can sometimes be red. | |
Apples: no | |
Oranges: no | |
Text: orange | |
Apples: no | |
Oranges: yes | |
Text: What are you eating? | |
Apples: no | |
Oranges: no | |
Text: """ | |
class_names = ["Apples", "Oranges"] | |
classifier = MultiClassifier(DEV, model, tokenizer, prompt, class_names) | |
def test(text): | |
result = classifier.classify(text, return_probs=True) | |
print(f"\n{text}\n\t{result}") | |
test("You can't squeeze ketchup from a banana.") | |
test("Do you like apple pie?") | |
test("Too bad. I baked an orange pie.") | |
test("DO NOT give me apple pie.") | |
test("red") | |
test("These can sometimes be red.") | |
test("orangey") | |
test("No apples and no oranges") | |
test("What are you eating?") | |
test("Orples") | |
test("What about a-p-p-l-e") | |
# Output | |
#You can't squeeze ketchup from a banana. | |
#([], {'Apples': {'yes': 0.0011695101857185364, 'no': 0.9988304972648621}, 'Oranges': {'yes': 0.005220125894993544, 'no': 0.9947799444198608}}) | |
#Do you like apple pie? | |
#(['Apples'], {'Apples': {'yes': 0.9997387528419495, 'no': 0.00026119028916582465}, 'Oranges': {'yes': 6.144174221844878e-06, 'no': 0.9999938011169434}}) | |
#Too bad. I baked an orange pie. | |
#(['Oranges'], {'Apples': {'yes': 0.0010322310263291001, 'no': 0.9989677667617798}, 'Oranges': {'yes': 0.9999938011169434, 'no': 6.144174221844878e-06}}) | |
#DO NOT give me apple pie. | |
#(['Apples'], {'Apples': {'yes': 0.9740425944328308, 'no': 0.02595735713839531}, 'Oranges': {'yes': 2.6729447100137804e-08, 'no': 1.0}}) | |
#red | |
#([], {'Apples': {'yes': 0.02595735713839531, 'no': 0.9740425944328308}, 'Oranges': {'yes': 0.2018132209777832, 'no': 0.7981867790222168}}) | |
#These can sometimes be red. | |
#([], {'Apples': {'yes': 0.0019267346942797303, 'no': 0.9980732202529907}, 'Oranges': {'yes': 0.0534033328294754, 'no': 0.9465966820716858}}) | |
#orangey | |
#(['Oranges'], {'Apples': {'yes': 0.00026119028916582465, 'no': 0.9997387528419495}, 'Oranges': {'yes': 0.9890130758285522, 'no': 0.01098694372922182}}) | |
#No apples and no oranges | |
#([], {'Apples': {'yes': 0.0008040859247557819, 'no': 0.9991958737373352}, 'Oranges': {'yes': 0.0024726232513785362, 'no': 0.9975274205207825}}) | |
#What are you eating? | |
#([], {'Apples': {'yes': 0.00048785717808641493, 'no': 0.9995121955871582}, 'Oranges': {'yes': 0.0011695101857185364, 'no': 0.9988304972648621}}) | |
#Orples | |
#([], {'Apples': {'yes': 3.120191104244441e-05, 'no': 0.9999687671661377}, 'Oranges': {'yes': 0.007577240467071533, 'no': 0.9924227595329285}}) | |
#What about a-p-p-l-e | |
#(['Apples'], {'Apples': {'yes': 0.9046505093574524, 'no': 0.09534946084022522}, 'Oranges': {'yes': 0.0035936026833951473, 'no': 0.9964063763618469}}) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment