Created
August 7, 2024 20:13
-
-
Save Blaizzy/23380d73a6e9ba42ffb58c6a982b23fa to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments | |
| from trl import SFTTrainer, SFTConfig | |
| from datasets import load_dataset | |
| import numpy as np | |
| # Hyperparameters | |
| temperature = 0.7 | |
| alpha = 0.5 | |
| def preprocess_dataset(examples): | |
| # Convert string representations to tensors | |
| examples['input_ids'] = torch.tensor(examples['input_ids']) | |
| examples['attention_mask'] = torch.tensor(examples['attention_mask']) | |
| return examples | |
| def compute_distillation_loss(teacher_logits, student_logits, temperature, ce_loss): | |
| top_k = 50 | |
| # Apply topk across the last dimension (vocabulary dimension) | |
| top_values, top_token_ids = student_logits.topk(top_k, dim=-1) | |
| teacher_logits_scaled = teacher_logits / temperature | |
| # Compute distillation loss | |
| distillation_loss = F.kl_div( | |
| F.log_softmax(top_values, dim=-1) / temperature, | |
| teacher_logits_scaled, | |
| reduction='mean' | |
| ) * (temperature ** 2) | |
| # Combine losses | |
| total_loss = alpha * distillation_loss + (1 - alpha) * ce_loss | |
| return total_loss | |
| class DistillationTrainer(SFTTrainer): | |
| def __init__(self, *args, **kwargs): | |
| self.remove_unused_columns = kwargs.pop('remove_unused_columns', None) | |
| self.max_seq_length = kwargs.get('max_seq_length', 8192) | |
| super(DistillationTrainer, self).__init__(*args, **kwargs) | |
| def _prepare_inputs(self, inputs): | |
| return {k: v.to(self.args.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} | |
| def compute_loss(self, model, inputs, return_outputs=False): | |
| # Find the position of the assistant token | |
| assistant_token_id = self.tokenizer.convert_tokens_to_ids("assistant") | |
| assistant_positions = np.where(inputs["input_ids"].clone().cpu() == assistant_token_id)[1].tolist()[0] | |
| labels = torch.full_like(inputs["input_ids"], -100) | |
| seq_end = inputs["attention_mask"].sum() | |
| labels[0, assistant_positions+1:seq_end] = inputs["input_ids"][0,assistant_positions+1:seq_end] | |
| normalized_logits = inputs.pop("normalized_logits") | |
| outputs = model(input_ids=inputs["input_ids"][:seq_end], attention_mask=inputs["attention_mask"][:seq_end], labels=labels) | |
| student_logits = outputs.logits | |
| loss = compute_distillation_loss(normalized_logits, student_logits, temperature, outputs.loss) | |
| return (loss, outputs) if return_outputs else loss | |
| # Load and preprocess the dataset | |
| dataset = load_dataset("arcee-train/finetome-llama-3.1-8b-10k-logits", split="train") # Replace with your actual dataset name or loading method | |
| dataset = dataset.take(100) | |
| preprocessed_dataset = dataset.map(preprocess_dataset, batched=True, batch_size=16) | |
| print("Dataset preparation complete. Loading models...") | |
| # Load models with configurable flash attention | |
| model_kwargs = {"torch_dtype": torch.bfloat16} | |
| model_name = "arcee-train/Llama-3.1-6B-Instruct-Granite-v0" | |
| # Define the student model | |
| student_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") | |
| # Define the tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| preprocessed_dataset = preprocessed_dataset.remove_columns("token_ids") | |
| preprocessed_dataset = preprocessed_dataset.remove_columns("top_values") | |
| student_model.config.pad_token_id = tokenizer.eos_token_id | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Training arguments | |
| training_args = TrainingArguments( | |
| output_dir="./distillation_output", | |
| max_steps=100, | |
| per_device_train_batch_size=1, | |
| per_device_eval_batch_size=1, | |
| remove_unused_columns=False, | |
| warmup_steps=50, | |
| weight_decay=0.01, | |
| logging_dir="./logs", | |
| logging_steps=10, | |
| ) | |
| # Initialize the DistillationTrainer | |
| trainer = DistillationTrainer( | |
| model=student_model, | |
| args=training_args, | |
| train_dataset=preprocessed_dataset, | |
| ) | |
| # Start training | |
| trainer.train() | |
| from transformers import TextStreamer | |
| def count_parameters(model): | |
| # Calculate the number of parameters in billions | |
| num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 10**9 | |
| print(f"Model size: {num_params:.3f}B parameters") | |
| return int(num_params) | |
| def generate(model, tokenizer, inputs, max_new_tokens=100): | |
| text_streamer = TextStreamer(tokenizer) | |
| _ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = max_new_tokens, temperature=0.7) | |
| inputs = tokenizer( | |
| [ | |
| "<|start_header_id|>user<|end_header_id|>Create a simple python program to add two numbers.<|eot_id|><|start_header_id|>assistant<|end_header_id|>" | |
| ], return_tensors = "pt").to("cuda") | |
| model_name = "arcee-train/Llama-3.1-6B-Instruct-Granite-v0" | |
| # Define the student model | |
| base_pruned_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") | |
| generate(base_pruned_model, tokenizer, inputs) | |
| generate(student_model, tokenizer, inputs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment