Created
September 16, 2024 11:10
-
-
Save pascalwhoop/561e9c2089c76d9c22a5bee2e6f2b15c to your computer and use it in GitHub Desktop.
Idea on cached inference with rocksdb, ray & high IOPS NAS
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Dict, Any | |
import rocksdb | |
import ray | |
# Step 1: Initialize RocksDB instance | |
# note this is mounted in the pod using a high IOPS ReadWriteMany Volume backed by GCP Hyperdisk | |
db = rocksdb.DB("rocksdb_dir", rocksdb.Options(create_if_missing=True)) | |
# Step 2: Define a Predictor class for inference. | |
class HuggingFacePredictor: | |
def __init__(self): | |
from transformers import pipeline | |
# Initialize a pre-trained GPT2 Huggingface pipeline. | |
self.model = pipeline("text-generation", model="gpt2") | |
# Logic for inference on 1 row of data. | |
def __call__(self, row: Dict[str, Any]) -> Dict[str, Any]: | |
input_text = row["input"] | |
# Check if the prediction is already in the cache | |
prediction = db.get(input_text.encode()) | |
if prediction is not None: | |
row["result"] = prediction.decode() | |
return row | |
# Get the prediction from the model | |
predictions = self.model([input_text], max_length=20, num_return_sequences=1) | |
prediction = predictions[0]["generated_text"] | |
# Store the prediction in the cache | |
db.put(input_text.encode(), prediction.encode()) | |
row["result"] = prediction | |
return row | |
# Step 3: Define the Kedro node function | |
def predict_and_cache(input_ds: ray.data.Dataset) -> ray.data.Dataset: | |
# Step 4: Map the Predictor over the Dataset to get predictions. | |
return input_ds.map(HuggingFacePredictor) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment