Created
December 27, 2023 02:04
-
-
Save fancyerii/6a4c5ee63d8935a42e3f615cff1c1413 to your computer and use it in GitHub Desktop.
llama debug
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import os | |
import random | |
import numpy as np | |
import torch | |
def set_all_seeds(seed): | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.backends.cudnn.deterministic = True | |
set_all_seeds(1234) | |
model_path = "/nas/lili/models_hf/70B-chat" | |
tokenizer_path = "/nas/lili/models_hf/70B-chat" | |
print(f"mode: {model_path}, tokenizer: {tokenizer_path}") | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
device_map="auto", | |
pad_token_id=tokenizer.eos_token_id | |
) | |
arg_dict ={ | |
"num_beams": 1, | |
"temperature": 0.1, | |
"max_new_tokens": 128, | |
"do_sample": True, | |
} | |
good_prompt = ''' | |
You will be provided with a product name. The product name will be delimited by 3 backticks, i.e.```. | |
Classify the product into a primary category. | |
Primary categories: | |
Clothing, Shoes & Jewelry | |
Automotive | |
Home & Kitchen | |
Beauty & Personal Care | |
Electronics | |
Sports & Outdoors | |
Patio, Lawn & Garden | |
Handmade Products | |
Grocery & Gourmet Food | |
Health & Household | |
Musical Instruments | |
Toys & Games | |
Baby Products | |
Pet Supplies | |
Tools & Home Improvement | |
Appliances | |
Office Products | |
Cell Phones & Accessories | |
Product name:```Mepase 4 Pcs Halloween Adult Women Ladybug Costume Set Include Dress Headband Socks Sunglasses for Lady Cosplay Party```. | |
Only answer the category name, no other words. | |
''' | |
bad_prompt = ''' | |
You will be provided with a product name. The product name will be delimited by 3 backticks, i.e.```. | |
Classify the product into a primary category. | |
Primary categories: | |
Clothing, Shoes & Jewelry | |
Automotive | |
Home & Kitchen | |
Beauty & Personal Care | |
Electronics | |
Sports & Outdoors | |
Patio, Lawn & Garden | |
Handmade Products | |
Grocery & Gourmet Food | |
Health & Household | |
Musical Instruments | |
Toys & Games | |
Baby Products | |
Pet Supplies | |
Tools & Home Improvement | |
Appliances | |
Office Products | |
Cell Phones & Accessories | |
Product name:```Mepase 4 Pcs Halloween Adult Women Ladybug Costume Set Include Dress Headband Socks Sunglasses for Lady Cosplay Party```. | |
Only answer the category name, no other words. | |
''' | |
print(f"good\n{repr(good_prompt)}") | |
print(f"bad\n{repr(bad_prompt)}") | |
print("test good:") | |
for _ in range(10): | |
model_inputs = tokenizer(good_prompt, return_tensors='pt').to('cuda') | |
output = model.generate(**model_inputs, **arg_dict) | |
input_length = model_inputs["input_ids"].shape[1] | |
output = tokenizer.decode(output[0][input_length:], skip_special_tokens=True) | |
print(output) | |
print("test bad:") | |
for _ in range(10): | |
model_inputs = tokenizer(good_prompt, return_tensors='pt').to('cuda') | |
output = model.generate(**model_inputs, **arg_dict) | |
input_length = model_inputs["input_ids"].shape[1] | |
output = tokenizer.decode(output[0][input_length:], skip_special_tokens=True) | |
print(output) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment