Skip to content

Instantly share code, notes, and snippets.

@Blaizzy
Created August 7, 2024 20:14
Show Gist options
  • Select an option

  • Save Blaizzy/5bd117285f4a0c23adefee11f5124939 to your computer and use it in GitHub Desktop.

Select an option

Save Blaizzy/5bd117285f4a0c23adefee11f5124939 to your computer and use it in GitHub Desktop.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
import os
import pyarrow as pa
import pyarrow.parquet as pq
from tqdm import tqdm
def process_and_save_parquet(all_logits, all_input_ids, all_attention_masks, top_50_token_ids, top_50_values, output_dir, shard_idx, total_shards):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Format shard_idx to have leading zeros
formatted_shard_idx = str(shard_idx).zfill(len(str(total_shards)))
# Store original shapes and data
data = {
'normalized_logits': all_logits.numpy(),
'input_ids': all_input_ids.numpy(),
'attention_mask': all_attention_masks.numpy(),
'token_ids': top_50_token_ids.numpy(),
'top_values': top_50_values.numpy()
}
shapes = {key: arr.shape for key, arr in data.items()}
# Create a list to store our data as dictionaries
rows = []
# Assuming the first dimension is the batch size
batch_size = data['input_ids'].shape[0]
for i in range(batch_size):
row = {}
for key, arr in data.items():
if arr.ndim > 1:
row[key] = arr[i].tolist() # Convert 2D+ arrays to list for each item in batch
else:
row[key] = arr[i]
rows.append(row)
# Create PyArrow Table
table = pa.Table.from_pylist(rows)
# Add shapes as metadata
metadata = table.schema.metadata
if metadata is None:
metadata = {}
metadata[b'shapes'] = str(shapes).encode()
table = table.replace_schema_metadata(metadata)
# Save as Parquet
shard_path = os.path.join(output_dir, f"data-shard-{formatted_shard_idx}.parquet")
pq.write_table(table, shard_path)
print(f"Shard {shard_idx} saved to {shard_path}")
def sharegpt_format(example):
conversations = example['conversations']
message = []
if isinstance(conversations, list):
for conversation in conversations:
if isinstance(conversation, dict):
if conversation.get('from') == 'human':
message.append({"role": "user", "content": conversation.get('value', '')})
elif conversation.get('from') == 'gpt':
message.append({"role": "assistant", "content": conversation.get('value', '')})
elif conversation.get('from') == 'system':
message.insert(0, {"role": "system", "content": conversation.get('value', '')})
if not any(msg.get('role') == 'system' for msg in message):
message.insert(0, {"role": "system", "content": "You are a helpful assistant."})
text = tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
return {"text": text}
model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct" # Replace with your model
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto", # Enables tensor parallelism
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
dataset = load_dataset("mlabonne/FineTome-100k", split="train")
original_columns = dataset.column_names
dataset = dataset.map(sharegpt_format, remove_columns=original_columns)
dataset = dataset.take(10000)
samples = len(dataset)
tokenizer.pad_token = tokenizer.eos_token
batch_size = 16
total_shards = samples // batch_size
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
output_dir = "test"
# Generate logits, input_ids, and attention masks
with torch.no_grad():
for i, batch in enumerate(tqdm(dataloader)):
inputs = tokenizer(batch['text'], truncation=True, padding=True, max_length=8192, return_tensors='pt')
inputs = {k: v.to(model.device) for k, v in inputs.items()}
outputs = model(**inputs)
logits = outputs.logits.cpu()
top_k = 50
# Apply topk across the last dimension (vocabulary dimension)
top_values, top_token_ids = logits.topk(top_k, dim=-1)
# Normalize the top-k logits along the last dimension
normalized_logits = torch.nn.functional.softmax(top_values, dim=-1)
print(normalized_logits.shape)
input_ids = inputs['input_ids'].cpu()
attention_mask = inputs['attention_mask'].cpu()
process_and_save_parquet(normalized_logits, input_ids, attention_mask, top_token_ids, top_values, output_dir, i, total_shards)
del outputs, logits, input_ids, attention_mask
torch.cuda.empty_cache()
print(f"All shards saved in {output_dir}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment