Skip to content

Instantly share code, notes, and snippets.

@CoffeeVampir3
Last active August 6, 2023 07:54
Show Gist options
  • Save CoffeeVampir3/fe09b5eede4a7375f2c0a0a64e792b36 to your computer and use it in GitHub Desktop.
Save CoffeeVampir3/fe09b5eede4a7375f2c0a0a64e792b36 to your computer and use it in GitHub Desktop.
import torch, os, sys, glob
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from model import ExLlama, ExLlamaCache, ExLlamaConfig
from lora import ExLlamaLora
from tokenizer import ExLlamaTokenizer
from generator import ExLlamaGenerator
from initialize import initialize_session, intitialize_model
from token_processor import KillSequenceProcessor
class ChatSession:
def __init__(self):
initialize_session()
self.model, self.generator, self.tokenizer, self.config = intitialize_model()
self.kill_processor = KillSequenceProcessor(["test"])
self.max_response_tokens = 2048
def generate_response(prompt, chat_session):
generator = chat_session.generator
kill_processor = chat_session.kill_processor
beam_search = (generator.settings.beams and generator.settings.beams >= 1 and generator.settings.beam_length and generator.settings.beam_length >= 1)
ids = generator.tokenizer.encode(prompt)
generator.gen_begin_reuse(ids)
seq_length = len(generator.tokenizer.decode(generator.sequence_actual[0]))
print(seq_length)
tail = seq_length
if beam_search:
generator.begin_beam_search()
token_getter = generator.beam_search
else:
token_getter = generator.gen_single_token
while(generator.gen_num_tokens() <= chat_session.max_response_tokens - 8):
token = token_getter()
titem = token.item()
#If it's the ending token replace it and end the generation.
if titem == generator.tokenizer.eos_token_id:
generator.replace_last_token(generator.tokenizer.newline_token_id)
break
stuff = generator.tokenizer.decode(generator.sequence_actual[0])
head = len(stuff)
chunk = stuff[tail:head]
tail = head
kill_result, kill_buf = kill_processor.process_stream_chunk(chunk)
if kill_result == KillSequenceProcessor.KillProcessorResult.KILL:
rewind_length = generator.tokenizer.encode(kill_buf).shape[-1]
generator.gen_rewind(rewind_length)
break
elif kill_result == KillSequenceProcessor.KillProcessorResult.YIELD:
yield kill_buf
generator.end_beam_search()
<!DOCTYPE html>
<html>
<head>
<title>SocketIO Test</title>
<script src="https://cdn.socket.io/4.6.0/socket.io.min.js" integrity="sha384-c79GN5VsunZvi+Q/WObgk2in0CbZsHnjEqvFxC5DxHn9lTfNce2WW6h2pH6u/kF+" crossorigin="anonymous"></script>
<link rel="stylesheet" type="text/css" href="{{ url_for('static', filename='styles.css') }}">
</head>
<body>
<div id="primary-container"></div>
<input type="text" id="message" placeholder="Type your message here">
<button id="send-button" onclick="sendMessage()">Send</button>
</body>
<script type="text/javascript" charset="utf-8">
var socket = io.connect('http://' + document.domain + ':' + location.port);
var to_right = true;
var primaryContainer = document.getElementById('primary-container');
var sendButton = document.getElementById('send-button');
socket.on('connect', function() {
socket.emit('startup', 'User has connected!');
});
socket.on('new_response', function(response_id) {
sendButton.disabled = true;
// Create a new text item (as a paragraph element)
var newItem = document.createElement('p');
// Set the text of the item
newItem.textContent = "";
// Store the response_id in a 'data-response-id' attribute
newItem.setAttribute('data-response-id', response_id);
// Add the item to the body of the document
newItem.classList.add('speech-bubble');
if (to_right) {
newItem.classList.add('right-bubble');
} else {
newItem.classList.add('left-bubble');
}
to_right = !to_right;
primaryContainer.appendChild(newItem);
});
socket.on('response_chunk', function(msg) {
let id = msg.id;
let data = msg.data;
// Find the text item with the matching response_id
let item = document.querySelector(`p[data-response-id="${id}"]`);
var newSpan = document.createElement('span');
// Set the text of the span
newSpan.textContent = data;
newSpan.style.color = "white"
// Depending on the content of the data, add a different CSS class
// if (ping_pong) {
// newSpan.style.color = "red";
// } else {
// newSpan.style.color = "blue";
// }
// ping_pong = !ping_pong
item.appendChild(newSpan);
});
socket.on('response_end', function(response_id) {
// Enable the send button
sendButton.disabled = false;
});
function sendMessage() {
var message = document.getElementById('message').value;
socket.emit('message', message);
}
</script>
</html>
import torch, os, sys, glob
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from model import ExLlama, ExLlamaCache, ExLlamaConfig
from lora import ExLlamaLora
from tokenizer import ExLlamaTokenizer
from generator import ExLlamaGenerator
import model_init
import argparse
def initialize_session():
torch.set_grad_enabled(False)
torch.cuda._lazy_init()
def get_model_path_at(path):
patterns = ["*.safetensors", "*.bin", "*.pt"]
model_paths = []
for pattern in patterns:
full_pattern = os.path.join(path, pattern)
model_paths = glob.glob(full_pattern)
if model_paths: # If there are any files matching the current pattern
break # Exit the loop as soon as we find a matching file
if model_paths: # If there are any files matching any of the patterns
return model_paths[0]
else:
return None # Return None if no matching files were found
def intitialize_model():
parser = argparse.ArgumentParser(description = "Simple chatbot example for ExLlama")
model_init.add_args(parser)
parser.add_argument("-lora", "--lora", type = str, help = "Path to LoRA binary to use during benchmark")
parser.add_argument("-loracfg", "--lora_config", type = str, help = "Path to LoRA config to use during benchmark")
parser.add_argument("-ld", "--lora_dir", type = str, help = "Path to LoRA config and binary. to use during benchmark")
parser.add_argument("-temp", "--temperature", type = float, help = "Temperature", default = 1.6)
parser.add_argument("-topk", "--top_k", type = int, help = "Top-K", default = 32)
parser.add_argument("-topp", "--top_p", type = float, help = "Top-P", default = 0.65)
parser.add_argument("-minp", "--min_p", type = float, help = "Min-P", default = 0.00)
parser.add_argument("-repp", "--repetition_penalty", type = float, help = "Repetition penalty", default = 1.65)
parser.add_argument("-repps", "--repetition_penalty_sustain", type = int, help = "Past length for repetition penalty", default = 256)
parser.add_argument("-beams", "--beams", type = int, help = "Number of beams for beam search", default = 0)
parser.add_argument("-beamlen", "--beam_length", type = int, help = "Number of future tokens to consider", default = 0)
args = parser.parse_args()
model_init.post_parse(args)
print(args)
model_init.get_model_files(args)
config = model_init.make_config(args)
model = ExLlama(config)
cache = ExLlamaCache(model)
tokenizer = ExLlamaTokenizer(args.tokenizer)
model_init.print_stats(model)
generator = ExLlamaGenerator(model, tokenizer, cache)
generator.settings.temperature = args.temperature
generator.settings.top_k = args.top_k
generator.settings.top_p = args.top_p
generator.settings.min_p = args.min_p
generator.settings.token_repetition_penalty_max = args.repetition_penalty
generator.settings.token_repetition_penalty_sustain = args.repetition_penalty_sustain
generator.settings.token_repetition_penalty_decay = generator.settings.token_repetition_penalty_sustain // 2
generator.settings.beams = args.beams
generator.settings.beam_length = args.beam_length
generator.settings.no_fused_attn = True
if args.lora_dir is not None:
lora_config_path = os.path.join(args.lora_dir, "adapter_config.json")
lora_path = get_model_path_at(args.lora_dir)
lora = ExLlamaLora(model, lora_config_path, lora_path)
generator.lora = lora
return model, generator, tokenizer, config
#primary-container {
width: 100%;
height: 80vh;
overflow-y: auto;
overflow-x: hidden;
}
.speech-bubble {
position: relative;
background-color: #111111;
border-radius: 5px;
padding: 20px;
width: 66vw;
}
.left-bubble {
float: left;
clear: both; /* This ensures that the item appears below the previous floating element */
}
.right-bubble {
float: right;
clear: both; /* This ensures that the item appears below the previous floating element */
}
body {
background: #11111111;
}
from enum import Enum
class StreamingTokenProcessor:
def __init__(self):
self.current_buffer = ""
class KillSequenceProcessor(StreamingTokenProcessor):
def __init__(self, kill_sequences):
super().__init__()
self.kill_sequences = kill_sequences
class KillProcessorResult(Enum):
KILL = 1
BUFFER = 0
YIELD = 2
def examine_buffer(self):
sequence = self.current_buffer.strip().lower()
for death_seq in self.kill_sequences:
if death_seq == sequence:
return self.KillProcessorResult.KILL
elif death_seq.startswith(sequence):
return self.KillProcessorResult.BUFFER
return self.KillProcessorResult.YIELD
def process_stream_chunk(self, chunk):
self.current_buffer += chunk
seqres = self.examine_buffer()
return_seq = self.current_buffer
if seqres == self.KillProcessorResult.YIELD:
self.current_buffer = ""
return seqres, return_seq
from flask import Flask, render_template
from flask_socketio import SocketIO, emit
from chat import generate_response, ChatSession
import uuid
import logging
app = Flask(__name__)
app.config['SECRET_KEY'] = 'mysecret'
socketio = SocketIO(app, cors_allowed_origins="*")
log = logging.getLogger('werkzeug')
log.setLevel(logging.WARNING)
@app.route('/')
def index():
return render_template('index.html')
global LAZY_SESSION
LAZY_SESSION = None
@socketio.on('message')
def handle_message(data):
global LAZY_SESSION
if LAZY_SESSION is None:
LAZY_SESSION = ChatSession()
print("Initialezed new model!")
prompt = data
response_id = str(uuid.uuid4())
print(prompt)
print(response_id)
emit('new_response', response_id)
for chunk in generate_response(prompt, LAZY_SESSION):
emit('response_chunk', {'id': response_id, 'data': chunk})
emit('response_end', response_id)
socketio.run(app)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment