Created
February 1, 2024 22:02
-
-
Save cjpais/59fb7fcb5256ed0aea339b0a35eac899 to your computer and use it in GitHub Desktop.
llava 1.6 hack
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import glob | |
import os | |
import torch | |
from safetensors import safe_open | |
from safetensors.torch import save_file | |
ap = argparse.ArgumentParser() | |
ap.add_argument("-m", "--model", help="Path to LLaVA v1.5 model") | |
args = ap.parse_args() | |
# find the model part that includes the the multimodal projector weights | |
safetensors = sorted(glob.glob(f"{args.model}/model*.safetensors")) | |
path = safetensors[-2] #-1 for 34b. TODO search for all of them and make checkpoints from there | |
newline = safetensors[0] | |
n_ckpt = safe_open(newline, framework="pt") | |
new_checkpoint = {} | |
for key in n_ckpt.keys(): | |
if not key.startswith("model.image_newline"): | |
new_checkpoint[key] = n_ckpt.get_tensor(key) | |
save_file(new_checkpoint, newline) | |
checkpoint = safe_open(path, framework="pt") | |
mm_tensors = [k for k in checkpoint.keys() if k.startswith("model.mm_projector")] | |
# store these tensors in a new dictionary and torch.save them | |
projector = {name: checkpoint.get_tensor(name).float() for name in mm_tensors} | |
torch.save(projector, f"{args.model}/llava.projector") | |
# build new tensors without the projector | |
# remove these tensors from the checkpoint and save it again | |
new_checkpoint = {} | |
for key in checkpoint.keys(): | |
if not key.startswith("model.mm_projector"): | |
new_checkpoint[key] = checkpoint.get_tensor(key) | |
checkpoint = new_checkpoint | |
# BakLLaVA models contain CLIP tensors in it | |
clip_tensors = [k for k in checkpoint.keys() if k.startswith("model.vision_tower")] | |
if len(clip_tensors) > 0: | |
clip = {name.replace("vision_tower.vision_tower.", ""): checkpoint[name].float() for name in clip_tensors} | |
torch.save(clip, f"{args.model}/llava.clip") | |
# remove these tensors | |
for name in clip_tensors: | |
del checkpoint[name] | |
# added tokens should be removed to be able to convert Mistral models | |
if os.path.exists(f"{args.model}/added_tokens.json"): | |
with open(f"{args.model}/added_tokens.json", "w") as f: | |
f.write("{}\n") | |
save_file(checkpoint, path) | |
print("Done!") | |
print(f"Now you can convert {args.model} to a a regular LLaMA GGUF file.") | |
print(f"Also, use {args.model}/llava.projector to prepare a llava-encoder.gguf file.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment