Created
August 2, 2024 13:43
-
-
Save wjurkowlaniec/c436d6abfae54381bc2d8e440d018a93 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import pathlib | |
import torch | |
from transformers import AutoModel, AutoTokenizer | |
import json | |
from watchdog.observers import Observer | |
from watchdog.events import FileSystemEventHandler | |
# Settings | |
codebase_dir = "project/path" | |
cache_dir = "./cache" | |
model_name = "distilbert-base-uncased" | |
exclude_dirs = ["venv", "node_modules"] | |
# Load model and tokenizer | |
model = AutoModel.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
class ChangeHandler(FileSystemEventHandler): | |
def __init__(self, last_update_dates): | |
self.last_update_dates = last_update_dates | |
def on_modified(self, event): | |
if ( | |
not event.is_directory | |
and event.src_path.endswith(".py") | |
and exclude_dirs_func(pathlib.Path(event.src_path)) | |
): | |
load_file(pathlib.Path(event.src_path), self.last_update_dates) | |
def load_file(file_path, last_update_dates): | |
try: | |
print(f"Loading file: {file_path}", end=" ", flush=True) | |
# Load file contents | |
with open(file_path, "r") as f: | |
contents = f.read() | |
if not contents.strip(): | |
print(f"File is empty: {file_path}") | |
return | |
# Set the chunk size to 512 tokens (adjust as needed) | |
chunk_size = 512 | |
# Initialize an empty list to store the tokenized chunks | |
tokenized_chunks = [] | |
# Loop through the file contents in chunks | |
for i in range(0, len(contents), chunk_size): | |
chunk = contents[i : i + chunk_size] | |
# Tokenize the chunk with padding | |
inputs = tokenizer( | |
chunk, | |
return_tensors="pt", | |
padding="max_length", | |
truncation=True, | |
max_length=512, | |
) | |
# Add the tokenized chunk to the list | |
tokenized_chunks.append(inputs) | |
# Combine the tokenized chunks into a single tensor | |
tokenized_file = torch.cat( | |
[chunk["input_ids"] for chunk in tokenized_chunks], dim=0 | |
) | |
# Run the tokenized file through the model | |
with torch.no_grad(): | |
outputs = model(tokenized_file) | |
# Save the outputs to the cache | |
file_name = file_path.name | |
file_dir = file_path.parent | |
path_name = os.path.join( | |
cache_dir, str(file_dir).replace("/", "_") + "_" + file_name + ".pt" | |
) | |
os.makedirs(os.path.dirname(path_name), exist_ok=True) | |
torch.save(outputs, path_name) | |
# Save the last update date of the file in a JSON file | |
last_update_dates[str(file_path)] = file_path.stat().st_mtime | |
with open("last_update_dates.json", "w") as f: | |
json.dump(last_update_dates, f) | |
print(f"Saved") | |
except Exception as e: | |
print(f"Error loading file: {file_path} - {e}") | |
def exclude_dirs_func(file_path): | |
for dir in exclude_dirs: | |
if dir in file_path.parts: | |
return False | |
return True | |
def main(): | |
print("Starting cache update...") | |
# Load the last update dates | |
last_update_dates = {} | |
try: | |
with open("last_update_dates.json", "r") as f: | |
if f.read().strip(): # Check if the file is not empty | |
last_update_dates = json.load(f) | |
except FileNotFoundError: | |
with open("last_update_dates.json", "w") as f: | |
json.dump({}, f) | |
except json.JSONDecodeError: | |
with open("last_update_dates.json", "w") as f: | |
json.dump({}, f) | |
event_handler = ChangeHandler(last_update_dates) | |
observer = Observer() | |
observer.schedule(event_handler, path=codebase_dir, recursive=True) | |
observer.start() | |
try: | |
while True: | |
pass | |
except KeyboardInterrupt: | |
observer.stop() | |
observer.join() | |
print("Cache update complete!") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment