"""
A poor man's implementation of celery like async task manager.
Hacked in under 2 hours.

Author: Thulasi

Usage:
app = Flask(__name__)
tasker = Tasker(app, rabbitmq_params={'hostname': 'amqp://guest:guest@localhost:5672/reseller'})
        or
tasker = Tasker(app, rabbitmq_params=app.config['RABBITMQ_CONFIG'])

@tasker.task
def long_process(key, value):
    pass

long_process(key, value) # executes synchronously
long_process.defer(key, value)  # executes asynchronously
"""

import logging
from typing import Dict

from kombu import Queue, Exchange, Connection, connections, producers, uuid
from kombu.mixins import ConsumerMixin

logger = logging.getLogger(__name__)

exchange = Exchange('tasker', 'topic', durable=True)


class Tasker:
    registry = {}

    def __init__(self, app, rabbitmq_params: Dict):
        self.app = app

        rabbitmq_params['transport_options'] = {'confirm_publish': True}
        self.connection = Connection(**rabbitmq_params)

        @app.cli.command()
        def run_tasker():
            with self.connection as conn:
                worker = Worker(connection=conn, callback=self.callback)
                worker.run()

    def task(self, func, unique_name=''):
        task_name = unique_name if unique_name else func.__qualname__

        def defer(*args, **kwargs):
            data = {
                'task_name': task_name,
                'args': args,
                'kwargs': kwargs,
            }
            self._register(func, task_name)
            task_id = publish(self.connection, routing_key='task.#', data=data)
            return task_id

        self._register(func, task_name)
        func.defer = defer
        return func

    def callback(self, body):
        func_name = body['task_name']
        args = body['args']
        kwargs = body['kwargs']
        self.registry[func_name](*args, **kwargs)

    def _register(self, func, task_name):
        if task_name in self.registry:
            if not self.registry[task_name] == func:
                raise RuntimeError('Duplicate task received with same name. Use @task(unique_name=...)')

        self.registry[task_name] = func


def publish(connection, routing_key, data, unique_id=None):
    unique_id = unique_id or uuid()
    with connections[connection].acquire(block=True, timeout=300) as conn:
        with producers[conn].acquire(block=True, timeout=30) as producer:
            logger.info(f'Publishing message {unique_id} with {data}')
            if 'errors' not in data:
                data['errors'] = []
            producer.publish(
                exchange=exchange,
                body=data,
                routing_key=routing_key,
                declare=[exchange],
                message_id=unique_id,
            )
            return unique_id


class Worker(ConsumerMixin):
    queue = Queue('tasks', exchange, 'task.#')
    dead_queue = Queue('dead-tasks', exchange, 'dead.task.#')

    def __init__(self, connection, callback):
        self.connection = connection
        self.callback = callback

        self.dead_queue.maybe_bind(self.connection.channel())
        self.dead_queue.declare()

    def get_consumers(self, Consumer, channel):
        return [Consumer(queues=[self.queue], callbacks=[self.on_task])]

    def on_task(self, body, message):
        unique_id = message.properties.get('message_id')
        if not unique_id:
            unique_id = uuid()
            message.properties['message_id'] = unique_id

        logger.info(f'Got message with task_id: {unique_id}')
        try:
            self.callback(body)
        except Exception as e:
            logger.exception(e)
            body['errors'].append(repr(e))
            publish(
                self.connection,
                routing_key='dead.task',
                data=body,
                unique_id=unique_id,
            )
        finally:
            message.ack()