Created
August 7, 2024 20:14
-
-
Save Blaizzy/5bd117285f4a0c23adefee11f5124939 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from datasets import load_dataset | |
| from torch.utils.data import DataLoader | |
| import os | |
| import pyarrow as pa | |
| import pyarrow.parquet as pq | |
| from tqdm import tqdm | |
| def process_and_save_parquet(all_logits, all_input_ids, all_attention_masks, top_50_token_ids, top_50_values, output_dir, shard_idx, total_shards): | |
| if not os.path.exists(output_dir): | |
| os.makedirs(output_dir) | |
| # Format shard_idx to have leading zeros | |
| formatted_shard_idx = str(shard_idx).zfill(len(str(total_shards))) | |
| # Store original shapes and data | |
| data = { | |
| 'normalized_logits': all_logits.numpy(), | |
| 'input_ids': all_input_ids.numpy(), | |
| 'attention_mask': all_attention_masks.numpy(), | |
| 'token_ids': top_50_token_ids.numpy(), | |
| 'top_values': top_50_values.numpy() | |
| } | |
| shapes = {key: arr.shape for key, arr in data.items()} | |
| # Create a list to store our data as dictionaries | |
| rows = [] | |
| # Assuming the first dimension is the batch size | |
| batch_size = data['input_ids'].shape[0] | |
| for i in range(batch_size): | |
| row = {} | |
| for key, arr in data.items(): | |
| if arr.ndim > 1: | |
| row[key] = arr[i].tolist() # Convert 2D+ arrays to list for each item in batch | |
| else: | |
| row[key] = arr[i] | |
| rows.append(row) | |
| # Create PyArrow Table | |
| table = pa.Table.from_pylist(rows) | |
| # Add shapes as metadata | |
| metadata = table.schema.metadata | |
| if metadata is None: | |
| metadata = {} | |
| metadata[b'shapes'] = str(shapes).encode() | |
| table = table.replace_schema_metadata(metadata) | |
| # Save as Parquet | |
| shard_path = os.path.join(output_dir, f"data-shard-{formatted_shard_idx}.parquet") | |
| pq.write_table(table, shard_path) | |
| print(f"Shard {shard_idx} saved to {shard_path}") | |
| def sharegpt_format(example): | |
| conversations = example['conversations'] | |
| message = [] | |
| if isinstance(conversations, list): | |
| for conversation in conversations: | |
| if isinstance(conversation, dict): | |
| if conversation.get('from') == 'human': | |
| message.append({"role": "user", "content": conversation.get('value', '')}) | |
| elif conversation.get('from') == 'gpt': | |
| message.append({"role": "assistant", "content": conversation.get('value', '')}) | |
| elif conversation.get('from') == 'system': | |
| message.insert(0, {"role": "system", "content": conversation.get('value', '')}) | |
| if not any(msg.get('role') == 'system' for msg in message): | |
| message.insert(0, {"role": "system", "content": "You are a helpful assistant."}) | |
| text = tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) | |
| return {"text": text} | |
| model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct" # Replace with your model | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map="auto", # Enables tensor parallelism | |
| ).eval() | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| dataset = load_dataset("mlabonne/FineTome-100k", split="train") | |
| original_columns = dataset.column_names | |
| dataset = dataset.map(sharegpt_format, remove_columns=original_columns) | |
| dataset = dataset.take(10000) | |
| samples = len(dataset) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| batch_size = 16 | |
| total_shards = samples // batch_size | |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) | |
| output_dir = "test" | |
| # Generate logits, input_ids, and attention masks | |
| with torch.no_grad(): | |
| for i, batch in enumerate(tqdm(dataloader)): | |
| inputs = tokenizer(batch['text'], truncation=True, padding=True, max_length=8192, return_tensors='pt') | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| outputs = model(**inputs) | |
| logits = outputs.logits.cpu() | |
| top_k = 50 | |
| # Apply topk across the last dimension (vocabulary dimension) | |
| top_values, top_token_ids = logits.topk(top_k, dim=-1) | |
| # Normalize the top-k logits along the last dimension | |
| normalized_logits = torch.nn.functional.softmax(top_values, dim=-1) | |
| print(normalized_logits.shape) | |
| input_ids = inputs['input_ids'].cpu() | |
| attention_mask = inputs['attention_mask'].cpu() | |
| process_and_save_parquet(normalized_logits, input_ids, attention_mask, top_token_ids, top_values, output_dir, i, total_shards) | |
| del outputs, logits, input_ids, attention_mask | |
| torch.cuda.empty_cache() | |
| print(f"All shards saved in {output_dir}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment