Created
May 17, 2012 16:02
-
-
Save winhamwr/2719812 to your computer and use it in GitHub Desktop.
Celery base task that adds some niceties for longish-running or singleton jobs.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Celery base task aimed at longish-running jobs that return a result. | |
``AwesomeResultTask`` adds thundering herd avoidance, result caching, progress | |
reporting, error fallback and JSON encoding of results. | |
""" | |
from __future__ import division | |
import logging | |
import simplejson | |
import time | |
from contextlib import contextmanager | |
from hashlib import md5 | |
from uuid import uuid4 | |
from celery import states | |
from celery.backends import default_backend | |
from celery.task import Task | |
from celery.result import BaseAsyncResult | |
from celery.signals import task_prerun, task_postrun | |
from djcelery.models import TaskMeta | |
from johnny.middleware import QueryCacheMiddleware, LocalStoreClearMiddleware | |
from django.conf import settings | |
from django.core.cache import cache | |
def task_prerun_handler(*args, **kwargs): | |
""" | |
Before each Task is ran, we have to instantiate Johnny's query cache | |
monkey patch. This will make sure that any table writes invalidate table | |
caches, and reads pull from any existing caches. | |
""" | |
QueryCacheMiddleware() | |
task_prerun.connect(task_prerun_handler) | |
def task_postrun_handler(*args, **kwargs): | |
""" | |
After each task is ran, the LocalStore cache (similar to threadlocals) is | |
cleared, as is the case with views (instead of celery tasks). | |
""" | |
clear_middleware = LocalStoreClearMiddleware() | |
clear_middleware.process_response(None, None) | |
task_postrun.connect(task_postrun_handler) | |
@contextmanager | |
def acquire_lock(lock_name): | |
for _ in range(10): | |
try: | |
value = cache.incr(lock_name) | |
except ValueError: | |
cache.set(lock_name, 0) | |
value = cache.incr(lock_name) | |
if value == 1: | |
break | |
else: | |
cache.decr(lock_name) | |
else: | |
yield | |
cache.set(lock_name, 0) | |
return | |
yield | |
cache.decr(lock_name) | |
def get_task_meta_error(exception): | |
""" | |
Take an exception and turn it in to a Celery result tombstone that mimics | |
what would happen if that error were thrown during a Task run. | |
This is copy/pasted from pstat.core.utils | |
""" | |
# Need to return an object that has a uuid attribute, and need to store the | |
# result in the cache | |
fake_result = TaskMeta() | |
task_id = str(uuid4()) | |
fake_result.task_id = task_id | |
default_backend.store_result(task_id, exception, status=states.FAILURE) | |
return fake_result | |
class AwesomeResultTask(Task): | |
""" | |
A base ``Celery.Task`` class that provides some common niceties for running | |
tasks that return some kind of result for which you need to wait. | |
To create a task that uses these helpers, use ``AwesomeResultTask`` as a | |
subclass and define a ``calculate_result`` method which returns a | |
dictionary to be turned in to JSON. You will also need to define the | |
following class variables: | |
* ``cache_prefix`` A unique string representing this task. Eg. | |
``foo.bar.tasks.BazzTask`` | |
* ``significant_kwargs`` The kwarg values that will be converted to strings | |
and hashed to determine if two versions of the same task are equivelent. | |
This is a list of 2-tuples with the first item being the kwarg string and | |
the second being a callable that converts the value to a hashable string. | |
If no second item is given, it's assumed that calling ``str()`` on the | |
value works just fine. | |
* ``herd_avoidance_timeout`` Number of seconds to hold a lock on this task | |
for other equivelant runs. Generally, this should be set to the longest | |
estimated amount of time the task could consume. | |
* ``cache_duration`` The number of seconds for which the result of this | |
task should be cached, meaning subsequent equivelant runs will skip | |
computation. Set this to ``-1`` to disable caching. | |
Provided are helpers for: | |
1. Handling failures to connect the task broker by either directly | |
running the task (`delay_or_run`) or by returning a task that | |
contains the connection error (`delay_or_fail`). This minimizes | |
the user-facing impact of a dead task broker. | |
2. Defeating any thundering herd issues by ensuring only one of a task with | |
specific arguments can be running at a time by directing subsequent calls | |
to latch on to the appropriate result. | |
3. Caching the final result for a designated time period so that subsequent | |
equivelant calls return quickly. | |
4. Returning the results as JSON, so that they can be processed easily by | |
client-side javascript. | |
5. Returning time-based, continually updating progress estimates to | |
front-end code so that users know what to expect. | |
""" | |
def delay_or_run(self, *args, **kwargs): | |
""" | |
Attempt to call self.delay, or if that fails, call self.run. | |
Returns a tuple, (result, fallback). ``result`` is the result of | |
calling delay or run. ``fallback`` is a boolean that is True when | |
self.run was called instead of self.delay. | |
""" | |
try: | |
result = self.delay(*args, **kwargs) | |
fallback = False | |
except IOError: | |
result = self.run(*args, **kwargs) | |
fallback = True | |
return result, fallback | |
def delay_or_fail(self, *args, **kwargs): | |
""" | |
Attempt to call self.delay, but if that fails with an exception, we | |
fake the task completion using the exception as the result. This allows | |
us to seamlessly handle errors on task creation the same way we handle | |
errors when a task runs, simplifying the user interface. | |
Returns a ``TaskMeta`` object that is either the result of calling | |
delay or the faked task result. | |
""" | |
try: | |
return self.delay(*args, **kwargs) | |
except IOError as e: | |
return get_task_meta_error(e) | |
def delay(self, *args, **kwargs): | |
""" | |
Put this task on the Celery queue as a singleton. Only one of this type | |
of task with its distinguishing args/kwargs will be allowed on the | |
queue at a time. Subsequent duplicate tasks called while this task is | |
still running will just latch on to the results of the running task by | |
synchronizing the task uuid. Additionally, identical task calls will | |
return those results for the next ``cache_duration`` seconds. | |
Passing a ``cache_duration`` keyword argument controls how long | |
identical task calls will latch on to previously cached results. | |
""" | |
self._validate_required_class_vars() | |
cache_key = self._get_cache_key(**kwargs) | |
# Check for an already-computed and cached result | |
task_id = cache.get(cache_key) # Check for the cached result | |
if task_id: | |
# We've already built this result, just latch on to the task that | |
# did the work | |
logging.info( | |
'Found existing cached and completed task: %s', task_id) | |
return BaseAsyncResult(task_id, self.backend) | |
# Check for an in-progress equivelant task to avoid duplicating work | |
task_id = cache.get('herd:%s' % cache_key) | |
if task_id: | |
logging.info('Found existing in-progress task: %s', task_id) | |
return BaseAsyncResult(task_id, self.backend) | |
# It's not cached and it's not already running. Use an atomic lock to | |
# start the task, ensuring there isn't a race condition that could | |
# result in multiple identical tasks being fired at once. | |
with acquire_lock('lock:%s' % cache_key): | |
task_meta = super(AwesomeResultTask, self).delay( | |
*args, **kwargs) | |
logging.info('Current status: %s', task_meta.status) | |
if task_meta.status in ['PROGRESS', 'PENDING']: | |
cache.set( | |
'herd:%s' % cache_key, | |
task_meta.task_id, | |
timeout=self.herd_avoidance_timeout) | |
logging.info( | |
'Setting herd-avoidance cache for task: %s', cache_key) | |
return task_meta | |
def calc_progress(self, completed_count, total_count): | |
""" | |
Calculate the percentage progress and estimated remaining time based on | |
the current number of items completed of the total. | |
Returns a tuple of ``(percentage_complete, seconds_remaining)``. | |
""" | |
self.logger.debug( | |
"calc_progress(%s, %s)", | |
completed_count, | |
total_count, | |
) | |
current_time = time.time() | |
time_spent = current_time - self.start_time | |
self.logger.debug("Progress time spent: %s", time_spent) | |
if total_count == 0: | |
return 100, 1 | |
completion_fraction = completed_count / total_count | |
if completion_fraction == 0: | |
completion_fraction = 1 | |
total_time = 0 | |
total_time = time_spent / completion_fraction | |
time_remaining = total_time - time_spent | |
completion_display = completion_fraction * 100 | |
if completion_display == 100: | |
return 100, 1 # 1 second to finish up | |
return completion_display, time_remaining | |
def update_progress(self, completed_count, total_count): | |
""" | |
Update the task backend with both an estimated percentage complete and | |
number of seconds remaining until completion. | |
``completed_count`` Number of task "units" that have been completed out | |
of ``total_count`` total "units." | |
""" | |
# Store progress for display | |
progress_percent, time_remaining = self.calc_progress( | |
completed_count, total_count) | |
self.logger.info( | |
"Updating progress: %s percent, %s remaining", | |
progress_percent, | |
time_remaining) | |
self.backend.store_result( | |
self.task_id, | |
result={ | |
"progress_percent": progress_percent, | |
"time_remaining": time_remaining, | |
}, | |
status="PROGRESS") | |
def run(self, *args, **kwargs): | |
self.logger = self.get_logger(**kwargs) | |
self.logger.info("Starting %s", self.__class__.__name__) | |
self.cache_key = self._get_cache_key(**kwargs) | |
# Record start time to give estimated time remaining estimates | |
self.start_time = time.time() | |
# Report to the backend that work has been started. | |
self.task_id = kwargs.get('task_id', None) | |
self.backend.store_result( | |
self.task_id, | |
result={ | |
"progress_percent": 0, | |
"time_remaining": -1, | |
}, | |
status="PROGRESS", | |
) | |
self.logger.info("Calculating result") | |
try: | |
result = self.calculate_result(*args, **kwargs) | |
except Exception: | |
# Don't want other tasks waiting for this task to finish, since it | |
# won't | |
self._break_thundering_herd_cache() | |
raise # We can use normal celery exception handling for this | |
self.logger.info("Serializing result") | |
json_result = simplejson.dumps(result) | |
if self.cache_duration >= 0: | |
# If we're configured to cache this result, do so. | |
cache.set(self.cache_key, self.task_id, self.cache_duration) | |
# Now that the task is finished, we can stop all of the thundering herd | |
# avoidance | |
self._break_thundering_herd_cache() | |
return json_result | |
def _validate_required_class_vars(self): | |
""" | |
Ensure that this subclass has defined all of the required class | |
variables. | |
""" | |
required_members = [ | |
'cache_prefix', | |
'significant_kwargs', | |
'herd_avoidance_timeout', | |
'cache_duration', | |
] | |
for required_member in required_members: | |
if not hasattr(self, required_member): | |
raise Exception( | |
"AwesomeResultTask's must define a %s" % required_member) | |
def on_success(self, retval, task_id, args, kwargs): | |
""" | |
Store results in the backend even if we're always eager. This helps for | |
testing. | |
""" | |
if getattr(settings, 'CELERY_ALWAYS_EAGER', False): | |
# Store the result because celery wouldn't otherwise | |
self.backend.store_result(task_id, retval, status=states.SUCCESS) | |
def _break_thundering_herd_cache(self): | |
cache.delete('herd:%s' % self.cache_key) | |
def _get_cache_key(self, **kwargs): | |
""" | |
Take this task's configured ``significant_kwargs`` and build a hash | |
that all equivelant task calls will match. | |
Takes in kwargs and returns a string. | |
To change the way the cache key is generated or do more in-depth | |
processing, override this method. | |
""" | |
m = md5() | |
for significant_kwarg in self.significant_kwargs: | |
key, to_str = significant_kwarg | |
m.update(to_str(kwargs[key])) | |
return '%s:%s' % (self.cache_prefix, m.hexdigest()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Nice.