Created
July 10, 2012 17:34
-
-
Save someone1/3084933 to your computer and use it in GitHub Desktop.
Testing code for reliabily saving hundreds of transactions on GAE
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import webapp2 | |
import decimal | |
import logging | |
import random | |
import string | |
from google.appengine.api import datastore_errors | |
from google.appengine.datastore import entity_pb | |
from google.appengine.ext import db | |
from google.appengine.ext import ndb | |
from google.appengine.ext.ndb import metadata | |
class DecimalProperty(ndb.Property): | |
"""A Property whose value is a decimal.Decimal object.""" | |
def _datastore_type(self, value): | |
return str(value) | |
def _validate(self, value): | |
if not isinstance(value, decimal.Decimal): | |
raise datastore_errors.BadValueError('Expected decimal.Decimal, got %r' | |
% (value,)) | |
return value | |
def _db_set_value(self, v, p, value): | |
value = str(value) | |
v.set_stringvalue(value) | |
if not self._indexed: | |
p.set_meaning(entity_pb.Property.TEXT) | |
def _db_get_value(self, v, _): | |
if not v.has_stringvalue(): | |
return None | |
value = v.stringvalue() | |
return decimal.Decimal(value) | |
class Shard(ndb.Model): | |
"""Shards for each named counter""" | |
# No need for the name property, as it is the same as the key id | |
count = DecimalProperty(name='c', default=decimal.Decimal('0.00'), | |
indexed=False) | |
class Counter(ndb.Model): | |
"""Tracks the number of shards for each named counter""" | |
# No need for the name property, as it is the same as the key id | |
# NOTE: num_shards can only be set to a maximum of 5 due to xg limitations | |
num_shards = ndb.IntegerProperty(default=4, indexed=False) | |
@property | |
def shards(self): | |
prefix = self.key.id() # Cache for use in loop | |
dbkeys = [] | |
for index in range(self.num_shards): | |
name = prefix + str(index) | |
dbkey = ndb.Key(Shard, name) | |
dbkeys.append(dbkey) | |
return filter(None, ndb.get_multi(dbkeys, use_memcache=False)) | |
@ndb.tasklet | |
def compress_shards_async(self): | |
"""To be used when reducing num_shards""" | |
@ndb.tasklet | |
def __compress_shards_tx(): | |
shards = self.shards | |
first_shard = shards.pop(0) | |
dbkeys = [] | |
for shard in shards: | |
first_shard.count += shard.count | |
dbkeys.append(shard.key) | |
del_fut = ndb.delete_multi_async(dbkeys) | |
put_fut = first_shard.put_async() | |
yield del_fut, put_fut | |
yield ndb.transaction_async(__compress_shards_tx, use_memcache=False, | |
xg=True) | |
def compress_shards(self): | |
return self.compress_shards_async().get_result() | |
@property | |
def total(self): | |
count = decimal.Decimal('0.00') # Use initial value if no shards | |
for shard in self.shards: | |
count += shard.count | |
return count | |
@ndb.tasklet | |
def incr_async(self, value): | |
index = random.randint(0, self.num_shards - 1) # Use random shard | |
name = self.key.id() + str(index) | |
key = ndb.Key(Shard, name) | |
version = metadata.get_entity_group_version(key) | |
@ndb.tasklet | |
def __incr_tx(): | |
shard = yield Shard.get_by_id_async(name, use_memcache=False) | |
if not shard: | |
# Setting the parent key for future queries and maintenance | |
# removes the benefit of using shards (shared entity group) | |
shard = Shard(id=name) | |
shard.count += value | |
yield shard.put_async() | |
try: | |
yield ndb.transaction_async(__incr_tx) | |
except db.InternalError, e: | |
if version == ndb.metadata.get_entity_group_version(key): | |
logging.warning('Almost corrupted shard %s' % name) | |
self.incr_async(value) | |
else: | |
logging.error('Shard %s is corrupted' % name) | |
raise e | |
def incr(self, value): | |
return self.incr_async(value).get_result() | |
@ndb.tasklet | |
def increment_batch(data_set): | |
# NOTE: data_set is modified in place | |
# (1/3) filter and fire off counter gets | |
# so the futures can autobatch | |
counters = {} | |
ctr_futs = {} | |
ctr_put_futs = [] | |
zero_values = set() | |
for name, value in data_set.iteritems(): | |
if value != decimal.Decimal('0.00'): | |
ctr_fut = Counter.get_by_id_async(name) # Use cache(s) | |
ctr_futs[name] = ctr_fut | |
else: | |
# Skip zero values because... | |
zero_values.add(name) | |
continue | |
for name in zero_values: | |
del data_set[name] # Remove all zero values from the data_set | |
del zero_values | |
while data_set: # Repeat until all transactions succeed | |
# (2/3) wait on counter gets and fire off increment transactions | |
# this way autobatchers should fill time | |
incr_futs = {} | |
for name, value in data_set.iteritems(): | |
counter = counters.get(name) | |
if not counter: | |
counter = counters[name] = yield ctr_futs.pop(name) | |
if not counter: | |
logging.info('Creating new counter %s' % name) | |
counter = counters[name] = Counter(id=name) | |
ctr_put_futs.append(counter.put_async()) | |
incr_futs[(name, value)] = counter.incr_async(value) | |
# (3/3) wait on increments and handle errors | |
# by using a tuple key for variable access | |
for (name, value), incr_fut in incr_futs.iteritems(): | |
counter = counters[name] | |
try: | |
yield incr_fut | |
except db.TransactionFailedError: | |
if counter.num_shards != 5: | |
counter.num_shards += 1 | |
logging.info('Increasing number of shards for %s to %i.' % | |
(name, counter.num_shards)) | |
ctr_put_futs.append(counter.put_async()) | |
except db.InternalError: | |
pass | |
else: | |
del data_set[name] | |
if data_set: | |
logging.warning('%i increments failed this batch.' % len(data_set)) | |
yield ctr_put_futs # In case you get() the Counters later in the handler | |
raise ndb.Return(counters) | |
class ShardTestHandler(webapp2.RequestHandler): | |
@ndb.toplevel | |
def get(self): | |
if self.request.GET.get('delete'): | |
ndb.delete_multi_async(Shard.query().fetch(keys_only=True)) | |
ndb.delete_multi_async(Counter.query().fetch(keys_only=True)) | |
else: | |
data_set_test = {''.join([random.choice(string.letters+string.digits) for _ in range(12)]): decimal.Decimal(round(random.random() * 100, 2)) for _ in range(250)} | |
result = yield increment_batch(data_set_test) | |
self.response.out.write("Done!") | |
app = webapp2.WSGIApplication([('/shard_test/', ShardTestHandler)]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment