Last active
April 26, 2024 14:47
-
-
Save CoffeeVampir3/4d8f0cf31677aa005eada071567e5f1b to your computer and use it in GitHub Desktop.
exllama minimum example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from flask import Flask, render_template | |
import torch | |
from flask_socketio import SocketIO, emit | |
from generation.make_instruct import get_generator_func | |
from generation.exllama_generator_wrapper import encode_message, encode_system, encode_header | |
import os,sys | |
app = Flask(__name__) | |
socketio = SocketIO(app) | |
system_prompt = "Respond to all inputs with EEE" | |
seed_msg = encode_message(tokenizer, "user", "hello world") | |
init_msg = encode_message(tokenizer, "assistant", "EEE") | |
init_head = encode_header(tokenizer, "assistant") | |
enc_sys_prompt = encode_system(tokenizer, system_prompt) | |
enc_sys_prompt.extend(seed_msg) | |
enc_sys_prompt.extend(init_msg) | |
testing = torch.tensor(enc_sys_prompt).unsqueeze(dim=0) | |
@app.route('/') | |
def index(): | |
return render_template('index.html') | |
@socketio.on('send_message') | |
def handle_send_message(message): | |
global generate, enc_sys_prompt, init_head | |
emit('user_message', message, broadcast=True) | |
emit('start_generation', broadcast=True) # Emit event to indicate generation start | |
next_message = encode_message(tokenizer, "user", message) | |
enc_sys_prompt.extend(next_message) | |
enc_sys_prompt.extend(init_head) | |
testing = torch.tensor(enc_sys_prompt).unsqueeze(dim=0) | |
for fragment, count in generate(instruction_ids=testing): | |
emit('stream_response', {'fragment': fragment}, broadcast=True) | |
emit('end_generation', broadcast=True) # Emit event to indicate generation end | |
if __name__ == '__main__': | |
socketio.run(app) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys, os, random | |
import torch | |
# A requirement for using exllamav2 api | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
from exllamav2 import( | |
ExLlamaV2, | |
ExLlamaV2Config, | |
ExLlamaV2Cache, | |
ExLlamaV2Tokenizer, | |
) | |
from exllamav2.generator import ( | |
ExLlamaV2StreamingGenerator, | |
ExLlamaV2Sampler | |
) | |
def load_model(model_directory): | |
config = ExLlamaV2Config() | |
config.model_dir = model_directory | |
config.prepare() | |
model = ExLlamaV2(config) | |
cache = ExLlamaV2Cache(model, lazy = True) | |
model.load_autosplit(cache) | |
tokenizer = ExLlamaV2Tokenizer(config) | |
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer) | |
return config, tokenizer, cache, generator | |
def encode_system(tokenizer, system_prompt): | |
bos_token = tokenizer.single_id("<|begin_of_text|>") | |
eot_token = tokenizer.single_id("<|eot_id|>") | |
tokens = [bos_token] | |
tokens.extend(encode_header(tokenizer, "system")) | |
system_ids = tokenizer.encode(system_prompt, add_bos = False).view(-1).tolist() | |
tokens.extend(system_ids) | |
tokens.append(eot_token) | |
return tokens | |
def encode_header(tokenizer, username): | |
tokens = [] | |
start_header = tokenizer.single_id("<|start_header_id|>") | |
end_header = tokenizer.single_id("<|end_header_id|>") | |
tokens.append(start_header) | |
tokens.extend(tokenizer.encode(username, add_bos = False).view(-1).tolist()) | |
tokens.append(end_header) | |
tokens.extend(tokenizer.encode("\n\n", add_bos = False).view(-1).tolist()) | |
return tokens | |
def encode_message(tokenizer, username, message): | |
eot_token = tokenizer.single_id("<|eot_id|>") | |
tokens = encode_header(tokenizer, username) | |
tokens.extend( | |
tokenizer.encode(message.strip(), add_bos = False).view(-1).tolist() | |
) | |
tokens.append(eot_token) | |
return tokens | |
def generate_response_stream(instruction_ids, tokenizer, generator, settings, stop_sequences=[]): | |
generator.begin_stream_ex(instruction_ids, settings) | |
stop_sequences.append(tokenizer.eos_token_id) | |
stop_sequences.append(128009) | |
generator.set_stop_conditions(stop_sequences) | |
while True: | |
res = generator.stream_ex() | |
if res["eos"]: | |
return | |
chunk = res["chunk"] | |
counts = len(res["chunk_token_ids"]) | |
yield chunk, counts |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys, os | |
from functools import partial | |
# Needed for exllamav2 lib | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
from exllamav2.generator import ( | |
ExLlamaV2Sampler | |
) | |
from .exllama_generator_wrapper import load_model, generate_response_stream | |
def get_generator_func(model_path): | |
abs_path = os.path.abspath(model_path) | |
config, tokenizer, cache, generator = load_model(model_path) | |
settings = ExLlamaV2Sampler.Settings() | |
settings.temperature = 2 | |
settings.top_k = 30 | |
settings.min_p = 0.1 | |
generate = partial(generate_response_stream, generator=generator, settings=settings, tokenizer=tokenizer) | |
return tokenizer, generate |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment