Skip to content

Instantly share code, notes, and snippets.

@CoffeeVampir3
Last active August 10, 2024 13:05
Show Gist options
  • Save CoffeeVampir3/326ca3f3b7f3e081d41df24dd4d9623f to your computer and use it in GitHub Desktop.
Save CoffeeVampir3/326ca3f3b7f3e081d41df24dd4d9623f to your computer and use it in GitHub Desktop.
Rolling system prompt
from exllamav2 import(
ExLlamaV2Tokenizer
)
def encode_message(tokenizer: ExLlamaV2Tokenizer, role: str, message: str) -> list:
tokens = []
start_token = tokenizer.single_id("<start_of_turn>")
end_token = tokenizer.single_id("<end_of_turn>")
tokens.append(start_token)
tokens.extend(tokenizer.encode(f"{role}\n", add_bos=False).view(-1).tolist())
tokens.extend(tokenizer.encode(message.strip(), add_bos=False).view(-1).tolist())
tokens.append(end_token)
tokens.extend(tokenizer.encode("\n", add_bos=False).view(-1).tolist())
return tokens
def terminate_message(tokenizer: ExLlamaV2Tokenizer, message: str) -> list:
tokens = []
end_token = tokenizer.single_id("<end_of_turn>")
tokens.extend(tokenizer.encode(message.strip(), add_bos=False).view(-1).tolist())
tokens.append(end_token)
return tokens
def encode_header(tokenizer: ExLlamaV2Tokenizer, role: str) -> list:
tokens = []
start_token = tokenizer.single_id("<start_of_turn>")
tokens.append(start_token)
tokens.extend(tokenizer.encode(f"{role}\n", add_bos=False).view(-1).tolist())
return tokens
from flask import Flask, render_template
import torch
from flask_socketio import SocketIO, emit
from generation.exllama_generator_wrapper import get_generator_func
from formats.gemma2_format import encode_header, encode_message, terminate_message
from collections import deque
import time
import os,sys,json
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from exllamav2 import (
ExLlamaV2Lora,
)
app = Flask(__name__)
socketio = SocketIO(app)
sys_role = "model"
user_role = "user"
system_prompt = f"""Example."""
model, tokenizer, generator, generate = get_generator_func(sys.argv[1])
jsonl_file = 'conversation.jsonl'
def write_system_prompt(system_prompt):
with open(jsonl_file, 'w') as f:
json.dump({"system": system_prompt}, f)
f.write('\n')
lora = None
if len(sys.argv) > 2:
lora = ExLlamaV2Lora.from_directory(model, sys.argv[2])
print(lora)
message_pairs = {}
turn_counter = 0
current_ctx_size = 0
ctx_buffer_trim_token_thresh = 30000
@app.route('/')
def index():
return render_template('index.html', system_prompt=system_prompt)
@socketio.on('purge_messages')
def handle_purge_messages():
global message_pairs, turn_counter
message_pairs = {}
turn_counter = 0
@socketio.on('send_message')
def handle_send_message(message, new_sys_prompt):
global generate, message_pairs, turn_counter, generator, sys_role, user_role, current_ctx_size, lora
emit('user_message', message, broadcast=True)
emit('start_generation', broadcast=True) # Emit event to indicate generation start
enc_sys_prompt = encode_message(tokenizer, "system", new_sys_prompt)
debug_out = torch.tensor(enc_sys_prompt).unsqueeze(dim=0)
debug_v = tokenizer.decode(debug_out, decode_special_tokens = True)
print(debug_v)
turn_counter += 1
next_message = encode_message(tokenizer, user_role, message)
message_pairs[turn_counter] = [next_message, encode_header(tokenizer, sys_role)]
with open(jsonl_file, 'a') as f:
json.dump({user_role: message}, f)
f.write('\n')
if current_ctx_size > ctx_buffer_trim_token_thresh:
for _ in range(0, 20):
oldest_turn = min(message_pairs.keys())
del message_pairs[oldest_turn]
sorted_turns = sorted(message_pairs.keys())
for turn in sorted_turns:
messages = message_pairs[turn]
for msg in messages:
enc_sys_prompt.extend(msg)
testing = torch.tensor(enc_sys_prompt).unsqueeze(dim=0)
tokens_out = 0
response = []
print("*"*30)
message_buffer = ""
for fragment, ids, count in generate(instruction_ids=testing, generator=generator, loras=lora):
emit('stream_response', {'fragment': fragment}, broadcast=True)
tokens_out += count
message_buffer += fragment
#print(str(ids) + " ", flush=True, end="")
#if tokens_out + len(enc_sys_prompt) > 7500:
# break
enc_turn = terminate_message(tokenizer, message_buffer)
with open(jsonl_file, 'a') as f:
json.dump({sys_role: message_buffer}, f)
f.write('\n')
message_pairs[turn_counter].append(enc_turn)
current_ctx_size = len(enc_turn) + len(enc_sys_prompt)
dbg = []
for item in message_pairs[turn_counter]:
dbg.extend(item)
debug_out = torch.tensor(dbg).unsqueeze(dim=0)
debug_v = tokenizer.decode(debug_out, decode_special_tokens = True)
print(debug_v)
emit('update_token_count', len(enc_sys_prompt) + tokens_out, broadcast=True)
emit('end_generation', broadcast=True) # Emit event to indicate generation end
if __name__ == '__main__':
socketio.run(app)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment