Skip to content

Instantly share code, notes, and snippets.

@mentat
Created December 21, 2011 17:09
Show Gist options
  • Save mentat/1506798 to your computer and use it in GitHub Desktop.
Save mentat/1506798 to your computer and use it in GitHub Desktop.
Add pre/post create hooks to NDB
diff --exclude='*.pyc' -Naur ndb.old/key.py ndb/key.py
--- ndb.old/key.py 2011-12-20 10:27:11.000000000 -0600
+++ ndb/key.py 2011-12-21 11:23:49.000000000 -0600
@@ -514,6 +514,8 @@
if not cls._is_default_hook(model.Model._default_post_get_hook,
post_hook):
fut.add_immediate_callback(post_hook, self, fut)
+ internal_post_hook = model.Model._post_get_hook_internal
+ fut.add_immediate_callback(internal_post_hook, self, fut)
return fut
def delete(self, **ctx_options):
diff --exclude='*.pyc' -Naur ndb.old/model.py ndb/model.py
--- ndb.old/model.py 2011-12-20 10:27:11.000000000 -0600
+++ ndb/model.py 2011-12-21 12:54:38.000000000 -0600
@@ -2165,6 +2165,7 @@
# TODO: should this be restricted to string ids?
self._key = Key(self._get_kind(), id, parent=parent)
+ self._is_saved = False
self._values = {}
self._set_attributes(kwds)
@@ -2528,8 +2529,15 @@
if self._key is None:
self._key = Key(self._get_kind(), None)
self._pre_put_hook()
+ if not self._is_saved:
+ self._pre_create_hook()
fut = ctx.put(self, **ctx_options)
+ if not self._is_saved:
+ post_create_hook = self._post_create_hook
+ if not self._is_default_hook(Model._default_post_create_hook, post_create_hook):
+ fut.add_immediate_callback(post_create_hook, fut)
post_hook = self._post_put_hook
+ self._is_saved = True
if not self._is_default_hook(Model._default_post_put_hook, post_hook):
fut.add_immediate_callback(post_hook, fut)
return fut
@@ -2683,6 +2691,20 @@
pass
_default_post_put_hook = _post_put_hook
+ def _pre_create_hook(self):
+ pass
+ _default_pre_create_hook = _pre_create_hook
+
+ def _post_create_hook(self, future):
+ pass
+ _default_post_create_hook = _post_create_hook
+
+ @classmethod
+ def _post_get_hook_internal(cls, key, future):
+ """An internal hook to set saved state. Do not touch."""
+ if future.get_result():
+ future.get_result()._is_saved = True
+
@staticmethod
def _is_default_hook(default_hook, hook):
"""Checks whether a specific hook is in its default state.
diff --exclude='*.pyc' -Naur ndb.old/model_test.py ndb/model_test.py
--- ndb.old/model_test.py 2011-12-20 10:27:11.000000000 -0600
+++ ndb/model_test.py 2011-12-21 11:03:52.000000000 -0600
@@ -2856,6 +2856,42 @@
self.assertEqual(self.post_counter, 11,
'Post put hooks not called on put_multi')
+ def testCreateHooksCalled(self):
+ test = self # Closure for inside hooks
+ self.pre_counter = 0
+ self.post_counter = 0
+
+ class HatStand(model.Model):
+ def _pre_create_hook(self):
+ test.pre_counter += 1
+ def _post_create_hook(self, future):
+ test.post_counter += 1
+ test.assertEqual(future.get_result(), test.entity.key)
+
+ furniture = HatStand()
+ self.entity = furniture
+ self.assertEqual(self.pre_counter, 0, 'Pre create hook called early')
+ future = furniture.put_async()
+ self.assertEqual(self.pre_counter, 1, 'Pre create hook not called')
+ self.assertEqual(self.post_counter, 0, 'Post create hook called early')
+ future.get_result()
+ self.assertEqual(self.post_counter, 1, 'Post create hook not called')
+
+ # All counters now read 1, calling put_multi for 10 entities makes this 11
+ new_furniture = [HatStand() for _ in range(10)]
+ results = []
+ multi_future = model.put_multi_async(new_furniture)
+ self.assertEqual(self.pre_counter, 11,
+ 'Pre create hooks not called on put_multi')
+ self.assertEqual(self.post_counter, 1,
+ 'Post create hooks called early on put_multi')
+ for fut, ent in zip(multi_future, new_furniture):
+ self.entity = ent
+ fut.get_result()
+
+ self.assertEqual(self.post_counter, 11,
+ 'Post create hooks not called on put_multi')
+
def testGetByIdHooksCalled(self):
# See issue 95. http://goo.gl/QSRQH
# Adapted from testGetHooksCalled in key_test.py.
@@ -2916,6 +2952,10 @@
test.pre_put_counter += 1
def _post_put_hook(self, future):
test.post_put_counter += 1
+ def _pre_create_hook(self):
+ test.pre_create_counter += 1
+ def _post_create_hook(self, future):
+ test.post_create_counter += 1
# First call creates it. This calls get() twice (once outside the
# transaction and once inside it) and put() once (from inside the
@@ -2924,27 +2964,45 @@
self.post_get_counter = 0
self.pre_put_counter = 0
self.post_put_counter = 0
+ self.pre_create_counter = 0
+ self.post_create_counter = 0
HatStand.get_or_insert('classic')
self.assertEqual(self.pre_get_counter, 2)
self.assertEqual(self.post_get_counter, 2)
self.assertEqual(self.pre_put_counter, 1)
self.assertEqual(self.post_put_counter, 1)
+ self.assertEqual(self.pre_create_counter, 1)
+ self.assertEqual(self.post_create_counter, 1)
# Second call gets it without needing a transaction.
self.pre_get_counter = 0
self.post_get_counter = 0
self.pre_put_counter = 0
self.post_put_counter = 0
- HatStand.get_or_insert_async('classic').get_result()
+ self.pre_create_counter = 0
+ self.post_create_counter = 0
+ ret = HatStand.get_or_insert_async('classic').get_result()
self.assertEqual(self.pre_get_counter, 1)
self.assertEqual(self.post_get_counter, 1)
self.assertEqual(self.pre_put_counter, 0)
self.assertEqual(self.post_put_counter, 0)
+ self.assertEqual(self.pre_create_counter, 0)
+ self.assertEqual(self.post_create_counter, 0)
+
+ # Ensure post_create only gets called once
+ ret.put()
+ self.assertEqual(self.pre_get_counter, 1)
+ self.assertEqual(self.post_get_counter, 1)
+ self.assertEqual(self.pre_put_counter, 1)
+ self.assertEqual(self.post_put_counter, 1)
+ self.assertEqual(self.pre_create_counter, 0)
+ self.assertEqual(self.post_create_counter, 0)
def testMonkeyPatchHooks(self):
test = self # Closure for inside put hooks
hook_attr_names = ('_pre_allocate_ids_hook', '_post_allocate_ids_hook',
- '_pre_put_hook', '_post_put_hook')
+ '_pre_put_hook', '_post_put_hook', '_pre_create_hook',
+ '_post_create_hook')
original_hooks = {}
# Backup the original hooks
@@ -2955,6 +3013,8 @@
self.post_allocate_ids_flag = False
self.pre_put_flag = False
self.post_put_flag = False
+ self.pre_create_flag = False
+ self.post_create_flag = False
# TODO: Should the unused arguments to Monkey Patched tests be tested?
class HatStand(model.Model):
@@ -2969,6 +3029,10 @@
test.pre_put_flag = True
def _post_put_hook(self, unused_future):
test.post_put_flag = True
+ def _pre_create_hook(self):
+ test.pre_create_flag = True
+ def _post_create_hook(self, unused_future):
+ test.post_create_flag = True
# Monkey patch the hooks
for name in hook_attr_names:
@@ -2987,6 +3051,10 @@
'Pre put hook not called when model is monkey patched')
self.assertTrue(self.post_put_flag,
'Post put hook not called when model is monkey patched')
+ self.assertTrue(self.pre_create_flag,
+ 'Pre create hook not called when model is monkey patched')
+ self.assertTrue(self.post_create_flag,
+ 'Post create hook not called when model is monkey patched')
finally:
# Restore the original hooks
for name in hook_attr_names:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment