""" A poor man's implementation of celery like async task manager. Hacked in under 2 hours. Author: Thulasi Usage: app = Flask(__name__) tasker = Tasker(app, rabbitmq_params={'hostname': 'amqp://guest:guest@localhost:5672/reseller'}) or tasker = Tasker(app, rabbitmq_params=app.config['RABBITMQ_CONFIG']) @tasker.task def long_process(key, value): pass long_process(key, value) # executes synchronously long_process.defer(key, value) # executes asynchronously """ import logging from typing import Dict from kombu import Queue, Exchange, Connection, connections, producers, uuid from kombu.mixins import ConsumerMixin logger = logging.getLogger(__name__) exchange = Exchange('tasker', 'topic', durable=True) class Tasker: registry = {} def __init__(self, app, rabbitmq_params: Dict): self.app = app rabbitmq_params['transport_options'] = {'confirm_publish': True} self.connection = Connection(**rabbitmq_params) @app.cli.command() def run_tasker(): with self.connection as conn: worker = Worker(connection=conn, callback=self.callback) worker.run() def task(self, func, unique_name=''): task_name = unique_name if unique_name else func.__qualname__ def defer(*args, **kwargs): data = { 'task_name': task_name, 'args': args, 'kwargs': kwargs, } self._register(func, task_name) task_id = publish(self.connection, routing_key='task.#', data=data) return task_id self._register(func, task_name) func.defer = defer return func def callback(self, body): func_name = body['task_name'] args = body['args'] kwargs = body['kwargs'] self.registry[func_name](*args, **kwargs) def _register(self, func, task_name): if task_name in self.registry: if not self.registry[task_name] == func: raise RuntimeError('Duplicate task received with same name. Use @task(unique_name=...)') self.registry[task_name] = func def publish(connection, routing_key, data, unique_id=None): unique_id = unique_id or uuid() with connections[connection].acquire(block=True, timeout=300) as conn: with producers[conn].acquire(block=True, timeout=30) as producer: logger.info(f'Publishing message {unique_id} with {data}') if 'errors' not in data: data['errors'] = [] producer.publish( exchange=exchange, body=data, routing_key=routing_key, declare=[exchange], message_id=unique_id, ) return unique_id class Worker(ConsumerMixin): queue = Queue('tasks', exchange, 'task.#') dead_queue = Queue('dead-tasks', exchange, 'dead.task.#') def __init__(self, connection, callback): self.connection = connection self.callback = callback self.dead_queue.maybe_bind(self.connection.channel()) self.dead_queue.declare() def get_consumers(self, Consumer, channel): return [Consumer(queues=[self.queue], callbacks=[self.on_task])] def on_task(self, body, message): unique_id = message.properties.get('message_id') if not unique_id: unique_id = uuid() message.properties['message_id'] = unique_id logger.info(f'Got message with task_id: {unique_id}') try: self.callback(body) except Exception as e: logger.exception(e) body['errors'].append(repr(e)) publish( self.connection, routing_key='dead.task', data=body, unique_id=unique_id, ) finally: message.ack()