Created
January 5, 2009 15:17
-
-
Save Arachnid/43429 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Modifies the App Engine datastore to support local caching of entities. | |
This is achieved by monkeypatching google.appengine.ext.db to recognise model | |
classes that should be cached and store them locally for the duration of a | |
single page request. | |
Note that only datastore gets (and anything that relies on them, such as | |
ReferenceProperty fetches) are cached; queries will neither return cached | |
entities nor update the cache. | |
To use, wrap your WSGI application in an instance of CacheSession, and | |
modify any models that you expect to be fetched more than once per request to | |
extend CachedModel instead of db.Model. | |
""" | |
from google.appengine.api import datastore | |
from google.appengine.ext import db | |
def splitByCond(l, func): | |
a = list() | |
b = list() | |
for item in l: | |
if func(item): | |
a.append(item) | |
else: | |
b.append(item) | |
return (a,b) | |
def joinByCond(l, a, b, func): | |
ret = list() | |
ia = iter(a) | |
ib = iter(b) | |
for item in l: | |
if func(item): | |
yield ia.next() | |
else: | |
yield ib.next() | |
class CachedModel(db.Model): | |
"""Any class that implements this will automatically have entities cached.""" | |
def put(self): | |
return put(self) | |
def delete(self): | |
delete(self) | |
self._entity = None | |
# Save the functions and classes we'll patch so we can refer to them | |
db_get = db.get | |
db_put = db.put | |
db_delete = db.delete | |
# Stores the current datastore cache | |
_current_session = None | |
def get(keys): | |
if not _current_session or datastore._CurrentTransactionKey(): | |
return db_get(keys) | |
keys, multiple = datastore.NormalizeAndTypeCheckKeys(keys) | |
# Split into cached and uncached | |
cond_func = lambda x: x in _current_session.cache | |
cached_keys, uncached_keys = splitByCond(keys, cond_func) | |
cached_models = [_current_session.cache[x] for x in cached_keys] | |
# Fetch uncached | |
if uncached_keys: | |
fetched_models = db_get(uncached_keys) | |
else: | |
_current_session.cached_gets += 1 | |
fetched_models = [] | |
# Update stats | |
_current_session.hit_count += len(cached_models) | |
_current_session.miss_count += sum(int(isinstance(x, CachedModel)) | |
for x in fetched_models) | |
_current_session.total_gets += 1 | |
# Construct return list | |
ret = list(joinByCond(keys, cached_models, fetched_models, cond_func)) | |
# Update cache | |
_current_session.cache.update((x.key(), x) for x in fetched_models | |
if isinstance(x, CachedModel)) | |
if multiple: | |
return ret | |
else: | |
return ret[0] | |
def put(models): | |
if not _current_session: | |
return db_put(models) | |
models, multiple = datastore.NormalizeAndTypeCheck(models, db.Model) | |
keys = db_put(models) | |
if not datastore._CurrentTransactionKey(): | |
_current_session.cache.update((k, v) for k, v in zip(keys, models) | |
if isinstance(v, CachedModel)) | |
else: | |
# In transactions, delete from the cache, since we don't know if it'll be | |
# committed or rolled back. | |
for k in keys: | |
if k in _current_session.cache: | |
del _current_session.cache[k] | |
if multiple: | |
return keys | |
else: | |
return keys[0] | |
def delete(models): | |
db_delete(models) | |
if not _current_session: return | |
models_or_keys, multiple = datastore.NormalizeAndTypeCheck( | |
models, (db.Model, db.Key, basestring)) | |
for model_or_key in models_or_keys: | |
if isinstance(model_or_key, CachedModel): | |
k = model_or_key.key() | |
elif isinstance(model_or_key, basestring): | |
k = db.Key(model_or_key) | |
else: | |
k = model_or_key | |
if k in _current_session.cache: | |
del _current_session.cache[k] | |
# Add in our monkeypatches | |
db.get = get | |
db.put = put | |
db.delete = delete | |
class CacheSession(object): | |
def __init__(self, wrapped): | |
self.wrapped = wrapped | |
self.cache = {} | |
self.hit_count = 0 | |
self.miss_count = 0 | |
self.total_gets = 0 | |
self.cached_gets = 0 | |
def __call__(self, environ, start_response): | |
global _current_session | |
_current_session = self | |
try: | |
return self.wrapped(environ, start_response) | |
finally: | |
_current_session = None | |
hit_rate = self.hit_count / float(self.hit_count + self.miss_count) | |
get_hit_rate = self.cached_gets / float(self.total_gets) | |
logging.info( | |
"datastore_cache saved %d/%d entity fetches (%d hit rate), " | |
"%d/%d requests (%d hit rate)", | |
self.hit_count, self.hit_count+self.miss_count, hit_rate, | |
self.cached_gets, self.total_gets, get_hit_rate) | |
def getStats(self): | |
return { | |
'hits': self.hit_count, | |
'misses': self.miss_count, | |
'gets': self.total_gets, | |
'cached_gets': self.cached_gets, | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import datastore_cache | |
import unittest | |
from google.appengine.api import apiproxy_stub_map | |
from google.appengine.api import datastore_file_stub | |
from google.appengine.ext import db | |
class Foo(datastore_cache.CachedModel): | |
one = db.IntegerProperty() | |
class Bar(db.Model): | |
ref = db.ReferenceProperty(Foo) | |
class DatastoreCacheTest(unittest.TestCase): | |
def setUp(self): | |
os.environ['APPLICATION_ID'] = 'test' | |
apiproxy_stub_map.apiproxy = apiproxy_stub_map.APIProxyStubMap() | |
datastore = datastore_file_stub.DatastoreFileStub('test', None, None) | |
apiproxy_stub_map.apiproxy.RegisterStub('datastore_v3', datastore) | |
def testOutsideSession(self): | |
self.failUnlessEqual(datastore_cache._current_session, None) | |
# Single put | |
foo = Foo(one=1) | |
foo_id = foo.put() | |
# Multiple put | |
foo.one = 2 | |
bar = Bar(ref=foo) | |
foo_id, bar_id = db.put([foo, bar]) | |
# Single get | |
self.failUnlessEqual(Foo.get(foo_id).one, 2) | |
# Multiple get | |
self.failUnlessEqual(len(db.get([foo_id, bar_id])), 2) | |
# Single delete | |
baz = Foo(one=3) | |
baz.put() | |
baz.delete() | |
# Multiple delete | |
db.delete([foo, bar]) | |
def testCacheSession(self): | |
class MyException(Exception): pass | |
def app(environ, start_response): | |
self.failIfEqual(datastore_cache._current_session, None) | |
return 'app' | |
def failApp(environ, start_response): | |
self.failIfEqual(datastore_cache._current_session, None) | |
raise MyException() | |
self.failUnlessEqual(datastore_cache._current_session, None) | |
session = datastore_cache.CacheSession(app) | |
self.failUnlessEqual(session(None, None), 'app') | |
self.failUnlessEqual(datastore_cache._current_session, None) | |
session = datastore_cache.CacheSession(failApp) | |
self.failUnlessRaises(MyException, session, None, None) | |
self.failUnlessEqual(datastore_cache._current_session, None) | |
def testCache(self): | |
session = datastore_cache.CacheSession(lambda x,y: []) | |
datastore_cache._current_session = session | |
# Single and multiple put of cached entities | |
foo = Foo(one=1) | |
foo.put() | |
self.failUnless(foo.key() in session.cache) | |
foo2 = Foo(one=2) | |
db.put([foo, foo2]) | |
self.failUnless(foo2.key() in session.cache) | |
# Put of uncached entities | |
bar = Bar(ref=foo) | |
bar.put() | |
self.failIf(bar.key() in session.cache) | |
bar2 = Bar(ref=foo2) | |
db.put([bar, bar2]) | |
self.failIf(bar2.key() in session.cache) | |
self.failUnlessEqual( | |
session.getStats(), | |
{'hits': 0, 'misses': 0, 'gets': 0, 'cached_gets': 0}) | |
# Cache hit | |
session.cache[foo.key()].test = 'test' | |
self.failUnlessEqual(db.get(foo.key()).test, 'test') | |
self.failUnlessEqual( | |
session.getStats(), | |
{'hits': 1, 'misses': 0, 'gets': 1, 'cached_gets': 1}) | |
# Cache miss | |
del session.cache[foo2.key()] | |
got = db.get([foo.key(), foo2.key()]) | |
self.failUnlessEqual(got[0], foo) | |
self.failIfEqual(got[1], foo2) | |
gotkeys = [x.key() for x in got] | |
self.failUnlessEqual(gotkeys, [foo.key(), foo2.key()]) | |
self.failUnless(foo2.key() in session.cache) | |
self.failUnlessEqual( | |
session.getStats(), | |
{'hits': 2, 'misses': 1, 'gets': 2, 'cached_gets': 1}) | |
# Deletion | |
foo2_key = foo2.key() | |
foo2.delete() | |
self.failIf(foo2_key in session.cache) | |
# Reference property lookup | |
bar = Bar.get(bar.key()) | |
self.failUnlessEqual(bar.ref, foo) | |
self.failUnlessEqual( | |
session.getStats(), | |
{'hits': 3, 'misses': 1, 'gets': 4, 'cached_gets': 2}) | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment