Last active
June 13, 2024 06:36
-
-
Save 0xDigest/67149a385762b57810323ba4e1669ae2 to your computer and use it in GitHub Desktop.
A demonstration of https://o565.com/llm-text-compression/
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import llama_cpp | |
import re | |
import json | |
# Model configuration | |
# tested with mistral, llama2, llama3, and phi3 | |
model_path = "/path/to/model" | |
base_llm = llama_cpp.Llama(model_path, seed=42, n_gpu_layers=-1, n_ctx=4096, verbose=False, temperature=0.0) | |
# Terminal color codes for debugging output | |
GREEN = '\033[32m' | |
RED = '\033[31m' | |
RESET = '\033[0m' | |
BLUE = '\033[34m' | |
DELIMITER = '\x1F' | |
DELIMITER = '@' | |
debug = True | |
def load_document(filename): | |
"""Load and tokenize the document, handling token overflow.""" | |
with open(filename, 'r', encoding='utf-8') as f: | |
source_text = f.read() | |
tokens = base_llm.tokenize(source_text.encode("utf-8")) | |
source_text_list = [] | |
if len(tokens) > base_llm.n_ctx() - 5: | |
print("[!] Document is too long for the context window: {}. Splitting.".format(base_llm.n_ctx())) | |
for i in range(0, len(tokens), int(base_llm.n_ctx() * 0.8)): | |
text_chunk = base_llm.detokenize(tokens[i:i + int(base_llm.n_ctx() * 0.8)]).decode() | |
source_text_list.append(text_chunk) | |
else: | |
source_text_list.append(source_text) | |
return source_text_list | |
def generate_text(prompt, max_tokens=1): | |
"""Generate text from a prompt using the LLM.""" | |
output = base_llm(prompt, max_tokens=max_tokens, echo=False, temperature=0.0, seed=42) | |
return output['choices'][0]['text'] | |
def compress_text(source_text): | |
"""Compress text by generating and comparing segments to the source text.""" | |
generated_text = "" | |
compressed_string = "" | |
gen_count = 0 | |
i = 0 | |
# let's loop until we have generated the entire source text | |
while generated_text != source_text: | |
# get a new token | |
part = generate_text(generated_text) | |
# if our generated text aligns with the source text then tally it | |
if source_text.startswith(str(generated_text + part)) and len(part) > 0: | |
gen_count += 1 | |
generated_text += part | |
i = len(generated_text) | |
if debug: | |
print(BLUE + part + RESET, end="", flush=True) | |
# if not, then grab a letter from the source document | |
# hopefully we'll be back on track during the next loop | |
else: | |
i += 1 | |
if gen_count > 0: | |
compressed_string += f"{re.escape(DELIMITER)}{gen_count}{re.escape(DELIMITER)}" | |
gen_count = 0 | |
generated_text += source_text[i - 1] | |
compressed_string += source_text[i - 1] | |
if debug: | |
print(source_text[i - 1], end="", flush=True) | |
return compressed_string | |
def decompress_text(compressed_text): | |
"""Decompress text from a compressed string.""" | |
decompressed_text = "" | |
# split the parts into sections, text and generation counts | |
parts = re.split(rf'({re.escape(DELIMITER)}\d+{re.escape(DELIMITER)})', compressed_text) | |
for part in parts: | |
# if we're looking at a generation count, then generate text | |
if re.match(rf'{re.escape(DELIMITER)}\d+{re.escape(DELIMITER)}', part): | |
number = int(part[1:-1]) | |
for count in range(number): | |
part = generate_text(decompressed_text) | |
if debug: | |
print(GREEN + part + RESET, end="", flush=True) | |
decompressed_text = decompressed_text + part | |
else: | |
# just add the text to the decompressed string | |
decompressed_text += part | |
if debug: | |
print(part, end="", flush=True) | |
return decompressed_text | |
if __name__ == "__main__": | |
# Process each document, compress, and decompress | |
if True: | |
print("\n[.] Loading Text...") | |
source_text_list = load_document("pg11.txt") | |
print("\n[.] Compressing Text...") | |
compressed_text_list = [compress_text(text) for text in source_text_list] | |
# Save compressed data | |
with open("compressed.json", "w") as f: | |
json.dump(compressed_text_list, f) | |
# Read compressed data and decompress | |
with open("compressed.json", "r") as f: | |
compressed_text_list = json.load(f) | |
output_text = "" | |
print("\n[.] Decompressing Text...") | |
base_llm.reset() | |
for compressed_text in compressed_text_list: | |
output_text += decompress_text(compressed_text) | |
print("\nDecompressed Output:") | |
print(output_text) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment