Skip to content

Instantly share code, notes, and snippets.

@silphendio
Last active February 15, 2024 23:55
Show Gist options
  • Save silphendio/535cd9c1821aa1290aa10d587b76a49c to your computer and use it in GitHub Desktop.
Save silphendio/535cd9c1821aa1290aa10d587b76a49c to your computer and use it in GitHub Desktop.
Create LLM slices at runtime with exllamav2
# to use this, first install python and exllamav2 (https://github.com/turboderp/exllamav2)
# load a model, rearrange the layers as you like, set generation parameters, and run it
# duplicate layers share tensors, but still need extra memory for the cache
# thanks to @dnhkng for showing that the cache needs to be re-created
# licensed under WTFPL (http://www.wtfpl.net/about/) - Silphendio
from exllamav2 import *
from exllamav2.generator import *
import sys, torch
from copy import copy
config = ExLlamaV2Config()
config.model_dir = "./TinyLlama-1.1B-Chat-v1.0-5.0bpw-h6-exl2"
config.prepare()
model = ExLlamaV2(config)
cache = ExLlamaV2Cache_8bit(model, lazy = True)
print("Loading model...")
model.load_autosplit(cache)
tokenizer = ExLlamaV2Tokenizer(config)
gen_settings = ExLlamaV2Sampler.Settings()
## mix layers here
layer_arrangement = list(range(0,14)) + list(range(8,22))
# modules arangement: [embedding, [...layers], rms-norm, head]
# where each layer is [attention, mlp]
old_modules = model.modules
model.modules = old_modules[:1]
for i, idx in enumerate(layer_arrangement):
model.modules += [copy(old_modules[idx*2 + 1])]
model.modules[-1].layer_idx = i # for duplicate layers to use a different cache
model.modules += [old_modules[idx*2 + 2]]
model.modules += old_modules[-2:]
model.head_layer_idx = len(model.modules) -1
model.config.num_hidden_layers = len(layer_arrangement)
model.last_kv_layer_idx = len(model.modules) -4
print('Re-creating cache')
del cache
model.cache_map = {}
model.set_cache_map()
cache = ExLlamaV2Cache_8bit(model)
# this needs to be re-created after rearranging layers
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
generator.set_stop_conditions([tokenizer.eos_token_id])
## mix layers end
# adjust generation settings
gen_settings.temperature = 0.0 # for deterministic results
#gen_settings.top_k = 50
#gen_settings.top_p = 0.8
#gen_settings.min_p = 0
max_response_length = 512
print("starting generation")
text = """<|system|>
You are a chatbot who can help code!</s>
<|user|>
Write me a python script to blink an LED on a raspberry PI.</s>
<|assistant|>"""
print("\n" + text, end="")
instruction_ids = tokenizer.encode(text, add_bos = True)
context_ids = instruction_ids if generator.sequence_ids is None \
else torch.cat([generator.sequence_ids, instruction_ids], dim = -1)
generator.begin_stream(context_ids, gen_settings)
for _ in range(max_response_length):
chunk, eos, _ = generator.stream()
if eos: break
text += chunk
if text.endswith("<|user|>"):
break
print(chunk, end = "")
sys.stdout.flush()
text += "\n"
# cleanup
model.modules = old_modules
model.unload()
del cache
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment