import time
import logging
import threading
log = logging.getLogger(__name__)

class MonitorInstance:
    def __init__(self, parent, label, func, threshold, active, metric):
        self.parent = parent
        self.label = label
        self.func = func
        self.threshold = threshold
        self.active = active
        self.metric = metric
        self.__errors = None

    def ok(self):
        if self.__errors is None or self.__errors:
            self.parent._ok(self)
        self.__errors = 0
        if self.metric:
            self.metric.set(0)

    def error(self):
        if not self.__errors:
            self.parent._error(self)

        if self.__errors is None:
            self.__errors = 0

        self.__errors += 1

        if self.metric:
            self.metric.inc()

    def check(self):
        try:
            self.func()
            self.ok()
        except Exception as e:
            log.error("%s error: %s", self.label, e)
            self.error()

    @property
    def healthy(self):
        return self.__errors < self.threshold

DEFAULT_THRESHOLD = 1           # errors to cause fault
DEFAULT_CHECKSECS = 5           # time in secs between checks

class Monitor:
    def __init__(self, health_callback=None, check_secs=DEFAULT_CHECKSECS, use_thread=False):
        self.active = []        # active moniors
        self.alerts = set()     # thresholds currently triggered (not healthy)
        self.health_callback = health_callback
        self.healthy = False    # default: not healthy unless a monitor is added!
        self.check_secs = check_secs
        self.last_check = 0

        if use_thread:
            assert self.check_secs > 0, "threads need to sleep"
            threading.Thread(target=self._thread_loop, daemon=True).start()

    def add(self, label, check, threshold=DEFAULT_THRESHOLD, active=False, metric=None):
        inst = MonitorInstance(self, label, check, threshold, active, metric)
        if active:
            self.active.append(inst)
        inst.check()
        return inst

    def _error(self, inst):
        self.alerts.add(inst)
        if self.healthy:
            self._callback(False)
        self.healthy = False

    def _thread_loop(self):
        while True:
            self.check()
            time.sleep(self.check_secs)

    def _callback(self, value):
        if not self.health_callback is None:
            try:
                self.health_callback(value)
            except:
                # health callback should always succeed!
                log.exception("deadlyexes: error calling %s", self.health_callback)

    def _ok(self, inst):
        self.alerts.discard(inst)
        if not self.healthy and not self.alerts:
            self._callback(True)
            self.healthy = True

    def check(self, force=False):
        if not force and (time.time() < (self.last_check + self.check_secs)):
            return False

        # returns true if check was done
        checked=False
        # convert to list prevents modifying iterators
        for inst in list(self.alerts) + self.active:
            try:
                checked=True
                inst.check()
            except:
                pass
        return checked

import unittest
from unittest.mock import MagicMock
class TestDeadly(unittest.TestCase):
    def test_basic(self):
        m = Monitor()
        ex = False
        inst = m.add("db", lambda: 1/0 if ex else 1, 1, False)
        assert m.healthy == True
        assert inst.healthy == True

        assert m.check() == False               # no point in checking stuff that isn't errored
        assert m.check(force=True) == False     # no active monitors to check

        inst.error()                            # passive monitor just alerted

        assert m.healthy == False               # no longer healthy
        assert inst.healthy == False

        assert m.check() == True                # fire off a check

        assert m.healthy == True                # all is ok now

        assert m.check() == False               # not checking again right away

        inst.error()                            # passive monitor just alerted
        ex = True                               # checker will fail

        assert m.check(force=True) == True      # force-check will check
        assert m.check(force=True) == True      # force-check will check again

        assert m.healthy == False               # still bad

        ex = False                              # checker will be ok

        assert m.check(force=True) == True      # force-check will check

        assert m.healthy == True

    def test_metric(self):
        m = Monitor()
        mockmetric = MagicMock()
        ex = False
        inst = m.add("db", lambda: 1/0 if ex else 1, 1, False, mockmetric)
        inst.error()
        inst.error()
        assert mockmetric.inc.call_count == 2

    def test_active(self):
        m = Monitor()
        ex = False
        m.add("db", lambda: 1/0 if ex else 1, 1, active=True)
        assert m.healthy == True
        assert m.check() == True
        assert m.check(force=True) == True      # actives are always checked, even when healthy
        assert m.check(force=True) == True
        ex = True
        assert m.check(force=True) == True
        assert m.healthy == False
        ex = False
        assert m.check(force=True) == True
        assert m.healthy == True

    def test_thread(self):
        with self.assertRaises(Exception):
            m = Monitor(use_thread=True, check_secs=0)

        m = Monitor(use_thread=True, check_secs=0.01)
        ex = False
        m.add("db", lambda: 1/0 if ex else 1, 1, True)
        assert(m.healthy)
        ex = True
        time.sleep(0.5)
        assert(not m.healthy)
        ex = False
        time.sleep(0.5)
        assert(m.healthy)

    def test_callback(self):
        event = threading.Event()
        state = None

        def cb(val):
            nonlocal state
            state = val
            event.set()

        m = Monitor(use_thread=True, check_secs=0.01, health_callback=cb)

        ex = False
        m.add("db", lambda: 1/0 if ex else 1, 1, True)

        assert(m.healthy)

        event.clear()
        state = None

        ex = True

        event.wait()
        assert state == False