Skip to content

Instantly share code, notes, and snippets.

@yuchenlin
Last active July 30, 2024 21:16
Show Gist options
  • Save yuchenlin/5ab016e825efd85188954dd120d572ee to your computer and use it in GitHub Desktop.
Save yuchenlin/5ab016e825efd85188954dd120d572ee to your computer and use it in GitHub Desktop.
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
# Test the model before saving
def chat_with_model(model, tokenizer, instruction, max_length=50):
prompt = f"### User: {instruction}\n\n### AI: "
inputs = tokenizer.encode(prompt, return_tensors='pt').to(device)
outputs = model.generate(inputs, max_length=max_length, num_return_sequences=1, top_k=50)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.replace(prompt, "")
# Truncate at the next "### User" if it exists
user_start = response.find("###")
if user_start != -1:
response = response[:user_start].strip()
return response
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the fine-tuned model
model = GPT2LMHeadModel.from_pretrained('fine_tuned_gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('fine_tuned_gpt2')
model.to(device)
# Test the model after saving
# test_instruction = "Hello!"
while True:
test_instruction = input(">> User: ")
test_output = chat_with_model(model, tokenizer, test_instruction)
print(">> AI: ", test_output)
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AdamW
from tqdm import tqdm
from datasets import load_dataset
class InstructionDataset(Dataset):
def __init__(self, data, tokenizer, max_length=1024):
self.data = data
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
instruction = item['instruction']
response = item['response']
# Format the sequence using the specified template
instruction_text = f"### User: {instruction}\n\n### AI: "
response_text = f"{response}"
# Tokenize the instruction and response separately
instruction_encodings = self.tokenizer(instruction_text, truncation=True, max_length=self.max_length, padding=False)
response_encodings = self.tokenizer(response_text, truncation=True, max_length=self.max_length - len(instruction_encodings['input_ids']), padding=False)
# Concatenate the instruction and response encodings
input_ids = instruction_encodings['input_ids'] + response_encodings['input_ids']
attention_mask = [1] * len(input_ids)
# Create labels: -100 for input (masked), shifted token ids for output
labels = [-100] * (len(instruction_encodings['input_ids'])-1) + response_encodings['input_ids'] + [-100]
# shift the labels
# Ensure length does not exceed max_length
if len(input_ids) > self.max_length:
input_ids = input_ids[:self.max_length]
attention_mask = attention_mask[:self.max_length]
labels = labels[:self.max_length]
return {
'instruction': instruction,
'response': response,
'instruction_ids': instruction_encodings['input_ids'],
'response_ids': response_encodings['input_ids'],
'input_ids': input_ids,
'attention_mask': attention_mask,
'labels': labels
}
def collate_fn(batch):
max_len = max(len(item['input_ids']) for item in batch)
input_ids = []
attention_mask = []
labels = []
for item in batch:
# Pad input_ids and attention_mask to the left
padding_length = max_len - len(item['input_ids'])
input_ids.append([tokenizer.pad_token_id] * padding_length + item['input_ids'])
attention_mask.append([0] * padding_length + item['attention_mask'])
# Pad labels to the left with -100
labels.append([-100] * padding_length + item['labels'])
# Convert to tensors
input_ids = torch.tensor(input_ids)
attention_mask = torch.tensor(attention_mask)
labels = torch.tensor(labels)
return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'labels': labels
}
# Initialize tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained('gpt2')
# # Prepare dataset and dataloader
# data = [
# {'instruction': 'Write a poem about spring.', 'response': 'Flowers bloom in vibrant hues...'},
# {'instruction': 'Explain quantum computing.', 'response': 'Quantum computing is a type of computation...'},
# # Add more data items here
# ]
# Load from Magpie data
hf_data = load_dataset("Magpie-Align/Magpie-Pro-300K-Filtered", split="train")
data = []
for item in hf_data:
data.append({
"instruction": item["conversations"][0]["value"],
"response": item["conversations"][1]["value"],
})
data = data[:100000]
dataset = InstructionDataset(data, tokenizer, max_length=1024)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
# Visualize the data for debugging
if False:
sample = dataset[0]
print("Instruction:", sample['instruction'])
print("Response:", sample['response'])
print("Instruction IDs:", sample['instruction_ids'][:50])
print("Response IDs:", sample['response_ids'][:30])
print("Input IDs:", sample['input_ids'][:100])
print("Attention Mask:", sample['attention_mask'])
print("Labels:", sample['labels'])
exit()
# Training setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
optimizer = AdamW(model.parameters(), lr=3e-4)
# Training loop
num_epochs = 1
grad_acc_steps = 4 # Number of steps to accumulate gradients
K = 5 # Report loss every K steps
model.train()
for epoch in range(num_epochs):
step_loss = 0
for i, batch in enumerate(tqdm(dataloader)):
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss / grad_acc_steps # Normalize loss
loss.backward()
step_loss += loss.item()
if (i + 1) % grad_acc_steps == 0:
optimizer.step()
optimizer.zero_grad()
if (i + 1) % K == 0:
avg_step_loss = step_loss / K
print(f"Step {i+1}/{len(dataloader)}, Loss: {avg_step_loss:.4f}")
step_loss = 0
# Save the fine-tuned model
model.save_pretrained('fine_tuned_gpt2')
tokenizer.save_pretrained('fine_tuned_gpt2')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment