Created
July 29, 2024 20:16
-
-
Save kalomaze/74a5cbbc3046e35024b657d1c1b0d9c6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig | |
import random | |
import os | |
import shutil | |
# Set a seed for reproducibility | |
random.seed(42) | |
# Load the model, tokenizer, and configuration | |
model_path = "/home/gcpuser/models/NemoBase" | |
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16) | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
config = AutoConfig.from_pretrained(model_path) | |
# Get the number of layers | |
num_layers = len(model.model.layers) | |
for i in range(num_layers): | |
layer = model.model.layers[i] | |
# Get the current intermediate size | |
intermediate_size = layer.mlp.gate_proj.weight.shape[0] | |
# Randomly select half of the intermediate dimensions to keep | |
dims_to_keep = random.sample(range(intermediate_size), intermediate_size // 2) | |
dims_to_keep.sort() | |
# Prune the MLP layers | |
layer.mlp.gate_proj.weight = torch.nn.Parameter(layer.mlp.gate_proj.weight[dims_to_keep]) | |
layer.mlp.up_proj.weight = torch.nn.Parameter(layer.mlp.up_proj.weight[dims_to_keep]) | |
layer.mlp.down_proj.weight = torch.nn.Parameter(layer.mlp.down_proj.weight[:, dims_to_keep]) | |
print(f"Layer {i} MLP pruned. New shapes:") | |
print(f"gate_proj: {layer.mlp.gate_proj.weight.shape}") | |
print(f"up_proj: {layer.mlp.up_proj.weight.shape}") | |
print(f"down_proj: {layer.mlp.down_proj.weight.shape}") | |
print() | |
# Update the configuration | |
config.intermediate_size = intermediate_size // 2 | |
# Create the output directory if it doesn't exist | |
output_path = "/home/gcpuser/models/Nemo_RNG_Prune" | |
os.makedirs(output_path, exist_ok=True) | |
# Save the modified model, tokenizer, and configuration | |
model.save_pretrained(output_path) | |
tokenizer.save_pretrained(output_path) | |
config.save_pretrained(output_path) | |
# Copy other necessary files | |
files_to_copy = ['special_tokens_map.json', 'tokenizer_config.json', 'vocab.json', 'merges.txt'] | |
for file in files_to_copy: | |
src = os.path.join(model_path, file) | |
dst = os.path.join(output_path, file) | |
if os.path.exists(src): | |
shutil.copy2(src, dst) | |
print(f"Modified model, tokenizer, and configuration saved to {output_path}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment