Skip to content

Instantly share code, notes, and snippets.

@pascalwhoop
Created September 16, 2024 11:10
Show Gist options
  • Save pascalwhoop/561e9c2089c76d9c22a5bee2e6f2b15c to your computer and use it in GitHub Desktop.
Save pascalwhoop/561e9c2089c76d9c22a5bee2e6f2b15c to your computer and use it in GitHub Desktop.
Idea on cached inference with rocksdb, ray & high IOPS NAS
from typing import Dict, Any
import rocksdb
import ray
# Step 1: Initialize RocksDB instance
# note this is mounted in the pod using a high IOPS ReadWriteMany Volume backed by GCP Hyperdisk
db = rocksdb.DB("rocksdb_dir", rocksdb.Options(create_if_missing=True))
# Step 2: Define a Predictor class for inference.
class HuggingFacePredictor:
def __init__(self):
from transformers import pipeline
# Initialize a pre-trained GPT2 Huggingface pipeline.
self.model = pipeline("text-generation", model="gpt2")
# Logic for inference on 1 row of data.
def __call__(self, row: Dict[str, Any]) -> Dict[str, Any]:
input_text = row["input"]
# Check if the prediction is already in the cache
prediction = db.get(input_text.encode())
if prediction is not None:
row["result"] = prediction.decode()
return row
# Get the prediction from the model
predictions = self.model([input_text], max_length=20, num_return_sequences=1)
prediction = predictions[0]["generated_text"]
# Store the prediction in the cache
db.put(input_text.encode(), prediction.encode())
row["result"] = prediction
return row
# Step 3: Define the Kedro node function
def predict_and_cache(input_ds: ray.data.Dataset) -> ray.data.Dataset:
# Step 4: Map the Predictor over the Dataset to get predictions.
return input_ds.map(HuggingFacePredictor)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment