Skip to content

Instantly share code, notes, and snippets.

@morganmcg1
Created August 17, 2024 18:58
Show Gist options
  • Save morganmcg1/46d476a21bd83703633f752b0e2ae5cc to your computer and use it in GitHub Desktop.
Save morganmcg1/46d476a21bd83703633f752b0e2ae5cc to your computer and use it in GitHub Desktop.
Chain of Thought Decoding, from https://arxiv.org/pdf/2402.10200
# From Unsloth Llama 3.1 fine-tuning notebook
from unsloth import FastLanguageModel
import torch
def print_gpu_stats():
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
model_name = "unsloth/Meta-Llama-3.1-8B"
# model_name = "unsloth/Meta-Llama-3.1-8B-Instruct"
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/Meta-Llama-3.1-8B",
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
print_gpu_stats()
import re
import torch.nn.functional as F
import torch
def get_top_n_tokens(o, n=2):
probs = F.softmax(o.logits[0, -1, :], dim=-1)
top_two = torch.topk(probs, n)
top_two_probs = top_two.values
top_two_tokens = top_two.indices
return top_two_probs, top_two_tokens
def find_last_even_or_odd(s):
# Search for "even" or "odd" from the end of the string
match = re.search(r'(even|odd)(?!.*\b(even|odd)\b)', s)
if match:
return match.group(0) # Returns the last "even" or "odd" found
else:
return "" # If neither "even" nor "odd" is found
def score_output(output: str, label: str):
answer = find_last_even_or_odd(output.lower())
return answer == label
top_k = 10 # Number of top tokens to consider
# For Year Parity task, Chain of Thought Reasoning Without Prompting https://arxiv.org/pdf/2402.10200 for
# Mistral used 50 tokens for base model and 100 for Instruction model
sequence_length = 100 # Number of tokens to generate for each top_k token,
sequences = {}
avg_deltas = {}
results = {}
# t = "Continue the fibonnaci sequence."
# t = alpaca_prompt.format(
# t, # instruction
# "1, 1, 2, 3, 5, 8", # input
# "", # output - leave this blank for generation!
# )
base_tempate = f"Q: {{question}}\nA:"
instruct_tempate = "[INST] question [/INST]"
t = f"Was {{celebrity}} born in an even or odd year?"
evals = [{"index": 0, "celebrity": "nicolas cage", "label": "even"}]
t = base_tempate.format(question = t.format(celebrity = evals[0]["celebrity"]))
print(t)
# t = instruct_tempate.format(question = t.format(celebrity = evals[0]["celebrity"]
inp = tokenizer([t], return_tensors="pt").to("cuda")
print(f"Model input:\n```{tokenizer.decode(inp.input_ids[0])}```")
# Get the logits and find the top_k indices
model = model.eval()
FastLanguageModel.for_inference(model)
o = model(**inp)
top_k_indices = torch.topk(o.logits[0, -1, :], top_k).indices
# Generate a sequence for each top_k index
with torch.no_grad():
for rank, idx in enumerate(top_k_indices):
first_token_decoded = tokenizer.decode(idx)
# print(f"First token: {first_token_decoded}")
gens = [first_token_decoded] # To store the generated sequence
deltas = []
# Start with the initial input + the current top_k token
current_inp = {
'input_ids': torch.cat((inp['input_ids'], idx.unsqueeze(0).unsqueeze(0)), dim=1)
}
current_inp['attention_mask'] = torch.ones(inp['input_ids'].shape[0], current_inp['input_ids'].shape[1], dtype=torch.long, device=inp['input_ids'].device)
# Generate a sequence of 10 tokens
for _ in range(sequence_length):
o = model(**current_inp)
next_token = torch.topk(o.logits[0, -1, :], 1).indices
# Calc token confidences
top_two_probs, top_two_tokens = get_top_n_tokens(o, n=2)
top_two_delta = top_two_probs[0] - top_two_probs[1]
deltas.append(top_two_delta)
top_two_tokens = torch.topk(o.logits[0, -1, :], 2).values
gens.append(tokenizer.decode(next_token))
# Update the input with the new token
current_inp['input_ids'] = torch.cat((current_inp['input_ids'], next_token.unsqueeze(0)), dim=1)
current_inp['attention_mask'] = torch.ones(current_inp['input_ids'].shape, dtype=torch.long, device=current_inp['input_ids'].device)
# Store the generated sequence for this index
sequences[idx.item()] = "".join(gens)
avg_deltas[idx.item()] = torch.mean(torch.stack(deltas))
# print(sequences[idx.item()])
print(f"Sequence for rank {rank}, token {idx}, '{first_token_decoded}', \
confidence: {avg_deltas[idx.item()]}:\n{sequences[idx.item()]}\nEND GENERATION\n")
correct = score_output(output = sequences[idx.item()], label = evals[0]["label"])
results[idx.item()] = correct
print(f"Correct: {correct}\n")
# print_gpu_stats()
# Free up memory by deleting variables
del o, next_token, current_inp, gens
torch.cuda.empty_cache()
# print the sequent with the highest avg_delta
confident_idx = max(avg_deltas, key=avg_deltas.get)
print(f"Most confidence path {confident_idx}, confidence: {avg_deltas[confident_idx]}, sequence: {sequences[confident_idx]}, result: {sequences[confident_idx]}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment