Skip to content

Instantly share code, notes, and snippets.

@philippkeller
Created January 11, 2025 14:49
Show Gist options
  • Save philippkeller/31df8329026eb73ce40d07bd4b199c0b to your computer and use it in GitHub Desktop.
Save philippkeller/31df8329026eb73ce40d07bd4b199c0b to your computer and use it in GitHub Desktop.
Postgres backed caching system

Posgres based caching system. Features:

  1. Doesn't recompute on refactoring: Since the keys are based on scope -> individual keys you can move the functions around without having cache misses
  2. Doesn't recompute on code changes: Something which drove me mad with joblib's memory: every time I added a print statement the cache would invalidate and fetch again. If your code changed enough, increase the versioning to make it recompute
  3. Refetch on demand: I had situation when I wanted to refetch particular method calls without deleting the whole cache. This is possible through the refetch attribute
  4. Human readable keys: In order to delete a certain sets of keys later I wanted the keys to be humanly readable in the database and not just hashes

Usage:

from db_cache import PostgresCache

cache = PostgresCache(dsn=DATABASE_URL, scope='gmaps')

@cached(cache=cache, func_name='places_nearby', version=1, ignore_params=('api_key',))
def places(*, keyword: str, api_key: str) -> List[Dict]:
…

import functools
from typing import Any, Callable, Optional, Dict
import psycopg2
from psycopg2.extras import Json
from contextlib import contextmanager
import threading
class PostgresCache:
def __init__(self, dsn: str, scope: str = "default"):
"""
Initialize cache with connection string and scope.
Args:
dsn: PostgreSQL connection string
scope: Namespace for the cache to separate different uses
"""
self.dsn = dsn
self.scope = scope
self._memory_cache: Dict[str, Any] = {}
self._lock = threading.Lock()
self._loaded_patterns: set[str] = set() # Track which patterns have been loaded
self._init_db()
def _init_db(self):
with self._get_conn() as conn:
with conn.cursor() as cur:
cur.execute("""
CREATE TABLE IF NOT EXISTS api_cache (
scope TEXT,
key TEXT,
data JSONB,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (scope, key)
)
""")
conn.commit()
@contextmanager
def _get_conn(self):
conn = psycopg2.connect(self.dsn)
try:
yield conn
finally:
conn.close()
def _create_key(self, *, func_name: str, func: Callable, args: tuple, kwargs: dict, version: int, ignore_params: tuple) -> str:
"""Create a unique cache key from function name and arguments."""
# Convert args to kwargs using function signature
import inspect
sig = inspect.signature(func)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
# Get all parameters as a dictionary
all_params = bound_args.arguments
# Filter out ignored parameters
filtered_params = {
k: v for k, v in all_params.items()
if k not in ignore_params
}
# Create stable key using sorted parameters
cache_key = f"v{version}_{func_name}_{str(sorted(filtered_params.items()))}"
return cache_key
def get(self, key: str) -> Optional[Any]:
# First try memory cache
with self._lock:
if key in self._memory_cache:
return self._memory_cache[key]
# If not in memory, try database
with self._get_conn() as conn:
with conn.cursor() as cur:
cur.execute("""
SELECT data
FROM api_cache
WHERE scope = %s AND key = %s
""", (self.scope, key))
result = cur.fetchone()
if result:
# Update memory cache
with self._lock:
self._memory_cache[key] = result[0]
return result[0]
return None
def set(self, key: str, value: Any) -> None:
"""
Set a value in the cache.
Args:
key: Cache key
value: Value to cache
"""
with self._get_conn() as conn:
with conn.cursor() as cur:
cur.execute("""
INSERT INTO api_cache (scope, key, data)
VALUES (%s, %s, %s)
ON CONFLICT (scope, key)
DO UPDATE SET
data = EXCLUDED.data,
created_at = CURRENT_TIMESTAMP
""", (self.scope, key, Json(value)))
conn.commit()
def clear(self, pattern: Optional[str] = None):
"""
Clear cache entries.
Args:
pattern: Optional SQL LIKE pattern to match keys
"""
with self._get_conn() as conn:
with conn.cursor() as cur:
if pattern:
cur.execute(
"DELETE FROM api_cache WHERE scope = %s AND key LIKE %s",
(self.scope, pattern)
)
else:
cur.execute(
"DELETE FROM api_cache WHERE scope = %s",
(self.scope,)
)
conn.commit()
def load_into_memory(self, pattern: str) -> None:
"""
Load all cache entries matching a pattern into memory if not already loaded.
Args:
pattern: SQL LIKE pattern to match keys
"""
with self._lock:
if pattern in self._loaded_patterns:
return
with self._get_conn() as conn:
with conn.cursor() as cur:
cur.execute("""
SELECT key, data
FROM api_cache
WHERE scope = %s AND key LIKE %s
""", (self.scope, pattern))
self._memory_cache.update({row[0]: row[1] for row in cur.fetchall()})
self._loaded_patterns.add(pattern)
def cached(cache: PostgresCache, func_name: str, version: int = 1, ignore_params: tuple = ()):
"""
Decorator for caching function results in PostgreSQL.
Args:
cache: PostgresCache instance to use
version: Version number for cache invalidation (default: 1)
ignore_params: Tuple of parameter names to ignore in cache key generation
refetch: If True, refetch the data from the function even if it is already in the cache
"""
# Load existing cache entries for this function into memory
cache.load_into_memory(f"v{version}_{func_name}%")
def decorator(func: Callable):
@functools.wraps(func)
def wrapper(*args, **kwargs):
refetch = kwargs.get('refetch', False)
# remove refetch from kwargs
kwargs.pop('refetch', None)
cache_key = cache._create_key(func_name=func_name, func=func, args=args, kwargs=kwargs, version=version, ignore_params=ignore_params)
if not refetch:
result = cache.get(cache_key)
if result is not None:
return result
result = func(*args, **kwargs)
cache.set(cache_key, result)
return result
return wrapper
return decorator
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment