Created
May 13, 2024 23:06
-
-
Save CoffeeVampir3/8c691b07f6f251315b7ca746db848f0b to your computer and use it in GitHub Desktop.
Base RP Test
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from flask import Flask, render_template | |
import torch | |
from flask_socketio import SocketIO, emit | |
from generation.make_instruct import get_generator_func | |
from generation.exllama_generator_wrapper import encode_message, encode_system, encode_header, encode_header_prefilled, encode_message_with_eot, encode_completion | |
from collections import deque | |
import time | |
import os,sys | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
from exllamav2 import ( | |
ExLlamaV2Lora, | |
) | |
system_prompt = '''<|begin_of_text|>Edgar: I'm so tired of these empty eyes watching us in the distance.<end_user> | |
Amelia: It's becoming quite a nuisance. Should we do something about it, Edgar?<end_turn> | |
Edgar: Yes, I think it's time we took matters into our own hands.<end_user> | |
Amelia: Then I will bring out the beacon upon which to shine the end upon their mortal visages. It'll be a swift and horrible fate.<end_turn> | |
Edgar: Stop stop I can only get so aroused. But please continue, actually.<end_user> | |
Amelia: I shall. Watch as their lives become ash in an instant. Tremble mortals, I am Amelia Deathheart!<end_turn>''' | |
app = Flask(__name__) | |
socketio = SocketIO(app) | |
model, tokenizer, generator, generate = get_generator_func(sys.argv[1]) | |
message_pairs = {} | |
turn_counter = 0 | |
current_ctx_size = 0 | |
ctx_buffer_trim_token_thresh = 7000 | |
enc_sys_prompt = encode_completion(tokenizer, system_prompt) | |
sys_role = "\t c2 " | |
@app.route('/') | |
def index(): | |
return render_template('index.html', system_prompt=system_prompt) | |
@socketio.on('purge_messages') | |
def handle_purge_messages(): | |
#global message_pairs, turn_counter | |
#message_pairs = {} | |
#turn_counter = 0 | |
pass | |
@socketio.on('send_message') | |
def handle_send_message(message, new_sys_prompt): | |
global generate, enc_sys_prompt, message_pairs, max_turns, turn_counter, generator, sys_role | |
emit('user_message', message, broadcast=True) | |
emit('start_generation', broadcast=True) # Emit event to indicate generation start | |
turn_counter += 1 | |
next_message = encode_completion(tokenizer, message) | |
enc_sys_prompt.extend(next_message) | |
testing = torch.tensor(enc_sys_prompt).unsqueeze(dim=0) | |
tokens_out = 0 | |
response = [] | |
print("*"*30) | |
message_buffer = "" | |
for fragment, ids, count in generate(instruction_ids=testing, generator=generator, loras=None, stop_sequences=["<end_turn>"]): | |
emit('stream_response', {'fragment': fragment}, broadcast=True) | |
tokens_out += count | |
message_buffer += fragment | |
#print(str(ids) + " ", flush=True, end="") | |
#if tokens_out + len(enc_sys_prompt) > 7500: | |
# break | |
enc_turn = encode_completion(tokenizer, message_buffer) | |
enc_sys_prompt.extend(enc_turn) | |
debug_out = torch.tensor(enc_sys_prompt).unsqueeze(dim=0) | |
debug_v = tokenizer.decode(debug_out, decode_special_tokens = True) | |
print(debug_v) | |
emit('update_token_count', len(enc_sys_prompt) + tokens_out, broadcast=True) | |
emit('end_generation', broadcast=True) # Emit event to indicate generation end | |
if __name__ == '__main__': | |
socketio.run(app) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment