import functools import os import re import threading import time import cffi __all__ = ["teardown", "reset", "report", "wrap_fn"] SELF_DIR = os.path.dirname(os.path.abspath(__file__)) BTS_DECLARATIONS = """ struct bts_aux_record { uint64_t from_addr; /* from and to are instruction addresses. */ uint64_t to_addr; uint64_t flags; /* 0x10 = predicted, in theory, maybe. */ }; /* * This function must be called with the value in * `/sys/bus/event_source/devices/intel_bts/type` before calling * `bts_setup`. */ void bts_init(int detected_bts_perf_type); /* * Cleans up any BTS state for the current thread. */ void bts_teardown(void); /* * Overwrites or creates the BTS state for the current thread, with * an auxiliary (tracing) buffer of `aux_size` bytes. `aux_bytes` * must be a power of two and must be at least one page. * * Returns 0 on success, negative on failure. */ int bts_setup(size_t aux_size); /* * Enables branch tracing for the calling thread, which must have * a BTS state (i.e., only call `bts_start` after `bts_setup`). * * Returns 0 on success, negative on failure. */ int bts_start(void); /* * Stops branch tracing for the current thread, and returns a * temporary (thread-local) buffer of the BTS records since * the last call to `bts_start`. * * The first argument is overwritten with the number of valid * records in the return value, or a negative count on error. * * When `(*OUT_num_elements + 2) * sizeof(struct bts_aux_record)` * exceeds the `aux_size` passed to `bts_setup`, tracing may have * exhausted the buffer space and stopped early. This trace * truncation does not affect the execution of the traced program. */ const struct bts_aux_record *bts_stop(ssize_t *OUT_num_elements); """ DEFAULT_AUX_SIZE = 2 ** 25 FFI = None BTS = None ENABLED = False BTS_TYPE = None BTS_TYPE_PATH = "/sys/bus/event_source/devices/intel_bts/type" try: with open(BTS_TYPE_PATH, "r") as f: BTS_TYPE = int(f.read()) ENABLED = True except: pass def _init_bts(): BTS.bts_init(BTS_TYPE) if ENABLED: FFI = cffi.FFI() FFI.cdef(BTS_DECLARATIONS) BTS = FFI.dlopen(SELF_DIR + "/libbts.so") FFI.init_once(_init_bts, "init_bts") def find_current_mappings(): ret = [] with open("/proc/self/maps", "r") as f: for line in f: m = re.match(r"^([0-9a-f]+)-([0-9a-f]+) r-xp .*", line) if m: mapping = (int(m.group(1), 16), int(m.group(2), 16)) ret.append(mapping) return ret BASELINE_MAPPINGS = find_current_mappings() def address_in_baseline_map(x): for mapping in BASELINE_MAPPINGS: if mapping[0] <= x < mapping[1]: return True return False FULLY_SETUP = threading.local() def teardown(): if not ENABLED: return BTS.bts_teardown() FULLY_SETUP.setup = False def ensure_setup(buffer_size=DEFAULT_AUX_SIZE): if not ENABLED or getattr(FULLY_SETUP, "setup", None): return assert BTS.bts_setup(buffer_size) == 0 FULLY_SETUP.setup = True EDGE_BUFFER = threading.local() def reset(buffer_size=DEFAULT_AUX_SIZE): if not ENABLED: return EDGE_BUFFER.buffer = [] EDGE_BUFFER.call_count = 0 assert BTS.bts_setup(buffer_size) == 0 FULLY_SETUP.setup = True MIN_ADDRESS = 2 ** 12 MAX_ADDRESS = 2 ** 63 - 1 ALL_SEEN_EDGES = dict() USELESS_EDGES = set() initial_time = time.time() def hash_report(od_pairs): global BEST_VALUES """Sketches the *unique* origin/destination pairs into an array of values""" ret = list() seen = set(USELESS_EDGES) for pair in od_pairs: # Skip kernel addresses. if pair[0] > MAX_ADDRESS or pair[1] > MAX_ADDRESS: continue if pair[0] < MIN_ADDRESS or pair[1] < MIN_ADDRESS: continue if pair in seen: continue if address_in_baseline_map(pair[0]) or address_in_baseline_map(pair[1]): continue if pair not in ALL_SEEN_EDGES: print( "%f new edge %i %s" % ( time.time() - initial_time, len(ALL_SEEN_EDGES), (hex(pair[0]), hex(pair[1])), ) ) ALL_SEEN_EDGES[pair] = len(ALL_SEEN_EDGES) seen.add(pair) ret.append(ALL_SEEN_EDGES[pair]) return ret def update_useless_edges(): if not ENABLED: return num = FFI.new("ssize_t *") ret = BTS.bts_stop(num) for i in range(num[0]): USELESS_EDGES.add((ret[i].from_addr, ret[i].to_addr)) for pair in getattr(EDGE_BUFFER, "buffer", []): USELESS_EDGES.add(pair) def report(): if not ENABLED: return [] num = FFI.new("ssize_t *") ret = BTS.bts_stop(num) od_pairs = [(ret[i].from_addr, ret[i].to_addr) for i in range(num[0])] + getattr( EDGE_BUFFER, "buffer", [] ) call_count = max(1, getattr(EDGE_BUFFER, "call_count", 0)) return call_count, hash_report(od_pairs) def _start(): if not ENABLED or not getattr(FULLY_SETUP, "setup", None): return BTS.bts_start() def _stop(): if not ENABLED or not getattr(FULLY_SETUP, "setup", None): return if getattr(EDGE_BUFFER, "buffer", None) is None: EDGE_BUFFER.buffer = [] EDGE_BUFFER.call_count = 1 + getattr(EDGE_BUFFER, "call_count", 0) num = FFI.new("ssize_t *") ret = BTS.bts_stop(num) for i in range(num[0]): EDGE_BUFFER.buffer.append((ret[i].from_addr, ret[i].to_addr)) def wrap_fn(fn): if not ENABLED or not callable(fn): return fn @functools.wraps(fn) def wrapper(*arg, **kwargs): try: _start() return fn(*arg, **kwargs) finally: _stop() return wrapper