Last active
August 6, 2023 07:54
-
-
Save CoffeeVampir3/fe09b5eede4a7375f2c0a0a64e792b36 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch, os, sys, glob | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
from model import ExLlama, ExLlamaCache, ExLlamaConfig | |
from lora import ExLlamaLora | |
from tokenizer import ExLlamaTokenizer | |
from generator import ExLlamaGenerator | |
from initialize import initialize_session, intitialize_model | |
from token_processor import KillSequenceProcessor | |
class ChatSession: | |
def __init__(self): | |
initialize_session() | |
self.model, self.generator, self.tokenizer, self.config = intitialize_model() | |
self.kill_processor = KillSequenceProcessor(["test"]) | |
self.max_response_tokens = 2048 | |
def generate_response(prompt, chat_session): | |
generator = chat_session.generator | |
kill_processor = chat_session.kill_processor | |
beam_search = (generator.settings.beams and generator.settings.beams >= 1 and generator.settings.beam_length and generator.settings.beam_length >= 1) | |
ids = generator.tokenizer.encode(prompt) | |
generator.gen_begin_reuse(ids) | |
seq_length = len(generator.tokenizer.decode(generator.sequence_actual[0])) | |
print(seq_length) | |
tail = seq_length | |
if beam_search: | |
generator.begin_beam_search() | |
token_getter = generator.beam_search | |
else: | |
token_getter = generator.gen_single_token | |
while(generator.gen_num_tokens() <= chat_session.max_response_tokens - 8): | |
token = token_getter() | |
titem = token.item() | |
#If it's the ending token replace it and end the generation. | |
if titem == generator.tokenizer.eos_token_id: | |
generator.replace_last_token(generator.tokenizer.newline_token_id) | |
break | |
stuff = generator.tokenizer.decode(generator.sequence_actual[0]) | |
head = len(stuff) | |
chunk = stuff[tail:head] | |
tail = head | |
kill_result, kill_buf = kill_processor.process_stream_chunk(chunk) | |
if kill_result == KillSequenceProcessor.KillProcessorResult.KILL: | |
rewind_length = generator.tokenizer.encode(kill_buf).shape[-1] | |
generator.gen_rewind(rewind_length) | |
break | |
elif kill_result == KillSequenceProcessor.KillProcessorResult.YIELD: | |
yield kill_buf | |
generator.end_beam_search() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>SocketIO Test</title> | |
<script src="https://cdn.socket.io/4.6.0/socket.io.min.js" integrity="sha384-c79GN5VsunZvi+Q/WObgk2in0CbZsHnjEqvFxC5DxHn9lTfNce2WW6h2pH6u/kF+" crossorigin="anonymous"></script> | |
<link rel="stylesheet" type="text/css" href="{{ url_for('static', filename='styles.css') }}"> | |
</head> | |
<body> | |
<div id="primary-container"></div> | |
<input type="text" id="message" placeholder="Type your message here"> | |
<button id="send-button" onclick="sendMessage()">Send</button> | |
</body> | |
<script type="text/javascript" charset="utf-8"> | |
var socket = io.connect('http://' + document.domain + ':' + location.port); | |
var to_right = true; | |
var primaryContainer = document.getElementById('primary-container'); | |
var sendButton = document.getElementById('send-button'); | |
socket.on('connect', function() { | |
socket.emit('startup', 'User has connected!'); | |
}); | |
socket.on('new_response', function(response_id) { | |
sendButton.disabled = true; | |
// Create a new text item (as a paragraph element) | |
var newItem = document.createElement('p'); | |
// Set the text of the item | |
newItem.textContent = ""; | |
// Store the response_id in a 'data-response-id' attribute | |
newItem.setAttribute('data-response-id', response_id); | |
// Add the item to the body of the document | |
newItem.classList.add('speech-bubble'); | |
if (to_right) { | |
newItem.classList.add('right-bubble'); | |
} else { | |
newItem.classList.add('left-bubble'); | |
} | |
to_right = !to_right; | |
primaryContainer.appendChild(newItem); | |
}); | |
socket.on('response_chunk', function(msg) { | |
let id = msg.id; | |
let data = msg.data; | |
// Find the text item with the matching response_id | |
let item = document.querySelector(`p[data-response-id="${id}"]`); | |
var newSpan = document.createElement('span'); | |
// Set the text of the span | |
newSpan.textContent = data; | |
newSpan.style.color = "white" | |
// Depending on the content of the data, add a different CSS class | |
// if (ping_pong) { | |
// newSpan.style.color = "red"; | |
// } else { | |
// newSpan.style.color = "blue"; | |
// } | |
// ping_pong = !ping_pong | |
item.appendChild(newSpan); | |
}); | |
socket.on('response_end', function(response_id) { | |
// Enable the send button | |
sendButton.disabled = false; | |
}); | |
function sendMessage() { | |
var message = document.getElementById('message').value; | |
socket.emit('message', message); | |
} | |
</script> | |
</html> |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch, os, sys, glob | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
from model import ExLlama, ExLlamaCache, ExLlamaConfig | |
from lora import ExLlamaLora | |
from tokenizer import ExLlamaTokenizer | |
from generator import ExLlamaGenerator | |
import model_init | |
import argparse | |
def initialize_session(): | |
torch.set_grad_enabled(False) | |
torch.cuda._lazy_init() | |
def get_model_path_at(path): | |
patterns = ["*.safetensors", "*.bin", "*.pt"] | |
model_paths = [] | |
for pattern in patterns: | |
full_pattern = os.path.join(path, pattern) | |
model_paths = glob.glob(full_pattern) | |
if model_paths: # If there are any files matching the current pattern | |
break # Exit the loop as soon as we find a matching file | |
if model_paths: # If there are any files matching any of the patterns | |
return model_paths[0] | |
else: | |
return None # Return None if no matching files were found | |
def intitialize_model(): | |
parser = argparse.ArgumentParser(description = "Simple chatbot example for ExLlama") | |
model_init.add_args(parser) | |
parser.add_argument("-lora", "--lora", type = str, help = "Path to LoRA binary to use during benchmark") | |
parser.add_argument("-loracfg", "--lora_config", type = str, help = "Path to LoRA config to use during benchmark") | |
parser.add_argument("-ld", "--lora_dir", type = str, help = "Path to LoRA config and binary. to use during benchmark") | |
parser.add_argument("-temp", "--temperature", type = float, help = "Temperature", default = 1.6) | |
parser.add_argument("-topk", "--top_k", type = int, help = "Top-K", default = 32) | |
parser.add_argument("-topp", "--top_p", type = float, help = "Top-P", default = 0.65) | |
parser.add_argument("-minp", "--min_p", type = float, help = "Min-P", default = 0.00) | |
parser.add_argument("-repp", "--repetition_penalty", type = float, help = "Repetition penalty", default = 1.65) | |
parser.add_argument("-repps", "--repetition_penalty_sustain", type = int, help = "Past length for repetition penalty", default = 256) | |
parser.add_argument("-beams", "--beams", type = int, help = "Number of beams for beam search", default = 0) | |
parser.add_argument("-beamlen", "--beam_length", type = int, help = "Number of future tokens to consider", default = 0) | |
args = parser.parse_args() | |
model_init.post_parse(args) | |
print(args) | |
model_init.get_model_files(args) | |
config = model_init.make_config(args) | |
model = ExLlama(config) | |
cache = ExLlamaCache(model) | |
tokenizer = ExLlamaTokenizer(args.tokenizer) | |
model_init.print_stats(model) | |
generator = ExLlamaGenerator(model, tokenizer, cache) | |
generator.settings.temperature = args.temperature | |
generator.settings.top_k = args.top_k | |
generator.settings.top_p = args.top_p | |
generator.settings.min_p = args.min_p | |
generator.settings.token_repetition_penalty_max = args.repetition_penalty | |
generator.settings.token_repetition_penalty_sustain = args.repetition_penalty_sustain | |
generator.settings.token_repetition_penalty_decay = generator.settings.token_repetition_penalty_sustain // 2 | |
generator.settings.beams = args.beams | |
generator.settings.beam_length = args.beam_length | |
generator.settings.no_fused_attn = True | |
if args.lora_dir is not None: | |
lora_config_path = os.path.join(args.lora_dir, "adapter_config.json") | |
lora_path = get_model_path_at(args.lora_dir) | |
lora = ExLlamaLora(model, lora_config_path, lora_path) | |
generator.lora = lora | |
return model, generator, tokenizer, config |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#primary-container { | |
width: 100%; | |
height: 80vh; | |
overflow-y: auto; | |
overflow-x: hidden; | |
} | |
.speech-bubble { | |
position: relative; | |
background-color: #111111; | |
border-radius: 5px; | |
padding: 20px; | |
width: 66vw; | |
} | |
.left-bubble { | |
float: left; | |
clear: both; /* This ensures that the item appears below the previous floating element */ | |
} | |
.right-bubble { | |
float: right; | |
clear: both; /* This ensures that the item appears below the previous floating element */ | |
} | |
body { | |
background: #11111111; | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from enum import Enum | |
class StreamingTokenProcessor: | |
def __init__(self): | |
self.current_buffer = "" | |
class KillSequenceProcessor(StreamingTokenProcessor): | |
def __init__(self, kill_sequences): | |
super().__init__() | |
self.kill_sequences = kill_sequences | |
class KillProcessorResult(Enum): | |
KILL = 1 | |
BUFFER = 0 | |
YIELD = 2 | |
def examine_buffer(self): | |
sequence = self.current_buffer.strip().lower() | |
for death_seq in self.kill_sequences: | |
if death_seq == sequence: | |
return self.KillProcessorResult.KILL | |
elif death_seq.startswith(sequence): | |
return self.KillProcessorResult.BUFFER | |
return self.KillProcessorResult.YIELD | |
def process_stream_chunk(self, chunk): | |
self.current_buffer += chunk | |
seqres = self.examine_buffer() | |
return_seq = self.current_buffer | |
if seqres == self.KillProcessorResult.YIELD: | |
self.current_buffer = "" | |
return seqres, return_seq |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from flask import Flask, render_template | |
from flask_socketio import SocketIO, emit | |
from chat import generate_response, ChatSession | |
import uuid | |
import logging | |
app = Flask(__name__) | |
app.config['SECRET_KEY'] = 'mysecret' | |
socketio = SocketIO(app, cors_allowed_origins="*") | |
log = logging.getLogger('werkzeug') | |
log.setLevel(logging.WARNING) | |
@app.route('/') | |
def index(): | |
return render_template('index.html') | |
global LAZY_SESSION | |
LAZY_SESSION = None | |
@socketio.on('message') | |
def handle_message(data): | |
global LAZY_SESSION | |
if LAZY_SESSION is None: | |
LAZY_SESSION = ChatSession() | |
print("Initialezed new model!") | |
prompt = data | |
response_id = str(uuid.uuid4()) | |
print(prompt) | |
print(response_id) | |
emit('new_response', response_id) | |
for chunk in generate_response(prompt, LAZY_SESSION): | |
emit('response_chunk', {'id': response_id, 'data': chunk}) | |
emit('response_end', response_id) | |
socketio.run(app) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment