Created
March 20, 2024 08:21
-
-
Save CoffeeVampir3/9bada9cb507a50c0626b96b611ef9c7a to your computer and use it in GitHub Desktop.
exl-with-stopbuffering
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def per_polygon_translate(generate_fn, input): | |
prompt = ( | |
f"""::JAPANESE TEXT:: | |
むし、 ある ところ に | |
[[0.0, 8.0], [252.0, 8.0], [252.0, 27.0], [0.0, 26.0]] | |
おじいさん と おばあさん が いました。 | |
[[0.0, 33.0], [289.0, 32.0], [289.0, 50.0], [0.0, 50.0]] | |
おじいさん が 山(やま) へ 木(き) を きり に いけば、 | |
[[0.0, 57.0], [416.0, 56.0], [417.0, 74.0], [0.0, 75.0]] | |
おばあさん は 川(かわ)へ せんたく に でかけます。 | |
[[0.0, 80.0], [393.0, 81.0], [393.0, 99.0], [0.0, 98.0]] | |
「おじいさん、 はよう もどって きなされ。」 | |
[[8.0, 104.0], [335.0, 104.0], [335.0, 123.0], [8.0, 122.0]] | |
::END TEXT:: | |
::ENGLISH TEXT:: | |
Once upon a time, in a certain place, | |
[[0.0, 8.0], [252.0, 8.0], [252.0, 27.0], [0.0, 26.0]] | |
there lived an old man and an old woman. | |
[[0.0, 33.0], [289.0, 32.0], [289.0, 50.0], [0.0, 50.0]] | |
When the old man went to the mountain to cut wood, | |
[[0.0, 57.0], [416.0, 56.0], [417.0, 74.0], [0.0, 75.0]] | |
the old woman would go to the river to do laundry. | |
[[0.0, 80.0], [393.0, 81.0], [393.0, 99.0], [0.0, 98.0]] | |
"Old man, please come back early." | |
[[8.0, 104.0], [335.0, 104.0], [335.0, 123.0], [8.0, 122.0]] | |
::END TEXT:: | |
::JAPANESE TEXT:: | |
{input} | |
::END TEXT:: | |
::ENGLISH TEXT::""") | |
resp = generate_fn(prompt=prompt, stop_sequences=["::END TEXT"]) | |
text = resp.strip() | |
return text |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys, os, random | |
import torch | |
from enum import Enum | |
import pygtrie | |
# A requirement for using exllamav2 api | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
from exllamav2 import( | |
ExLlamaV2, | |
ExLlamaV2Config, | |
ExLlamaV2Cache, | |
ExLlamaV2Tokenizer, | |
) | |
from exllamav2.generator import ( | |
ExLlamaV2StreamingGenerator, | |
ExLlamaV2Sampler | |
) | |
class StopStatus(Enum): | |
CONTINUE = 1 | |
WAITING = 2 | |
STOP = 3 | |
class StopBuffer: | |
def __init__(self, stop_sequences, case_sensitive=False): | |
self.trie = pygtrie.CharTrie() | |
self.case_sensitive = case_sensitive | |
for seq in stop_sequences: | |
if not self.case_sensitive: | |
seq = seq.lower() | |
self.trie[seq] = True | |
def append_check(self, text): | |
text = text.strip() | |
if not self.case_sensitive: | |
text = text.lower() | |
# If exact match, stop generating this is a stop condition | |
if text in self.trie: | |
return StopStatus.STOP | |
# If partial match, continue generating but do not yield anything yet | |
if self.trie.has_subtrie(text): | |
return StopStatus.WAITING | |
# No match, this will cause us to yield the current buffer. | |
return StopStatus.CONTINUE | |
def load_model(model_directory): | |
config = ExLlamaV2Config() | |
config.model_dir = model_directory | |
config.prepare() | |
config.max_seq_len = 2048 | |
config.max_attention_size = 2048**2 | |
model = ExLlamaV2(config) | |
cache = ExLlamaV2Cache(model, lazy = True) | |
model.load_autosplit(cache) | |
tokenizer = ExLlamaV2Tokenizer(config) | |
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer) | |
return config, tokenizer, cache, generator | |
def generate_response_fold(prompt, tokenizer, generator, settings, max_length, stop_sequences=[]): | |
buffer = "" | |
for fragment in generate_response_stream(prompt, tokenizer, generator, settings, max_length, stop_sequences): | |
buffer += fragment | |
print(fragment, end="", flush=True) | |
return buffer | |
def generate_response_stream(prompt, tokenizer, generator, settings, max_length, stop_sequences): | |
instruction_ids = tokenizer.encode(prompt, add_bos = True) | |
context_ids = instruction_ids if generator.sequence_ids is None \ | |
else torch.cat([generator.sequence_ids, instruction_ids], dim = -1) | |
stop_buffer = StopBuffer(stop_sequences) | |
sbuf = "" | |
generator.begin_stream_ex(context_ids, settings) | |
while True: | |
res = generator.stream_ex() | |
if res["eos"]: | |
yield sbuf | |
return | |
chunk = res["chunk"] | |
sbuf += chunk | |
status = stop_buffer.append_check(sbuf) | |
if status == StopStatus.CONTINUE: | |
yield sbuf | |
sbuf = "" | |
elif status == StopStatus.STOP: | |
return |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment