Created
December 21, 2011 17:09
-
-
Save mentat/1506798 to your computer and use it in GitHub Desktop.
Add pre/post create hooks to NDB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --exclude='*.pyc' -Naur ndb.old/key.py ndb/key.py | |
--- ndb.old/key.py 2011-12-20 10:27:11.000000000 -0600 | |
+++ ndb/key.py 2011-12-21 11:23:49.000000000 -0600 | |
@@ -514,6 +514,8 @@ | |
if not cls._is_default_hook(model.Model._default_post_get_hook, | |
post_hook): | |
fut.add_immediate_callback(post_hook, self, fut) | |
+ internal_post_hook = model.Model._post_get_hook_internal | |
+ fut.add_immediate_callback(internal_post_hook, self, fut) | |
return fut | |
def delete(self, **ctx_options): | |
diff --exclude='*.pyc' -Naur ndb.old/model.py ndb/model.py | |
--- ndb.old/model.py 2011-12-20 10:27:11.000000000 -0600 | |
+++ ndb/model.py 2011-12-21 12:54:38.000000000 -0600 | |
@@ -2165,6 +2165,7 @@ | |
# TODO: should this be restricted to string ids? | |
self._key = Key(self._get_kind(), id, parent=parent) | |
+ self._is_saved = False | |
self._values = {} | |
self._set_attributes(kwds) | |
@@ -2528,8 +2529,15 @@ | |
if self._key is None: | |
self._key = Key(self._get_kind(), None) | |
self._pre_put_hook() | |
+ if not self._is_saved: | |
+ self._pre_create_hook() | |
fut = ctx.put(self, **ctx_options) | |
+ if not self._is_saved: | |
+ post_create_hook = self._post_create_hook | |
+ if not self._is_default_hook(Model._default_post_create_hook, post_create_hook): | |
+ fut.add_immediate_callback(post_create_hook, fut) | |
post_hook = self._post_put_hook | |
+ self._is_saved = True | |
if not self._is_default_hook(Model._default_post_put_hook, post_hook): | |
fut.add_immediate_callback(post_hook, fut) | |
return fut | |
@@ -2683,6 +2691,20 @@ | |
pass | |
_default_post_put_hook = _post_put_hook | |
+ def _pre_create_hook(self): | |
+ pass | |
+ _default_pre_create_hook = _pre_create_hook | |
+ | |
+ def _post_create_hook(self, future): | |
+ pass | |
+ _default_post_create_hook = _post_create_hook | |
+ | |
+ @classmethod | |
+ def _post_get_hook_internal(cls, key, future): | |
+ """An internal hook to set saved state. Do not touch.""" | |
+ if future.get_result(): | |
+ future.get_result()._is_saved = True | |
+ | |
@staticmethod | |
def _is_default_hook(default_hook, hook): | |
"""Checks whether a specific hook is in its default state. | |
diff --exclude='*.pyc' -Naur ndb.old/model_test.py ndb/model_test.py | |
--- ndb.old/model_test.py 2011-12-20 10:27:11.000000000 -0600 | |
+++ ndb/model_test.py 2011-12-21 11:03:52.000000000 -0600 | |
@@ -2856,6 +2856,42 @@ | |
self.assertEqual(self.post_counter, 11, | |
'Post put hooks not called on put_multi') | |
+ def testCreateHooksCalled(self): | |
+ test = self # Closure for inside hooks | |
+ self.pre_counter = 0 | |
+ self.post_counter = 0 | |
+ | |
+ class HatStand(model.Model): | |
+ def _pre_create_hook(self): | |
+ test.pre_counter += 1 | |
+ def _post_create_hook(self, future): | |
+ test.post_counter += 1 | |
+ test.assertEqual(future.get_result(), test.entity.key) | |
+ | |
+ furniture = HatStand() | |
+ self.entity = furniture | |
+ self.assertEqual(self.pre_counter, 0, 'Pre create hook called early') | |
+ future = furniture.put_async() | |
+ self.assertEqual(self.pre_counter, 1, 'Pre create hook not called') | |
+ self.assertEqual(self.post_counter, 0, 'Post create hook called early') | |
+ future.get_result() | |
+ self.assertEqual(self.post_counter, 1, 'Post create hook not called') | |
+ | |
+ # All counters now read 1, calling put_multi for 10 entities makes this 11 | |
+ new_furniture = [HatStand() for _ in range(10)] | |
+ results = [] | |
+ multi_future = model.put_multi_async(new_furniture) | |
+ self.assertEqual(self.pre_counter, 11, | |
+ 'Pre create hooks not called on put_multi') | |
+ self.assertEqual(self.post_counter, 1, | |
+ 'Post create hooks called early on put_multi') | |
+ for fut, ent in zip(multi_future, new_furniture): | |
+ self.entity = ent | |
+ fut.get_result() | |
+ | |
+ self.assertEqual(self.post_counter, 11, | |
+ 'Post create hooks not called on put_multi') | |
+ | |
def testGetByIdHooksCalled(self): | |
# See issue 95. http://goo.gl/QSRQH | |
# Adapted from testGetHooksCalled in key_test.py. | |
@@ -2916,6 +2952,10 @@ | |
test.pre_put_counter += 1 | |
def _post_put_hook(self, future): | |
test.post_put_counter += 1 | |
+ def _pre_create_hook(self): | |
+ test.pre_create_counter += 1 | |
+ def _post_create_hook(self, future): | |
+ test.post_create_counter += 1 | |
# First call creates it. This calls get() twice (once outside the | |
# transaction and once inside it) and put() once (from inside the | |
@@ -2924,27 +2964,45 @@ | |
self.post_get_counter = 0 | |
self.pre_put_counter = 0 | |
self.post_put_counter = 0 | |
+ self.pre_create_counter = 0 | |
+ self.post_create_counter = 0 | |
HatStand.get_or_insert('classic') | |
self.assertEqual(self.pre_get_counter, 2) | |
self.assertEqual(self.post_get_counter, 2) | |
self.assertEqual(self.pre_put_counter, 1) | |
self.assertEqual(self.post_put_counter, 1) | |
+ self.assertEqual(self.pre_create_counter, 1) | |
+ self.assertEqual(self.post_create_counter, 1) | |
# Second call gets it without needing a transaction. | |
self.pre_get_counter = 0 | |
self.post_get_counter = 0 | |
self.pre_put_counter = 0 | |
self.post_put_counter = 0 | |
- HatStand.get_or_insert_async('classic').get_result() | |
+ self.pre_create_counter = 0 | |
+ self.post_create_counter = 0 | |
+ ret = HatStand.get_or_insert_async('classic').get_result() | |
self.assertEqual(self.pre_get_counter, 1) | |
self.assertEqual(self.post_get_counter, 1) | |
self.assertEqual(self.pre_put_counter, 0) | |
self.assertEqual(self.post_put_counter, 0) | |
+ self.assertEqual(self.pre_create_counter, 0) | |
+ self.assertEqual(self.post_create_counter, 0) | |
+ | |
+ # Ensure post_create only gets called once | |
+ ret.put() | |
+ self.assertEqual(self.pre_get_counter, 1) | |
+ self.assertEqual(self.post_get_counter, 1) | |
+ self.assertEqual(self.pre_put_counter, 1) | |
+ self.assertEqual(self.post_put_counter, 1) | |
+ self.assertEqual(self.pre_create_counter, 0) | |
+ self.assertEqual(self.post_create_counter, 0) | |
def testMonkeyPatchHooks(self): | |
test = self # Closure for inside put hooks | |
hook_attr_names = ('_pre_allocate_ids_hook', '_post_allocate_ids_hook', | |
- '_pre_put_hook', '_post_put_hook') | |
+ '_pre_put_hook', '_post_put_hook', '_pre_create_hook', | |
+ '_post_create_hook') | |
original_hooks = {} | |
# Backup the original hooks | |
@@ -2955,6 +3013,8 @@ | |
self.post_allocate_ids_flag = False | |
self.pre_put_flag = False | |
self.post_put_flag = False | |
+ self.pre_create_flag = False | |
+ self.post_create_flag = False | |
# TODO: Should the unused arguments to Monkey Patched tests be tested? | |
class HatStand(model.Model): | |
@@ -2969,6 +3029,10 @@ | |
test.pre_put_flag = True | |
def _post_put_hook(self, unused_future): | |
test.post_put_flag = True | |
+ def _pre_create_hook(self): | |
+ test.pre_create_flag = True | |
+ def _post_create_hook(self, unused_future): | |
+ test.post_create_flag = True | |
# Monkey patch the hooks | |
for name in hook_attr_names: | |
@@ -2987,6 +3051,10 @@ | |
'Pre put hook not called when model is monkey patched') | |
self.assertTrue(self.post_put_flag, | |
'Post put hook not called when model is monkey patched') | |
+ self.assertTrue(self.pre_create_flag, | |
+ 'Pre create hook not called when model is monkey patched') | |
+ self.assertTrue(self.post_create_flag, | |
+ 'Post create hook not called when model is monkey patched') | |
finally: | |
# Restore the original hooks | |
for name in hook_attr_names: |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment