Skip to content

Instantly share code, notes, and snippets.

@CoffeeVampir3
Created March 20, 2024 08:21
Show Gist options
  • Save CoffeeVampir3/9bada9cb507a50c0626b96b611ef9c7a to your computer and use it in GitHub Desktop.
Save CoffeeVampir3/9bada9cb507a50c0626b96b611ef9c7a to your computer and use it in GitHub Desktop.
exl-with-stopbuffering
def per_polygon_translate(generate_fn, input):
prompt = (
f"""::JAPANESE TEXT::
むし、 ある ところ に
[[0.0, 8.0], [252.0, 8.0], [252.0, 27.0], [0.0, 26.0]]
おじいさん と おばあさん が いました。
[[0.0, 33.0], [289.0, 32.0], [289.0, 50.0], [0.0, 50.0]]
おじいさん が 山(やま) へ 木(き) を きり に いけば、
[[0.0, 57.0], [416.0, 56.0], [417.0, 74.0], [0.0, 75.0]]
おばあさん は 川(かわ)へ せんたく に でかけます。
[[0.0, 80.0], [393.0, 81.0], [393.0, 99.0], [0.0, 98.0]]
「おじいさん、 はよう もどって きなされ。」
[[8.0, 104.0], [335.0, 104.0], [335.0, 123.0], [8.0, 122.0]]
::END TEXT::
::ENGLISH TEXT::
Once upon a time, in a certain place,
[[0.0, 8.0], [252.0, 8.0], [252.0, 27.0], [0.0, 26.0]]
there lived an old man and an old woman.
[[0.0, 33.0], [289.0, 32.0], [289.0, 50.0], [0.0, 50.0]]
When the old man went to the mountain to cut wood,
[[0.0, 57.0], [416.0, 56.0], [417.0, 74.0], [0.0, 75.0]]
the old woman would go to the river to do laundry.
[[0.0, 80.0], [393.0, 81.0], [393.0, 99.0], [0.0, 98.0]]
"Old man, please come back early."
[[8.0, 104.0], [335.0, 104.0], [335.0, 123.0], [8.0, 122.0]]
::END TEXT::
::JAPANESE TEXT::
{input}
::END TEXT::
::ENGLISH TEXT::""")
resp = generate_fn(prompt=prompt, stop_sequences=["::END TEXT"])
text = resp.strip()
return text
import sys, os, random
import torch
from enum import Enum
import pygtrie
# A requirement for using exllamav2 api
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from exllamav2 import(
ExLlamaV2,
ExLlamaV2Config,
ExLlamaV2Cache,
ExLlamaV2Tokenizer,
)
from exllamav2.generator import (
ExLlamaV2StreamingGenerator,
ExLlamaV2Sampler
)
class StopStatus(Enum):
CONTINUE = 1
WAITING = 2
STOP = 3
class StopBuffer:
def __init__(self, stop_sequences, case_sensitive=False):
self.trie = pygtrie.CharTrie()
self.case_sensitive = case_sensitive
for seq in stop_sequences:
if not self.case_sensitive:
seq = seq.lower()
self.trie[seq] = True
def append_check(self, text):
text = text.strip()
if not self.case_sensitive:
text = text.lower()
# If exact match, stop generating this is a stop condition
if text in self.trie:
return StopStatus.STOP
# If partial match, continue generating but do not yield anything yet
if self.trie.has_subtrie(text):
return StopStatus.WAITING
# No match, this will cause us to yield the current buffer.
return StopStatus.CONTINUE
def load_model(model_directory):
config = ExLlamaV2Config()
config.model_dir = model_directory
config.prepare()
config.max_seq_len = 2048
config.max_attention_size = 2048**2
model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, lazy = True)
model.load_autosplit(cache)
tokenizer = ExLlamaV2Tokenizer(config)
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
return config, tokenizer, cache, generator
def generate_response_fold(prompt, tokenizer, generator, settings, max_length, stop_sequences=[]):
buffer = ""
for fragment in generate_response_stream(prompt, tokenizer, generator, settings, max_length, stop_sequences):
buffer += fragment
print(fragment, end="", flush=True)
return buffer
def generate_response_stream(prompt, tokenizer, generator, settings, max_length, stop_sequences):
instruction_ids = tokenizer.encode(prompt, add_bos = True)
context_ids = instruction_ids if generator.sequence_ids is None \
else torch.cat([generator.sequence_ids, instruction_ids], dim = -1)
stop_buffer = StopBuffer(stop_sequences)
sbuf = ""
generator.begin_stream_ex(context_ids, settings)
while True:
res = generator.stream_ex()
if res["eos"]:
yield sbuf
return
chunk = res["chunk"]
sbuf += chunk
status = stop_buffer.append_check(sbuf)
if status == StopStatus.CONTINUE:
yield sbuf
sbuf = ""
elif status == StopStatus.STOP:
return
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment