Last active
February 15, 2024 23:55
-
-
Save silphendio/535cd9c1821aa1290aa10d587b76a49c to your computer and use it in GitHub Desktop.
Create LLM slices at runtime with exllamav2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# to use this, first install python and exllamav2 (https://github.com/turboderp/exllamav2) | |
# load a model, rearrange the layers as you like, set generation parameters, and run it | |
# duplicate layers share tensors, but still need extra memory for the cache | |
# thanks to @dnhkng for showing that the cache needs to be re-created | |
# licensed under WTFPL (http://www.wtfpl.net/about/) - Silphendio | |
from exllamav2 import * | |
from exllamav2.generator import * | |
import sys, torch | |
from copy import copy | |
config = ExLlamaV2Config() | |
config.model_dir = "./TinyLlama-1.1B-Chat-v1.0-5.0bpw-h6-exl2" | |
config.prepare() | |
model = ExLlamaV2(config) | |
cache = ExLlamaV2Cache_8bit(model, lazy = True) | |
print("Loading model...") | |
model.load_autosplit(cache) | |
tokenizer = ExLlamaV2Tokenizer(config) | |
gen_settings = ExLlamaV2Sampler.Settings() | |
## mix layers here | |
layer_arrangement = list(range(0,14)) + list(range(8,22)) | |
# modules arangement: [embedding, [...layers], rms-norm, head] | |
# where each layer is [attention, mlp] | |
old_modules = model.modules | |
model.modules = old_modules[:1] | |
for i, idx in enumerate(layer_arrangement): | |
model.modules += [copy(old_modules[idx*2 + 1])] | |
model.modules[-1].layer_idx = i # for duplicate layers to use a different cache | |
model.modules += [old_modules[idx*2 + 2]] | |
model.modules += old_modules[-2:] | |
model.head_layer_idx = len(model.modules) -1 | |
model.config.num_hidden_layers = len(layer_arrangement) | |
model.last_kv_layer_idx = len(model.modules) -4 | |
print('Re-creating cache') | |
del cache | |
model.cache_map = {} | |
model.set_cache_map() | |
cache = ExLlamaV2Cache_8bit(model) | |
# this needs to be re-created after rearranging layers | |
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer) | |
generator.set_stop_conditions([tokenizer.eos_token_id]) | |
## mix layers end | |
# adjust generation settings | |
gen_settings.temperature = 0.0 # for deterministic results | |
#gen_settings.top_k = 50 | |
#gen_settings.top_p = 0.8 | |
#gen_settings.min_p = 0 | |
max_response_length = 512 | |
print("starting generation") | |
text = """<|system|> | |
You are a chatbot who can help code!</s> | |
<|user|> | |
Write me a python script to blink an LED on a raspberry PI.</s> | |
<|assistant|>""" | |
print("\n" + text, end="") | |
instruction_ids = tokenizer.encode(text, add_bos = True) | |
context_ids = instruction_ids if generator.sequence_ids is None \ | |
else torch.cat([generator.sequence_ids, instruction_ids], dim = -1) | |
generator.begin_stream(context_ids, gen_settings) | |
for _ in range(max_response_length): | |
chunk, eos, _ = generator.stream() | |
if eos: break | |
text += chunk | |
if text.endswith("<|user|>"): | |
break | |
print(chunk, end = "") | |
sys.stdout.flush() | |
text += "\n" | |
# cleanup | |
model.modules = old_modules | |
model.unload() | |
del cache |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment