Last active
October 14, 2024 18:07
-
-
Save silphendio/90f7e23b2b1ab6949fd4b35e7dd705cf to your computer and use it in GitHub Desktop.
Cutting up a llama and putting it back together
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# A simple script to demonstrate the sclicing and recombination of models at runtime | |
# inspired by mergekit | |
# Sadly, it doesn't work with quantisized models. | |
# | |
# public domain - silphendio | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer | |
import torch | |
model_path = 'gpt2' # huggingface name or local folder | |
output_folder = 'sliced_llama' | |
layer_arrangement = list(range(0,8)) + list(range(4,12)) | |
model = AutoModelForCausalLM.from_pretrained(model_path) | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
# rearrange layers | |
new_state_dict = model.state_dict().copy() | |
layer_keys_template = [key.replace('.0.', '.{}.') for key in model.state_dict() if '.0.' in key] | |
for new_layer, old_layer in enumerate(layer_arrangement): | |
for key in layer_keys_template: | |
new_state_dict[key.format(new_layer)] = model.state_dict()[key.format(old_layer)] | |
new_config = model.config | |
new_config.n_layer = len(layer_arrangement) # for gpt2 | |
new_config.num_hidden_layers = len(layer_arrangement) # for mistral / llama | |
# save the merged model | |
new_config.save_pretrained(output_folder) | |
torch.save(new_state_dict, output_folder + '/pytorch_model.bin') | |
# load the merged model from memory | |
model = AutoModelForCausalLM.from_pretrained(model_path, config=new_config, state_dict=new_state_dict) | |
del new_state_dict # don't need this anymore (too bad transformers couldn't reuse the memory) | |
##### test the merged model | |
prompt = "In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains." | |
inputs = tokenizer(prompt, return_tensors="pt") | |
streamer = TextStreamer(tokenizer) | |
model.generate(**inputs, streamer=streamer, do_sample=True, max_new_tokens=250) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment