|
import functools |
|
from typing import Any, Callable, Optional, Dict |
|
import psycopg2 |
|
from psycopg2.extras import Json |
|
from contextlib import contextmanager |
|
import threading |
|
|
|
class PostgresCache: |
|
def __init__(self, dsn: str, scope: str = "default"): |
|
""" |
|
Initialize cache with connection string and scope. |
|
|
|
Args: |
|
dsn: PostgreSQL connection string |
|
scope: Namespace for the cache to separate different uses |
|
""" |
|
self.dsn = dsn |
|
self.scope = scope |
|
self._memory_cache: Dict[str, Any] = {} |
|
self._lock = threading.Lock() |
|
self._loaded_patterns: set[str] = set() # Track which patterns have been loaded |
|
self._init_db() |
|
|
|
def _init_db(self): |
|
with self._get_conn() as conn: |
|
with conn.cursor() as cur: |
|
cur.execute(""" |
|
CREATE TABLE IF NOT EXISTS api_cache ( |
|
scope TEXT, |
|
key TEXT, |
|
data JSONB, |
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
|
PRIMARY KEY (scope, key) |
|
) |
|
""") |
|
conn.commit() |
|
|
|
@contextmanager |
|
def _get_conn(self): |
|
conn = psycopg2.connect(self.dsn) |
|
try: |
|
yield conn |
|
finally: |
|
conn.close() |
|
|
|
def _create_key(self, *, func_name: str, func: Callable, args: tuple, kwargs: dict, version: int, ignore_params: tuple) -> str: |
|
"""Create a unique cache key from function name and arguments.""" |
|
# Convert args to kwargs using function signature |
|
import inspect |
|
sig = inspect.signature(func) |
|
bound_args = sig.bind(*args, **kwargs) |
|
bound_args.apply_defaults() |
|
|
|
# Get all parameters as a dictionary |
|
all_params = bound_args.arguments |
|
|
|
# Filter out ignored parameters |
|
filtered_params = { |
|
k: v for k, v in all_params.items() |
|
if k not in ignore_params |
|
} |
|
|
|
# Create stable key using sorted parameters |
|
cache_key = f"v{version}_{func_name}_{str(sorted(filtered_params.items()))}" |
|
return cache_key |
|
|
|
def get(self, key: str) -> Optional[Any]: |
|
# First try memory cache |
|
with self._lock: |
|
if key in self._memory_cache: |
|
return self._memory_cache[key] |
|
|
|
# If not in memory, try database |
|
with self._get_conn() as conn: |
|
with conn.cursor() as cur: |
|
cur.execute(""" |
|
SELECT data |
|
FROM api_cache |
|
WHERE scope = %s AND key = %s |
|
""", (self.scope, key)) |
|
result = cur.fetchone() |
|
if result: |
|
# Update memory cache |
|
with self._lock: |
|
self._memory_cache[key] = result[0] |
|
return result[0] |
|
return None |
|
|
|
def set(self, key: str, value: Any) -> None: |
|
""" |
|
Set a value in the cache. |
|
|
|
Args: |
|
key: Cache key |
|
value: Value to cache |
|
""" |
|
with self._get_conn() as conn: |
|
with conn.cursor() as cur: |
|
cur.execute(""" |
|
INSERT INTO api_cache (scope, key, data) |
|
VALUES (%s, %s, %s) |
|
ON CONFLICT (scope, key) |
|
DO UPDATE SET |
|
data = EXCLUDED.data, |
|
created_at = CURRENT_TIMESTAMP |
|
""", (self.scope, key, Json(value))) |
|
conn.commit() |
|
|
|
def clear(self, pattern: Optional[str] = None): |
|
""" |
|
Clear cache entries. |
|
|
|
Args: |
|
pattern: Optional SQL LIKE pattern to match keys |
|
""" |
|
with self._get_conn() as conn: |
|
with conn.cursor() as cur: |
|
if pattern: |
|
cur.execute( |
|
"DELETE FROM api_cache WHERE scope = %s AND key LIKE %s", |
|
(self.scope, pattern) |
|
) |
|
else: |
|
cur.execute( |
|
"DELETE FROM api_cache WHERE scope = %s", |
|
(self.scope,) |
|
) |
|
conn.commit() |
|
|
|
def load_into_memory(self, pattern: str) -> None: |
|
""" |
|
Load all cache entries matching a pattern into memory if not already loaded. |
|
|
|
Args: |
|
pattern: SQL LIKE pattern to match keys |
|
""" |
|
with self._lock: |
|
if pattern in self._loaded_patterns: |
|
return |
|
|
|
with self._get_conn() as conn: |
|
with conn.cursor() as cur: |
|
cur.execute(""" |
|
SELECT key, data |
|
FROM api_cache |
|
WHERE scope = %s AND key LIKE %s |
|
""", (self.scope, pattern)) |
|
self._memory_cache.update({row[0]: row[1] for row in cur.fetchall()}) |
|
self._loaded_patterns.add(pattern) |
|
|
|
def cached(cache: PostgresCache, func_name: str, version: int = 1, ignore_params: tuple = ()): |
|
""" |
|
Decorator for caching function results in PostgreSQL. |
|
|
|
Args: |
|
cache: PostgresCache instance to use |
|
version: Version number for cache invalidation (default: 1) |
|
ignore_params: Tuple of parameter names to ignore in cache key generation |
|
refetch: If True, refetch the data from the function even if it is already in the cache |
|
""" |
|
# Load existing cache entries for this function into memory |
|
cache.load_into_memory(f"v{version}_{func_name}%") |
|
|
|
def decorator(func: Callable): |
|
@functools.wraps(func) |
|
def wrapper(*args, **kwargs): |
|
refetch = kwargs.get('refetch', False) |
|
# remove refetch from kwargs |
|
kwargs.pop('refetch', None) |
|
cache_key = cache._create_key(func_name=func_name, func=func, args=args, kwargs=kwargs, version=version, ignore_params=ignore_params) |
|
if not refetch: |
|
result = cache.get(cache_key) |
|
if result is not None: |
|
return result |
|
|
|
result = func(*args, **kwargs) |
|
cache.set(cache_key, result) |
|
return result |
|
return wrapper |
|
return decorator |