Last active
June 5, 2024 14:36
-
-
Save abetlen/db9f3015e6d5bcc7d00493fa7b368655 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import json | |
| import typing | |
| import pathlib | |
| import argparse | |
| import numpy as np | |
| import numpy.typing as npt | |
| import gguf | |
| from gguf import KEY_ATTENTION_HEAD_COUNT, KEY_ATTENTION_LAYERNORM_EPS, KEY_BLOCK_COUNT, KEY_EMBEDDING_LENGTH, KEY_FEED_FORWARD_LENGTH, GGUFWriter, TokenType, SpecialVocab | |
| from safetensors import safe_open | |
| class SafetensorsIndexFile(typing.TypedDict): | |
| weight_map: typing.Dict[str, str] | |
| class SafetensorsIndex: | |
| def __init__(self, index_file_path: str): | |
| directory = os.path.dirname(index_file_path) | |
| self.index = typing.cast(SafetensorsIndexFile, json.load(open(index_file_path))) | |
| self.weight_map = self.index["weight_map"] | |
| files = set(self.weight_map.values()) | |
| self.tensors = {file: safe_open(os.path.join(directory, file), framework="np") for file in files} | |
| def get_tensor(self, key: str) -> npt.NDArray[np.float32]: | |
| return typing.cast(npt.NDArray[np.float32], self.tensors[self.weight_map[key]].get_tensor(key)) # type: ignore | |
| def extract_key(raw_key: str, arch: str) -> str: | |
| return raw_key.format(arch=arch) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "-d", | |
| "--dir-model", | |
| required=True, | |
| help="path to directory containing the tokenizer", | |
| ) | |
| args = parser.parse_args() | |
| dir_model = pathlib.Path(args.dir_model) | |
| # set model name to folder name | |
| name = dir_model.name | |
| tensors = SafetensorsIndex((dir_model / "model.safetensors.index.json").as_posix()) | |
| # Load the model config | |
| config = json.load(open(dir_model / "config.json")) | |
| # text config is based on mistral v0.1 | |
| text_config = { | |
| "vocab_size": 32000, | |
| "hidden_size": 4096, | |
| "intermediate_size": 14336, | |
| "num_hidden_layers": 32, | |
| "num_attention_heads": 32, | |
| "num_key_value_heads": 8, | |
| "hidden_act": "silu", | |
| "max_position_embeddings": 4096 * 32, | |
| "rms_norm_eps": 1e-05, | |
| "bos_token_id": 1, | |
| "eos_token_id": 2, | |
| "tie_word_embeddings": False, | |
| "rope_theta": 10000.0, | |
| "sliding_window": 4096 | |
| } | |
| text_config.update(config["text_config"]) | |
| vision_config = config["vision_config"] | |
| # https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics2/configuration_idefics2.py#L129 | |
| perceiver_config = config.get("perceiver_config", { | |
| "hidden_act": "silu", | |
| "resampler_n_latents": 64, | |
| "resampler_depth": 3, | |
| "resampler_n_heads": 16, | |
| "resampler_head_dim": 96, | |
| "num_key_value_heads": 4, | |
| "attention_dropout": 0.0, | |
| }) | |
| ### Vision encoder | |
| ftype = 1 # fp16 | |
| fname_out = f"{name}-vision-model-f16.gguf" | |
| fout = GGUFWriter(fname_out, arch="clip") | |
| fout.add_bool("clip.has_text_encoder", False) | |
| fout.add_bool("clip.has_vision_encoder", True) | |
| fout.add_bool("clip.has_llava_projector", True) | |
| fout.add_file_type(ftype) | |
| model_name = "idefics2" | |
| fout.add_name(model_name) | |
| fout.add_description("Vision encoder for " + model_name) | |
| fout.add_string("clip.projector_type", "idefics2") | |
| n_layers_clip = vision_config["num_hidden_layers"] | |
| # vision model hparams | |
| VISION = "clip.vision" | |
| fout.add_uint32("clip.vision.image_size", vision_config["image_size"]) # Update as necessary | |
| fout.add_uint32("clip.vision.patch_size", vision_config["patch_size"]) # Update as necessary | |
| fout.add_uint32(extract_key(KEY_EMBEDDING_LENGTH, VISION), vision_config["hidden_size"]) | |
| fout.add_uint32(extract_key(KEY_FEED_FORWARD_LENGTH, VISION), vision_config["intermediate_size"]) | |
| fout.add_uint32("clip.vision.projection_dim", 4096) # Update as necessary | |
| fout.add_uint32(extract_key(KEY_ATTENTION_HEAD_COUNT, VISION), vision_config["num_attention_heads"]) | |
| fout.add_float32(extract_key(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6) | |
| fout.add_uint32(extract_key(KEY_BLOCK_COUNT, VISION), n_layers_clip + 1) | |
| fout.add_array("clip.vision.image_mean", [0.5, 0.5, 0.5]) | |
| fout.add_array("clip.vision.image_std", [0.5, 0.5, 0.5]) | |
| fout.add_bool("clip.use_gelu", True) # using regular GELU instead of quick | |
| # connector | |
| # model.connector | |
| # model.connector.modality_projection.down_proj.weight [4 096, 14 336] | |
| # F32 | |
| fout.add_tensor( | |
| "mm.mp.ffn_down.weight", | |
| tensors.get_tensor("model.connector.modality_projection.down_proj.weight").astype(np.float16), | |
| ) | |
| # model.connector.modality_projection.gate_proj.weight [14 336, 1 152] | |
| # F32 | |
| fout.add_tensor( | |
| "mm.mp.ffn_gate.weight", | |
| tensors.get_tensor("model.connector.modality_projection.gate_proj.weight").astype(np.float16), | |
| ) | |
| # model.connector.modality_projection.up_proj.weight [14 336, 1 152] | |
| # F32 | |
| fout.add_tensor( | |
| "mm.mp.ffn_up.weight", | |
| tensors.get_tensor("model.connector.modality_projection.up_proj.weight").astype(np.float16), | |
| ) | |
| # model.connector.perceiver_resampler.latents [64, 4 096] | |
| # F32 | |
| fout.add_tensor( | |
| "mm.pr.latents.weight", | |
| tensors.get_tensor("model.connector.perceiver_resampler.latents").astype(np.float32), | |
| ) | |
| # model.connector.perceiver_resampler.norm.weight [4 096] | |
| # F32 | |
| fout.add_tensor( | |
| "mm.pr.ln0.weight", | |
| tensors.get_tensor("model.connector.perceiver_resampler.norm.weight").astype(np.float32), | |
| ) | |
| for i in range(3): | |
| # model.connector.perceiver_resampler.layers.0.input_context_norm.weight [4 096] | |
| # F32 | |
| fout.add_tensor( | |
| f"mm.pr.blk.{i}.ln0.weight", | |
| tensors.get_tensor(f"model.connector.perceiver_resampler.layers.{i}.input_context_norm.weight").astype(np.float32), | |
| ) | |
| # model.connector.perceiver_resampler.layers.0.input_latents_norm.weight [4 096] | |
| # F32 | |
| fout.add_tensor( | |
| f"mm.pr.blk.{i}.ln1.weight", | |
| tensors.get_tensor(f"model.connector.perceiver_resampler.layers.{i}.input_latents_norm.weight").astype(np.float32), | |
| ) | |
| # model.connector.perceiver_resampler.layers.0.mlp.down_proj.weight [4 096, 16 384] | |
| # F32 | |
| fout.add_tensor( | |
| f"mm.pr.blk.{i}.ffn_down.weight", | |
| tensors.get_tensor(f"model.connector.perceiver_resampler.layers.{i}.mlp.down_proj.weight").astype(np.float16), | |
| ) | |
| # model.connector.perceiver_resampler.layers.0.mlp.gate_proj.weight [16 384, 4 096] | |
| # F32 | |
| fout.add_tensor( | |
| f"mm.pr.blk.{i}.ffn_gate.weight", | |
| tensors.get_tensor(f"model.connector.perceiver_resampler.layers.{i}.mlp.gate_proj.weight").astype(np.float16), | |
| ) | |
| # model.connector.perceiver_resampler.layers.0.mlp.up_proj.weight [16 384, 4 096] | |
| # F32 | |
| fout.add_tensor( | |
| f"mm.pr.blk.{i}.ffn_up.weight", | |
| tensors.get_tensor(f"model.connector.perceiver_resampler.layers.{i}.mlp.up_proj.weight").astype(np.float16), | |
| ) | |
| # model.connector.perceiver_resampler.layers.0.post_attention_layernorm.weight [4 096] | |
| # F32 | |
| fout.add_tensor( | |
| f"mm.pr.blk.{i}.ln2.weight", | |
| tensors.get_tensor(f"model.connector.perceiver_resampler.layers.{i}.post_attention_layernorm.weight").astype(np.float32), | |
| ) | |
| # model.connector.perceiver_resampler.layers.0.self_attn.k_proj.weight [384, 4 096] | |
| # F32 | |
| fout.add_tensor( | |
| f"mm.pr.blk.{i}.attn_k.weight", | |
| tensors.get_tensor(f"model.connector.perceiver_resampler.layers.{i}.self_attn.k_proj.weight").astype(np.float16), | |
| ) | |
| # model.connector.perceiver_resampler.layers.0.self_attn.o_proj.weight [4 096, 1 536] | |
| # F32 | |
| fout.add_tensor( | |
| f"mm.pr.blk.{i}.attn_o.weight", | |
| tensors.get_tensor(f"model.connector.perceiver_resampler.layers.{i}.self_attn.o_proj.weight").astype(np.float16), | |
| ) | |
| # model.connector.perceiver_resampler.layers.0.self_attn.q_proj.weight [1 536, 4 096] | |
| # F32 | |
| fout.add_tensor( | |
| f"mm.pr.blk.{i}.attn_q.weight", | |
| tensors.get_tensor(f"model.connector.perceiver_resampler.layers.{i}.self_attn.q_proj.weight").astype(np.float16), | |
| ) | |
| # model.connector.perceiver_resampler.layers.0.self_attn.v_proj.weight [384, 4 096] | |
| # F32 | |
| fout.add_tensor( | |
| f"mm.pr.blk.{i}.attn_v.weight", | |
| tensors.get_tensor(f"model.connector.perceiver_resampler.layers.{i}.self_attn.v_proj.weight").astype(np.float16), | |
| ) | |
| # vision_model | |
| fout.add_tensor( | |
| "v.position_embd.weight", | |
| tensors.get_tensor("model.vision_model.embeddings.position_embedding.weight").astype(np.float16), | |
| ) | |
| fout.add_tensor( | |
| "v.patch_embd.weight", | |
| tensors.get_tensor("model.vision_model.embeddings.patch_embedding.weight") | |
| .reshape(vision_config["hidden_size"], 3, vision_config["patch_size"], vision_config["patch_size"]) | |
| .astype(np.float16), | |
| ) | |
| fout.add_tensor( | |
| "v.patch_embd.bias", | |
| tensors.get_tensor("model.vision_model.embeddings.patch_embedding.bias").astype(np.float32), | |
| ) | |
| fout.add_tensor( | |
| "v.post_ln.weight", | |
| tensors.get_tensor("model.vision_model.post_layernorm.weight").astype(np.float32), | |
| ) | |
| fout.add_tensor( | |
| "v.post_ln.bias", | |
| tensors.get_tensor("model.vision_model.post_layernorm.bias").astype(np.float32), | |
| ) | |
| def add_vision_tensor(blk_id: int, gguf_id: typing.Optional[int]=None): | |
| if gguf_id is None: | |
| gguf_id = blk_id | |
| attn_prefix = f"model.vision_model.encoder.layers.{blk_id}.self_attn." | |
| fout.add_tensor( | |
| f"v.blk.{gguf_id}.attn_q.weight", | |
| tensors.get_tensor(f"{attn_prefix}q_proj.weight").astype(np.float32), | |
| ) | |
| fout.add_tensor( | |
| f"v.blk.{gguf_id}.attn_q.bias", | |
| tensors.get_tensor(f"{attn_prefix}q_proj.bias").astype(np.float32), | |
| ) | |
| fout.add_tensor( | |
| f"v.blk.{gguf_id}.attn_k.weight", | |
| tensors.get_tensor(f"{attn_prefix}k_proj.weight").astype(np.float32), | |
| ) | |
| fout.add_tensor( | |
| f"v.blk.{gguf_id}.attn_k.bias", | |
| tensors.get_tensor(f"{attn_prefix}k_proj.bias").astype(np.float32), | |
| ) | |
| fout.add_tensor( | |
| f"v.blk.{gguf_id}.attn_v.weight", | |
| tensors.get_tensor(f"{attn_prefix}v_proj.weight").astype(np.float32), | |
| ) | |
| fout.add_tensor( | |
| f"v.blk.{gguf_id}.attn_v.bias", | |
| tensors.get_tensor(f"{attn_prefix}v_proj.bias").astype(np.float32), | |
| ) | |
| fout.add_tensor( | |
| f"v.blk.{gguf_id}.attn_out.weight", | |
| tensors.get_tensor(f"{attn_prefix}out_proj.weight").astype(np.float32), | |
| ) | |
| fout.add_tensor( | |
| f"v.blk.{gguf_id}.attn_out.bias", | |
| tensors.get_tensor(f"{attn_prefix}out_proj.bias").astype(np.float32), | |
| ) | |
| fout.add_tensor( | |
| f"v.blk.{gguf_id}.ln1.weight", | |
| tensors.get_tensor(f"model.vision_model.encoder.layers.{blk_id}.layer_norm1.weight").astype(np.float32), | |
| ) | |
| fout.add_tensor( | |
| f"v.blk.{gguf_id}.ln1.bias", | |
| tensors.get_tensor(f"model.vision_model.encoder.layers.{blk_id}.layer_norm1.bias").astype(np.float32), | |
| ) | |
| fout.add_tensor( | |
| f"v.blk.{gguf_id}.ffn_down.weight", | |
| tensors.get_tensor(f"model.vision_model.encoder.layers.{blk_id}.mlp.fc1.weight").astype(np.float32), | |
| ) | |
| fout.add_tensor( | |
| f"v.blk.{gguf_id}.ffn_down.bias", | |
| tensors.get_tensor(f"model.vision_model.encoder.layers.{blk_id}.mlp.fc1.bias").astype(np.float32), | |
| ) | |
| fout.add_tensor( | |
| f"v.blk.{gguf_id}.ffn_up.weight", | |
| tensors.get_tensor(f"model.vision_model.encoder.layers.{blk_id}.mlp.fc2.weight").astype(np.float32), | |
| ) | |
| fout.add_tensor( | |
| f"v.blk.{gguf_id}.ffn_up.bias", | |
| tensors.get_tensor(f"model.vision_model.encoder.layers.{blk_id}.mlp.fc2.bias").astype(np.float32), | |
| ) | |
| fout.add_tensor( | |
| f"v.blk.{gguf_id}.ln2.weight", | |
| tensors.get_tensor(f"model.vision_model.encoder.layers.{blk_id}.layer_norm2.weight").astype(np.float32), | |
| ) | |
| fout.add_tensor( | |
| f"v.blk.{gguf_id}.ln2.bias", | |
| tensors.get_tensor(f"model.vision_model.encoder.layers.{blk_id}.layer_norm2.bias").astype(np.float32), | |
| ) | |
| for i in range(n_layers_clip): | |
| add_vision_tensor(i) | |
| # Duplicate the last block (llava-cli skips over this) | |
| add_vision_tensor(n_layers_clip - 1, n_layers_clip) | |
| fout.write_header_to_file() | |
| fout.write_kv_data_to_file() | |
| fout.write_tensors_to_file() | |
| fout.close() | |
| ### Text Model | |
| # general GGUF init | |
| fname_out = f"{name}-text-model-f16.gguf" | |
| fout = GGUFWriter(fname_out, arch="llama") | |
| ftype = 1 | |
| block_count = text_config["num_hidden_layers"] | |
| fout.add_name(name) | |
| fout.add_block_count(block_count) | |
| fout.add_context_length(text_config["max_position_embeddings"]) | |
| fout.add_embedding_length(text_config["hidden_size"]) | |
| fout.add_feed_forward_length(text_config["intermediate_size"]) | |
| fout.add_head_count(text_config["num_attention_heads"]) | |
| fout.add_head_count_kv(text_config["num_key_value_heads"]) | |
| fout.add_rope_freq_base(text_config["rope_theta"]) | |
| fout.add_layer_norm_rms_eps(text_config["rms_norm_eps"]) | |
| fout.add_file_type(ftype) | |
| fout.add_vocab_size(text_config["vocab_size"]) | |
| fout.add_rope_dimension_count( | |
| text_config["hidden_size"] // text_config["num_attention_heads"] | |
| ) | |
| tokenizer_config_file = dir_model / 'tokenizer_config.json' | |
| if tokenizer_config_file.is_file(): | |
| with open(tokenizer_config_file, "r", encoding="utf-8") as f: | |
| tokenizer_config_json = json.load(f) | |
| if "add_prefix_space" in tokenizer_config_json: | |
| fout.add_add_space_prefix(tokenizer_config_json["add_prefix_space"]) | |
| ### Tokenizer | |
| # Taken from _set_vocab_sentencepiece | |
| from enum import IntEnum | |
| class SentencePieceTokenTypes(IntEnum): | |
| NORMAL = 1 | |
| UNKNOWN = 2 | |
| CONTROL = 3 | |
| USER_DEFINED = 4 | |
| UNUSED = 5 | |
| BYTE = 6 | |
| from sentencepiece import SentencePieceProcessor | |
| tokenizer_path = dir_model / 'tokenizer.model' | |
| tokens: typing.List[bytes] = [] | |
| scores: typing.List[float] = [] | |
| toktypes: typing.List[int] = [] | |
| if not tokenizer_path.is_file(): | |
| raise FileNotFoundError(f"File not found: {tokenizer_path}") | |
| tokenizer = SentencePieceProcessor() | |
| tokenizer.LoadFromFile(str(tokenizer_path)) | |
| vocab_size = text_config["vocab_size"] | |
| tokens: typing.List[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] | |
| scores: typing.List[float] = [-10000.0] * vocab_size | |
| toktypes: typing.List[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size | |
| for token_id in range(tokenizer.vocab_size()): | |
| piece = tokenizer.IdToPiece(token_id) | |
| text = piece.encode("utf-8") | |
| score = tokenizer.GetScore(token_id) | |
| toktype = SentencePieceTokenTypes.NORMAL | |
| if tokenizer.IsUnknown(token_id): | |
| toktype = SentencePieceTokenTypes.UNKNOWN | |
| elif tokenizer.IsControl(token_id): | |
| toktype = SentencePieceTokenTypes.CONTROL | |
| elif tokenizer.IsUnused(token_id): | |
| toktype = SentencePieceTokenTypes.UNUSED | |
| elif tokenizer.IsByte(token_id): | |
| toktype = SentencePieceTokenTypes.BYTE | |
| tokens[token_id] = text | |
| scores[token_id] = score | |
| toktypes[token_id] = toktype | |
| added_tokens_file = dir_model / 'added_tokens.json' | |
| if added_tokens_file.is_file(): | |
| with open(added_tokens_file, "r", encoding="utf-8") as f: | |
| added_tokens_json = json.load(f) | |
| for key in added_tokens_json: | |
| token_id = added_tokens_json[key] | |
| if (token_id >= vocab_size): | |
| print(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}') | |
| continue | |
| tokens[token_id] = key.encode("utf-8") | |
| scores[token_id] = -1000.0 | |
| toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED | |
| if vocab_size > len(tokens): | |
| pad_count = vocab_size - len(tokens) | |
| print(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]") | |
| for i in range(1, pad_count + 1): | |
| tokens.append(bytes(f"[PAD{i}]", encoding="utf-8")) | |
| scores.append(-1000.0) | |
| toktypes.append(SentencePieceTokenTypes.UNUSED) | |
| fout.add_tokenizer_model("llama") | |
| fout.add_tokenizer_pre("default") | |
| fout.add_token_list(tokens) | |
| fout.add_token_scores(scores) | |
| fout.add_token_types(toktypes) | |
| special_vocab = gguf.SpecialVocab(dir_model, n_vocab=len(tokens)) | |
| special_vocab.add_to_gguf(fout) | |
| def permute(weights: npt.NDArray[np.float16], n_head: int, n_head_kv: typing.Optional[int]): | |
| if n_head_kv is not None and n_head != n_head_kv: | |
| n_head = n_head_kv | |
| return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) | |
| .swapaxes(1, 2) | |
| .reshape(weights.shape)) | |
| n_head = typing.cast(int, text_config["num_attention_heads"]) | |
| n_kv_head = typing.cast(int, text_config["num_key_value_heads"]) | |
| fout.add_tensor( | |
| "token_embd.weight", | |
| tensors.get_tensor("model.text_model.embed_tokens.weight").astype(np.float32), | |
| ) | |
| def add_text_tensor(i: int): | |
| fout.add_tensor( | |
| f"blk.{i}.attn_norm.weight", | |
| tensors.get_tensor(f"model.text_model.layers.{i}.input_layernorm.weight").astype( | |
| np.float32 | |
| ), | |
| ) | |
| fout.add_tensor( | |
| f"blk.{i}.ffn_down.weight", | |
| tensors.get_tensor(f"model.text_model.layers.{i}.mlp.down_proj.weight").astype( | |
| np.float16 | |
| ), | |
| ) | |
| fout.add_tensor( | |
| f"blk.{i}.ffn_gate.weight", | |
| tensors.get_tensor(f"model.text_model.layers.{i}.mlp.gate_proj.weight").astype( | |
| np.float16 | |
| ), | |
| ) | |
| fout.add_tensor( | |
| f"blk.{i}.ffn_up.weight", | |
| tensors.get_tensor(f"model.text_model.layers.{i}.mlp.up_proj.weight").astype( | |
| np.float16 | |
| ), | |
| ) | |
| fout.add_tensor( | |
| f"blk.{i}.ffn_norm.weight", | |
| tensors.get_tensor(f"model.text_model.layers.{i}.post_attention_layernorm.weight").astype( | |
| np.float32 | |
| ), | |
| ) | |
| fout.add_tensor( | |
| f"blk.{i}.attn_k.weight", | |
| permute( | |
| tensors.get_tensor( | |
| f"model.text_model.layers.{i}.self_attn.k_proj.weight" | |
| ).astype(np.float16), | |
| n_head, | |
| n_kv_head | |
| ), | |
| ) | |
| fout.add_tensor( | |
| f"blk.{i}.attn_output.weight", | |
| tensors.get_tensor( | |
| f"model.text_model.layers.{i}.self_attn.o_proj.weight" | |
| ).astype(np.float16), | |
| ) | |
| fout.add_tensor( | |
| f"blk.{i}.attn_q.weight", | |
| permute( | |
| tensors.get_tensor( | |
| f"model.text_model.layers.{i}.self_attn.q_proj.weight" | |
| ).astype(np.float16), | |
| n_head, | |
| n_head, | |
| ) | |
| ) | |
| fout.add_tensor( | |
| f"blk.{i}.attn_v.weight", | |
| tensors.get_tensor( | |
| f"model.text_model.layers.{i}.self_attn.v_proj.weight" | |
| ).astype(np.float16), | |
| ) | |
| for i in range(32): # Update as necessary | |
| add_text_tensor(i) | |
| fout.add_tensor( | |
| "output_norm.weight", | |
| tensors.get_tensor("model.text_model.norm.weight").astype(np.float32), | |
| ) | |
| fout.add_tensor( | |
| "output.weight", | |
| tensors.get_tensor("lm_head.weight").astype(np.float32), | |
| ) | |
| fout.write_header_to_file() | |
| fout.write_kv_data_to_file() | |
| fout.write_tensors_to_file() | |
| fout.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment