Skip to content

Instantly share code, notes, and snippets.

@ngxson
Last active July 10, 2024 22:27
Show Gist options
  • Save ngxson/bca68453aa411c058e6c84b7f7925223 to your computer and use it in GitHub Desktop.
Save ngxson/bca68453aa411c058e6c84b7f7925223 to your computer and use it in GitHub Desktop.
import logging
import argparse
import contextlib
import json
import os
import re
import sys
import numpy as np
import math
import torch
from pathlib import Path
if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
import gguf
# Download non-MOE model here:
# https://huggingface.co/ngxson/test_gguf_lora_adapter/blob/main/stories15M.gguf
#
# This work by repeating the weight of base model to create 4 experts
#
# Run: ./llama-cli -m ./fake_moe.gguf -n 20 --temp 0
# expected output: Once upon a time, there was a little girl named Lily. She loved to play outside in
#
# With lora: ./llama-cli -m ./fake_moe.gguf -n 20 --temp 0 --lora ./fake_moe_lora.gguf
# expected output: (should be the same as the case below)
#
# With merged model: ./llama-cli -m ./fake_moe_lora_merged.gguf -n 20 --temp 0
# expected output: other othergggggggggggggggggg
### decode_field and get_field_data are copied from gguf_new_metadata.py
def decode_field(field: gguf.ReaderField | None):
if field and field.types:
main_type = field.types[0]
if main_type == gguf.GGUFValueType.ARRAY:
sub_type = field.types[-1]
if sub_type == gguf.GGUFValueType.STRING:
return [str(bytes(field.parts[idx]), encoding='utf-8') for idx in field.data]
else:
return [pv for idx in field.data for pv in field.parts[idx].tolist()]
if main_type == gguf.GGUFValueType.STRING:
return str(bytes(field.parts[-1]), encoding='utf-8')
else:
return field.parts[-1][0]
return None
def get_field_data(reader: gguf.GGUFReader, key: str):
field = reader.get_field(key)
return decode_field(field)
reader = gguf.GGUFReader(path='stories15M.gguf')
orig_tensor_map = {}
for t in reader.tensors:
orig_tensor_map[t.name] = t.data
#print(t.name)
#print(orig_tensor_map)
#exit(1)
#### generate base model MOE
N_HEAD = 6
N_LAYERS = 6
N_EXPERTS = 4
N_FF = 768
N_EMBD = 288
gguf_writer = gguf.GGUFWriter(path=None, arch='llama')
def set_hparams():
gguf_writer.add_name('mixtral_fake')
gguf_writer.add_block_count(N_LAYERS)
gguf_writer.add_context_length(128)
gguf_writer.add_embedding_length(N_EMBD)
gguf_writer.add_feed_forward_length(N_FF)
gguf_writer.add_head_count(N_HEAD)
gguf_writer.add_head_count_kv(N_HEAD)
gguf_writer.add_expert_count(N_EXPERTS)
gguf_writer.add_expert_used_count(2)
gguf_writer.add_rope_dimension_count(48)
gguf_writer.add_layer_norm_rms_eps(0.000001)
gguf_writer.add_file_type(gguf.LlamaFileType.MOSTLY_F16)
gguf_writer.add_bos_token_id(1)
gguf_writer.add_eos_token_id(2)
gguf_writer.add_unk_token_id(0)
gguf_writer.add_tokenizer_model("llama")
gguf_writer.add_token_list(get_field_data(reader, gguf.Keys.Tokenizer.LIST))
gguf_writer.add_token_scores(get_field_data(reader, gguf.Keys.Tokenizer.SCORES))
gguf_writer.add_token_types(get_field_data(reader, gguf.Keys.Tokenizer.TOKEN_TYPE))
set_hparams()
tensors = []
tensors.append(('token_embd.weight', (32000, N_EMBD)))
for il in range(N_LAYERS):
tensors.append((f'blk.{il}.attn_k.weight', (N_EMBD, N_EMBD)))
tensors.append((f'blk.{il}.attn_q.weight', (N_EMBD, N_EMBD)))
tensors.append((f'blk.{il}.attn_v.weight', (N_EMBD, N_EMBD)))
tensors.append((f'blk.{il}.attn_norm.weight', (N_EMBD)))
tensors.append((f'blk.{il}.attn_output.weight', (N_EMBD, N_EMBD)))
tensors.append((f'blk.{il}.ffn_norm.weight', (N_EMBD)))
tensors.append((f'blk.{il}.ffn_gate_inp.weight', (N_EXPERTS, N_EMBD)))
tensors.append((f'blk.{il}.ffn_down_exps.weight', (N_EXPERTS, N_EMBD, N_FF)))
tensors.append((f'blk.{il}.ffn_gate_exps.weight', (N_EXPERTS, N_FF, N_EMBD)))
tensors.append((f'blk.{il}.ffn_up_exps.weight', (N_EXPERTS, N_FF, N_EMBD)))
tensors.append((f'output.weight', (32000, N_EMBD)))
tensors.append((f'output_norm.weight', (N_EMBD)))
def get_orig_tensor(name):
return orig_tensor_map[name.replace('_exps', '')]
np.random.seed(0)
base_tensors = []
for name, shape in tensors:
if 'ffn_gate_inp' in name:
base_tensors.append((name, np.random.rand(*shape)))
elif 'ffn_norm' in name:
base_tensors.append((name, get_orig_tensor(name)))
elif 'ffn_' in name:
exp_3d_tensor = np.stack([get_orig_tensor(name)] * N_EXPERTS, axis=0)
base_tensors.append((name, exp_3d_tensor))
print('to 3D', name, exp_3d_tensor.shape)
elif name in orig_tensor_map:
base_tensors.append((name, get_orig_tensor(name)))
else:
print(f'unhandled tensor {name}')
exit(1)
print('>> base_tensors')
for name, t in base_tensors:
dtype = gguf.GGMLQuantizationType.F32
print(name, t.shape)
t = t.squeeze().astype(np.float32)
if t.ndim != 1:
t = t.squeeze().astype(np.float16)
dtype = gguf.GGMLQuantizationType.F16
gguf_writer.add_tensor(name, t, raw_dtype=dtype)
gguf_writer.write_header_to_file('./fake_moe.gguf')
gguf_writer.write_kv_data_to_file()
gguf_writer.write_tensors_to_file(progress=True)
gguf_writer.close()
#### generate LoRA with random weights
print('LoRA')
N_RANK = 32
gguf_writer = gguf.GGUFWriter(path=None, arch='llama')
gguf_writer.add_name('mixtral_fake')
gguf_writer.add_string('training.type', 'finetune_lora')
lora_tensors = []
for name, shape in tensors:
if any(x in name for x in ['attn_k', 'attn_q', 'attn_v', 'attn_output', 'ffn_gate_inp']):
lora_a = np.random.rand(N_RANK, shape[1]) * 0.015
lora_b = np.random.rand(shape[0], N_RANK) * 0.015
lora_tensors.append((f'{name}.lora_a', lora_a))
lora_tensors.append((f'{name}.lora_b', lora_b))
elif any(x in name for x in ['ffn_down', 'ffn_gate', 'ffn_up']):
lora_a = np.random.rand(N_EXPERTS, N_RANK, shape[2]) * 0.015
lora_b = np.random.rand(N_EXPERTS, shape[1], N_RANK) * 0.015
lora_tensors.append((f'{name}.lora_a', lora_a))
lora_tensors.append((f'{name}.lora_b', lora_b))
for name, t in lora_tensors:
dtype = gguf.GGMLQuantizationType.F32
t = t.squeeze().astype(np.float32)
print(name, t.shape)
gguf_writer.add_tensor(name, t, raw_dtype=dtype)
gguf_writer.write_header_to_file('./fake_moe_lora.gguf')
gguf_writer.write_kv_data_to_file()
gguf_writer.write_tensors_to_file(progress=True)
gguf_writer.close()
#### generate merge LoRA to original model
gguf_writer = gguf.GGUFWriter(path=None, arch='llama')
set_hparams()
lora_tensor_map = {}
for name, t in lora_tensors:
lora_tensor_map[name] = t
for name, t in base_tensors:
print(name, t.shape)
# merge lora to weight
if any(x in name for x in ['attn_k', 'attn_q', 'attn_v', 'attn_output', 'ffn_gate_inp', 'ffn_down', 'ffn_gate', 'ffn_up']):
lora_a = lora_tensor_map[f'{name}.lora_a']
lora_b = lora_tensor_map[f'{name}.lora_b']
print(lora_a.shape)
print(lora_b.shape)
t = np.add(t, np.matmul(lora_b, lora_a))
dtype = gguf.GGMLQuantizationType.F32
t = t.squeeze().astype(np.float32)
if t.ndim != 1:
t = t.squeeze().astype(np.float16)
dtype = gguf.GGMLQuantizationType.F16
gguf_writer.add_tensor(name, t, raw_dtype=dtype)
gguf_writer.write_header_to_file('./fake_moe_lora_merged.gguf')
gguf_writer.write_kv_data_to_file()
gguf_writer.write_tensors_to_file(progress=True)
gguf_writer.close()
@compilade
Copy link

compilade commented Jul 10, 2024

To make this work with 3D stacked experts, I suggest using Numpy-style shapes (row size at the end) instead of manually reversing them to/from the GGML shape order (row size first). GGUFWriter handles the reversing internally so that it can be given Numpy tensors. If a tensor is from GGUFReader, the corresponding Numpy tensor is in some_reader_tensor.data (like you're using to build the orig_tensor map), and their shape has the Numpy style (row size last).

Also the unpacking operator can be useful (e.g. hello(*some_tuple) is the same as hello(some_tuple[0], some_tuple[1], ...), and ranges (e.g. some_tuple[-2:] to get a tuple with the last 2 elements) can also be useful to get subsets of shapes.

@ngxson
Copy link
Author

ngxson commented Jul 10, 2024

@compilade OK thanks for the suggestion, so I reversed the shape in my tensors list.

The script is very draft so I may make mistake somewhere. Feel free to let me know if I understand it correctly.

@ngxson
Copy link
Author

ngxson commented Jul 10, 2024

OK so that's done. Still very messy, but at least the merged version still output the same result as before.

Experts tensors are now 3D:

blk.0.ffn_down_exps.weight (4, 288, 768)
blk.0.ffn_gate_exps.weight (4, 768, 288)
blk.0.ffn_up_exps.weight (4, 768, 288)

LoRA:
blk.0.ffn_down_exps.weight.lora_a (4, 32, 768)
blk.0.ffn_down_exps.weight.lora_b (4, 288, 32)
blk.0.ffn_gate_exps.weight.lora_a (4, 32, 288)
blk.0.ffn_gate_exps.weight.lora_b (4, 768, 32)
blk.0.ffn_up_exps.weight.lora_a (4, 32, 288)
blk.0.ffn_up_exps.weight.lora_b (4, 768, 32)

@compilade
Copy link

Nice! For tests of differences between the merged and hot-applied-LoRA model, I think llama-perplexity works too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment