Skip to content

Instantly share code, notes, and snippets.

@CoffeeVampir3
Created May 13, 2024 23:06
Show Gist options
  • Save CoffeeVampir3/8c691b07f6f251315b7ca746db848f0b to your computer and use it in GitHub Desktop.
Save CoffeeVampir3/8c691b07f6f251315b7ca746db848f0b to your computer and use it in GitHub Desktop.
Base RP Test
from flask import Flask, render_template
import torch
from flask_socketio import SocketIO, emit
from generation.make_instruct import get_generator_func
from generation.exllama_generator_wrapper import encode_message, encode_system, encode_header, encode_header_prefilled, encode_message_with_eot, encode_completion
from collections import deque
import time
import os,sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from exllamav2 import (
ExLlamaV2Lora,
)
system_prompt = '''<|begin_of_text|>Edgar: I'm so tired of these empty eyes watching us in the distance.<end_user>
Amelia: It's becoming quite a nuisance. Should we do something about it, Edgar?<end_turn>
Edgar: Yes, I think it's time we took matters into our own hands.<end_user>
Amelia: Then I will bring out the beacon upon which to shine the end upon their mortal visages. It'll be a swift and horrible fate.<end_turn>
Edgar: Stop stop I can only get so aroused. But please continue, actually.<end_user>
Amelia: I shall. Watch as their lives become ash in an instant. Tremble mortals, I am Amelia Deathheart!<end_turn>'''
app = Flask(__name__)
socketio = SocketIO(app)
model, tokenizer, generator, generate = get_generator_func(sys.argv[1])
message_pairs = {}
turn_counter = 0
current_ctx_size = 0
ctx_buffer_trim_token_thresh = 7000
enc_sys_prompt = encode_completion(tokenizer, system_prompt)
sys_role = "\t c2 "
@app.route('/')
def index():
return render_template('index.html', system_prompt=system_prompt)
@socketio.on('purge_messages')
def handle_purge_messages():
#global message_pairs, turn_counter
#message_pairs = {}
#turn_counter = 0
pass
@socketio.on('send_message')
def handle_send_message(message, new_sys_prompt):
global generate, enc_sys_prompt, message_pairs, max_turns, turn_counter, generator, sys_role
emit('user_message', message, broadcast=True)
emit('start_generation', broadcast=True) # Emit event to indicate generation start
turn_counter += 1
next_message = encode_completion(tokenizer, message)
enc_sys_prompt.extend(next_message)
testing = torch.tensor(enc_sys_prompt).unsqueeze(dim=0)
tokens_out = 0
response = []
print("*"*30)
message_buffer = ""
for fragment, ids, count in generate(instruction_ids=testing, generator=generator, loras=None, stop_sequences=["<end_turn>"]):
emit('stream_response', {'fragment': fragment}, broadcast=True)
tokens_out += count
message_buffer += fragment
#print(str(ids) + " ", flush=True, end="")
#if tokens_out + len(enc_sys_prompt) > 7500:
# break
enc_turn = encode_completion(tokenizer, message_buffer)
enc_sys_prompt.extend(enc_turn)
debug_out = torch.tensor(enc_sys_prompt).unsqueeze(dim=0)
debug_v = tokenizer.decode(debug_out, decode_special_tokens = True)
print(debug_v)
emit('update_token_count', len(enc_sys_prompt) + tokens_out, broadcast=True)
emit('end_generation', broadcast=True) # Emit event to indicate generation end
if __name__ == '__main__':
socketio.run(app)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment