Created
August 17, 2024 18:58
-
-
Save morganmcg1/46d476a21bd83703633f752b0e2ae5cc to your computer and use it in GitHub Desktop.
Chain of Thought Decoding, from https://arxiv.org/pdf/2402.10200
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # From Unsloth Llama 3.1 fine-tuning notebook | |
| from unsloth import FastLanguageModel | |
| import torch | |
| def print_gpu_stats(): | |
| gpu_stats = torch.cuda.get_device_properties(0) | |
| start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) | |
| max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) | |
| print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.") | |
| print(f"{start_gpu_memory} GB of memory reserved.") | |
| max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally! | |
| dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ | |
| load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False. | |
| model_name = "unsloth/Meta-Llama-3.1-8B" | |
| # model_name = "unsloth/Meta-Llama-3.1-8B-Instruct" | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name = "unsloth/Meta-Llama-3.1-8B", | |
| max_seq_length = max_seq_length, | |
| dtype = dtype, | |
| load_in_4bit = load_in_4bit, | |
| # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf | |
| ) | |
| print_gpu_stats() | |
| import re | |
| import torch.nn.functional as F | |
| import torch | |
| def get_top_n_tokens(o, n=2): | |
| probs = F.softmax(o.logits[0, -1, :], dim=-1) | |
| top_two = torch.topk(probs, n) | |
| top_two_probs = top_two.values | |
| top_two_tokens = top_two.indices | |
| return top_two_probs, top_two_tokens | |
| def find_last_even_or_odd(s): | |
| # Search for "even" or "odd" from the end of the string | |
| match = re.search(r'(even|odd)(?!.*\b(even|odd)\b)', s) | |
| if match: | |
| return match.group(0) # Returns the last "even" or "odd" found | |
| else: | |
| return "" # If neither "even" nor "odd" is found | |
| def score_output(output: str, label: str): | |
| answer = find_last_even_or_odd(output.lower()) | |
| return answer == label | |
| top_k = 10 # Number of top tokens to consider | |
| # For Year Parity task, Chain of Thought Reasoning Without Prompting https://arxiv.org/pdf/2402.10200 for | |
| # Mistral used 50 tokens for base model and 100 for Instruction model | |
| sequence_length = 100 # Number of tokens to generate for each top_k token, | |
| sequences = {} | |
| avg_deltas = {} | |
| results = {} | |
| # t = "Continue the fibonnaci sequence." | |
| # t = alpaca_prompt.format( | |
| # t, # instruction | |
| # "1, 1, 2, 3, 5, 8", # input | |
| # "", # output - leave this blank for generation! | |
| # ) | |
| base_tempate = f"Q: {{question}}\nA:" | |
| instruct_tempate = "[INST] question [/INST]" | |
| t = f"Was {{celebrity}} born in an even or odd year?" | |
| evals = [{"index": 0, "celebrity": "nicolas cage", "label": "even"}] | |
| t = base_tempate.format(question = t.format(celebrity = evals[0]["celebrity"])) | |
| print(t) | |
| # t = instruct_tempate.format(question = t.format(celebrity = evals[0]["celebrity"] | |
| inp = tokenizer([t], return_tensors="pt").to("cuda") | |
| print(f"Model input:\n```{tokenizer.decode(inp.input_ids[0])}```") | |
| # Get the logits and find the top_k indices | |
| model = model.eval() | |
| FastLanguageModel.for_inference(model) | |
| o = model(**inp) | |
| top_k_indices = torch.topk(o.logits[0, -1, :], top_k).indices | |
| # Generate a sequence for each top_k index | |
| with torch.no_grad(): | |
| for rank, idx in enumerate(top_k_indices): | |
| first_token_decoded = tokenizer.decode(idx) | |
| # print(f"First token: {first_token_decoded}") | |
| gens = [first_token_decoded] # To store the generated sequence | |
| deltas = [] | |
| # Start with the initial input + the current top_k token | |
| current_inp = { | |
| 'input_ids': torch.cat((inp['input_ids'], idx.unsqueeze(0).unsqueeze(0)), dim=1) | |
| } | |
| current_inp['attention_mask'] = torch.ones(inp['input_ids'].shape[0], current_inp['input_ids'].shape[1], dtype=torch.long, device=inp['input_ids'].device) | |
| # Generate a sequence of 10 tokens | |
| for _ in range(sequence_length): | |
| o = model(**current_inp) | |
| next_token = torch.topk(o.logits[0, -1, :], 1).indices | |
| # Calc token confidences | |
| top_two_probs, top_two_tokens = get_top_n_tokens(o, n=2) | |
| top_two_delta = top_two_probs[0] - top_two_probs[1] | |
| deltas.append(top_two_delta) | |
| top_two_tokens = torch.topk(o.logits[0, -1, :], 2).values | |
| gens.append(tokenizer.decode(next_token)) | |
| # Update the input with the new token | |
| current_inp['input_ids'] = torch.cat((current_inp['input_ids'], next_token.unsqueeze(0)), dim=1) | |
| current_inp['attention_mask'] = torch.ones(current_inp['input_ids'].shape, dtype=torch.long, device=current_inp['input_ids'].device) | |
| # Store the generated sequence for this index | |
| sequences[idx.item()] = "".join(gens) | |
| avg_deltas[idx.item()] = torch.mean(torch.stack(deltas)) | |
| # print(sequences[idx.item()]) | |
| print(f"Sequence for rank {rank}, token {idx}, '{first_token_decoded}', \ | |
| confidence: {avg_deltas[idx.item()]}:\n{sequences[idx.item()]}\nEND GENERATION\n") | |
| correct = score_output(output = sequences[idx.item()], label = evals[0]["label"]) | |
| results[idx.item()] = correct | |
| print(f"Correct: {correct}\n") | |
| # print_gpu_stats() | |
| # Free up memory by deleting variables | |
| del o, next_token, current_inp, gens | |
| torch.cuda.empty_cache() | |
| # print the sequent with the highest avg_delta | |
| confident_idx = max(avg_deltas, key=avg_deltas.get) | |
| print(f"Most confidence path {confident_idx}, confidence: {avg_deltas[confident_idx]}, sequence: {sequences[confident_idx]}, result: {sequences[confident_idx]}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment