Skip to content

Instantly share code, notes, and snippets.

@kalomaze
Created July 29, 2024 20:16
Show Gist options
  • Save kalomaze/74a5cbbc3046e35024b657d1c1b0d9c6 to your computer and use it in GitHub Desktop.
Save kalomaze/74a5cbbc3046e35024b657d1c1b0d9c6 to your computer and use it in GitHub Desktop.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import random
import os
import shutil
# Set a seed for reproducibility
random.seed(42)
# Load the model, tokenizer, and configuration
model_path = "/home/gcpuser/models/NemoBase"
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_path)
config = AutoConfig.from_pretrained(model_path)
# Get the number of layers
num_layers = len(model.model.layers)
for i in range(num_layers):
layer = model.model.layers[i]
# Get the current intermediate size
intermediate_size = layer.mlp.gate_proj.weight.shape[0]
# Randomly select half of the intermediate dimensions to keep
dims_to_keep = random.sample(range(intermediate_size), intermediate_size // 2)
dims_to_keep.sort()
# Prune the MLP layers
layer.mlp.gate_proj.weight = torch.nn.Parameter(layer.mlp.gate_proj.weight[dims_to_keep])
layer.mlp.up_proj.weight = torch.nn.Parameter(layer.mlp.up_proj.weight[dims_to_keep])
layer.mlp.down_proj.weight = torch.nn.Parameter(layer.mlp.down_proj.weight[:, dims_to_keep])
print(f"Layer {i} MLP pruned. New shapes:")
print(f"gate_proj: {layer.mlp.gate_proj.weight.shape}")
print(f"up_proj: {layer.mlp.up_proj.weight.shape}")
print(f"down_proj: {layer.mlp.down_proj.weight.shape}")
print()
# Update the configuration
config.intermediate_size = intermediate_size // 2
# Create the output directory if it doesn't exist
output_path = "/home/gcpuser/models/Nemo_RNG_Prune"
os.makedirs(output_path, exist_ok=True)
# Save the modified model, tokenizer, and configuration
model.save_pretrained(output_path)
tokenizer.save_pretrained(output_path)
config.save_pretrained(output_path)
# Copy other necessary files
files_to_copy = ['special_tokens_map.json', 'tokenizer_config.json', 'vocab.json', 'merges.txt']
for file in files_to_copy:
src = os.path.join(model_path, file)
dst = os.path.join(output_path, file)
if os.path.exists(src):
shutil.copy2(src, dst)
print(f"Modified model, tokenizer, and configuration saved to {output_path}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment