Skip to content

Instantly share code, notes, and snippets.

@Blaizzy
Created August 7, 2024 20:13
Show Gist options
  • Select an option

  • Save Blaizzy/23380d73a6e9ba42ffb58c6a982b23fa to your computer and use it in GitHub Desktop.

Select an option

Save Blaizzy/23380d73a6e9ba42ffb58c6a982b23fa to your computer and use it in GitHub Desktop.
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset
import numpy as np
# Hyperparameters
temperature = 0.7
alpha = 0.5
def preprocess_dataset(examples):
# Convert string representations to tensors
examples['input_ids'] = torch.tensor(examples['input_ids'])
examples['attention_mask'] = torch.tensor(examples['attention_mask'])
return examples
def compute_distillation_loss(teacher_logits, student_logits, temperature, ce_loss):
top_k = 50
# Apply topk across the last dimension (vocabulary dimension)
top_values, top_token_ids = student_logits.topk(top_k, dim=-1)
teacher_logits_scaled = teacher_logits / temperature
# Compute distillation loss
distillation_loss = F.kl_div(
F.log_softmax(top_values, dim=-1) / temperature,
teacher_logits_scaled,
reduction='mean'
) * (temperature ** 2)
# Combine losses
total_loss = alpha * distillation_loss + (1 - alpha) * ce_loss
return total_loss
class DistillationTrainer(SFTTrainer):
def __init__(self, *args, **kwargs):
self.remove_unused_columns = kwargs.pop('remove_unused_columns', None)
self.max_seq_length = kwargs.get('max_seq_length', 8192)
super(DistillationTrainer, self).__init__(*args, **kwargs)
def _prepare_inputs(self, inputs):
return {k: v.to(self.args.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
def compute_loss(self, model, inputs, return_outputs=False):
# Find the position of the assistant token
assistant_token_id = self.tokenizer.convert_tokens_to_ids("assistant")
assistant_positions = np.where(inputs["input_ids"].clone().cpu() == assistant_token_id)[1].tolist()[0]
labels = torch.full_like(inputs["input_ids"], -100)
seq_end = inputs["attention_mask"].sum()
labels[0, assistant_positions+1:seq_end] = inputs["input_ids"][0,assistant_positions+1:seq_end]
normalized_logits = inputs.pop("normalized_logits")
outputs = model(input_ids=inputs["input_ids"][:seq_end], attention_mask=inputs["attention_mask"][:seq_end], labels=labels)
student_logits = outputs.logits
loss = compute_distillation_loss(normalized_logits, student_logits, temperature, outputs.loss)
return (loss, outputs) if return_outputs else loss
# Load and preprocess the dataset
dataset = load_dataset("arcee-train/finetome-llama-3.1-8b-10k-logits", split="train") # Replace with your actual dataset name or loading method
dataset = dataset.take(100)
preprocessed_dataset = dataset.map(preprocess_dataset, batched=True, batch_size=16)
print("Dataset preparation complete. Loading models...")
# Load models with configurable flash attention
model_kwargs = {"torch_dtype": torch.bfloat16}
model_name = "arcee-train/Llama-3.1-6B-Instruct-Granite-v0"
# Define the student model
student_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
# Define the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
preprocessed_dataset = preprocessed_dataset.remove_columns("token_ids")
preprocessed_dataset = preprocessed_dataset.remove_columns("top_values")
student_model.config.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
# Training arguments
training_args = TrainingArguments(
output_dir="./distillation_output",
max_steps=100,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
remove_unused_columns=False,
warmup_steps=50,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=10,
)
# Initialize the DistillationTrainer
trainer = DistillationTrainer(
model=student_model,
args=training_args,
train_dataset=preprocessed_dataset,
)
# Start training
trainer.train()
from transformers import TextStreamer
def count_parameters(model):
# Calculate the number of parameters in billions
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 10**9
print(f"Model size: {num_params:.3f}B parameters")
return int(num_params)
def generate(model, tokenizer, inputs, max_new_tokens=100):
text_streamer = TextStreamer(tokenizer)
_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = max_new_tokens, temperature=0.7)
inputs = tokenizer(
[
"<|start_header_id|>user<|end_header_id|>Create a simple python program to add two numbers.<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
], return_tensors = "pt").to("cuda")
model_name = "arcee-train/Llama-3.1-6B-Instruct-Granite-v0"
# Define the student model
base_pruned_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
generate(base_pruned_model, tokenizer, inputs)
generate(student_model, tokenizer, inputs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment