import functools
import os
import re
import threading
import time

import cffi

__all__ = ["teardown", "reset", "report", "wrap_fn"]

SELF_DIR = os.path.dirname(os.path.abspath(__file__))

BTS_DECLARATIONS = """
struct bts_aux_record {
        uint64_t from_addr;   /* from and to are instruction addresses. */
        uint64_t to_addr;
        uint64_t flags;  /* 0x10 = predicted, in theory, maybe. */
};

/*
 * This function must be called with the value in
 * `/sys/bus/event_source/devices/intel_bts/type` before calling
 * `bts_setup`.
 */
void bts_init(int detected_bts_perf_type);

/*
 * Cleans up any BTS state for the current thread.
 */
void bts_teardown(void);

/*
 * Overwrites or creates the BTS state for the current thread, with
 * an auxiliary (tracing) buffer of `aux_size` bytes.  `aux_bytes`
 * must be a power of two and must be at least one page.
 *
 * Returns 0 on success, negative on failure.
 */
int bts_setup(size_t aux_size);

/*
 * Enables branch tracing for the calling thread, which must have
 * a BTS state (i.e., only call `bts_start` after `bts_setup`).
 *
 * Returns 0 on success, negative on failure.
 */
int bts_start(void);

/*
 * Stops branch tracing for the current thread, and returns a
 * temporary (thread-local) buffer of the BTS records since
 * the last call to `bts_start`.
 *
 * The first argument is overwritten with the number of valid
 * records in the return value, or a negative count on error.
 *
 * When `(*OUT_num_elements + 2) * sizeof(struct bts_aux_record)`
 * exceeds the `aux_size` passed to `bts_setup`, tracing may have
 * exhausted the buffer space and stopped early.  This trace
 * truncation does not affect the execution of the traced program.
 */
const struct bts_aux_record *bts_stop(ssize_t *OUT_num_elements);
"""

DEFAULT_AUX_SIZE = 2 ** 25

FFI = None
BTS = None

ENABLED = False
BTS_TYPE = None

BTS_TYPE_PATH = "/sys/bus/event_source/devices/intel_bts/type"

try:
    with open(BTS_TYPE_PATH, "r") as f:
        BTS_TYPE = int(f.read())
    ENABLED = True
except:
    pass


def _init_bts():
    BTS.bts_init(BTS_TYPE)


if ENABLED:
    FFI = cffi.FFI()
    FFI.cdef(BTS_DECLARATIONS)
    BTS = FFI.dlopen(SELF_DIR + "/libbts.so")
    FFI.init_once(_init_bts, "init_bts")


def find_current_mappings():
    ret = []
    with open("/proc/self/maps", "r") as f:
        for line in f:
            m = re.match(r"^([0-9a-f]+)-([0-9a-f]+) r-xp .*", line)
            if m:
                mapping = (int(m.group(1), 16), int(m.group(2), 16))
                ret.append(mapping)
    return ret


BASELINE_MAPPINGS = find_current_mappings()


def address_in_baseline_map(x):
    for mapping in BASELINE_MAPPINGS:
        if mapping[0] <= x < mapping[1]:
            return True
    return False


FULLY_SETUP = threading.local()


def teardown():
    if not ENABLED:
        return
    BTS.bts_teardown()
    FULLY_SETUP.setup = False


def ensure_setup(buffer_size=DEFAULT_AUX_SIZE):
    if not ENABLED or getattr(FULLY_SETUP, "setup", None):
        return
    assert BTS.bts_setup(buffer_size) == 0
    FULLY_SETUP.setup = True


EDGE_BUFFER = threading.local()


def reset(buffer_size=DEFAULT_AUX_SIZE):
    if not ENABLED:
        return
    EDGE_BUFFER.buffer = []
    EDGE_BUFFER.call_count = 0
    assert BTS.bts_setup(buffer_size) == 0
    FULLY_SETUP.setup = True


MIN_ADDRESS = 2 ** 12


MAX_ADDRESS = 2 ** 63 - 1


ALL_SEEN_EDGES = dict()


USELESS_EDGES = set()


initial_time = time.time()


def hash_report(od_pairs):
    global BEST_VALUES
    """Sketches the *unique* origin/destination pairs into an array of values"""
    ret = list()
    seen = set(USELESS_EDGES)
    for pair in od_pairs:
        # Skip kernel addresses.
        if pair[0] > MAX_ADDRESS or pair[1] > MAX_ADDRESS:
            continue
        if pair[0] < MIN_ADDRESS or pair[1] < MIN_ADDRESS:
            continue
        if pair in seen:
            continue
        if address_in_baseline_map(pair[0]) or address_in_baseline_map(pair[1]):
            continue
        if pair not in ALL_SEEN_EDGES:
            print(
                "%f new edge %i %s"
                % (
                    time.time() - initial_time,
                    len(ALL_SEEN_EDGES),
                    (hex(pair[0]), hex(pair[1])),
                )
            )
            ALL_SEEN_EDGES[pair] = len(ALL_SEEN_EDGES)
        seen.add(pair)
        ret.append(ALL_SEEN_EDGES[pair])
    return ret


def update_useless_edges():
    if not ENABLED:
        return
    num = FFI.new("ssize_t *")
    ret = BTS.bts_stop(num)
    for i in range(num[0]):
        USELESS_EDGES.add((ret[i].from_addr, ret[i].to_addr))
    for pair in getattr(EDGE_BUFFER, "buffer", []):
        USELESS_EDGES.add(pair)


def report():
    if not ENABLED:
        return []
    num = FFI.new("ssize_t *")
    ret = BTS.bts_stop(num)
    od_pairs = [(ret[i].from_addr, ret[i].to_addr) for i in range(num[0])] + getattr(
        EDGE_BUFFER, "buffer", []
    )
    call_count = max(1, getattr(EDGE_BUFFER, "call_count", 0))
    return call_count, hash_report(od_pairs)


def _start():
    if not ENABLED or not getattr(FULLY_SETUP, "setup", None):
        return
    BTS.bts_start()


def _stop():
    if not ENABLED or not getattr(FULLY_SETUP, "setup", None):
        return
    if getattr(EDGE_BUFFER, "buffer", None) is None:
        EDGE_BUFFER.buffer = []
    EDGE_BUFFER.call_count = 1 + getattr(EDGE_BUFFER, "call_count", 0)
    num = FFI.new("ssize_t *")
    ret = BTS.bts_stop(num)
    for i in range(num[0]):
        EDGE_BUFFER.buffer.append((ret[i].from_addr, ret[i].to_addr))


def wrap_fn(fn):
    if not ENABLED or not callable(fn):
        return fn

    @functools.wraps(fn)
    def wrapper(*arg, **kwargs):
        try:
            _start()
            return fn(*arg, **kwargs)
        finally:
            _stop()

    return wrapper