#!/usr/bin/python

from __future__ import print_function
from bcc import BPF
from bcc.utils import printb
from time import sleep
import sys
import pdb

if len(sys.argv) < 2:
    print("USAGE: tracer PID")
    sys.exit(1)

pid = int(sys.argv[1])

# BPF program
b = BPF(text="""
#include <uapi/linux/ptrace.h>

// Maps the allocated pointer to allocation time.
BPF_HASH(alloc_time, u64, u64);
// For all short-allocations, map the stack id to number of hits.
BPF_HASH(short_allocs, u32, u64, 1<<12);
BPF_STACK_TRACE(stack_traces, 1<<12);

int alloc_leave(struct pt_regs *ctx) {
    u64 ptr = PT_REGS_RC(ctx);
    u64 start_time = bpf_ktime_get_ns();

    u64 zero = 0;
    u64 *val;
    val = alloc_time.lookup_or_try_init(&ptr, &zero);
    if (val) {
      *val = start_time;
    }

    return 0;
}

int free_enter(struct pt_regs *ctx, size_t size) {
    u64 key = bpf_get_current_pid_tgid();
    u64 ptr = PT_REGS_PARM1(ctx);
    int loc = stack_traces.get_stackid(ctx, BPF_F_USER_STACK);
    if (loc < 0)
        return 0;

    u64 *val = alloc_time.lookup(&ptr);
    if (!val) {
        return 0;
    }

    u64 curr_time = bpf_ktime_get_ns();
    u64 alloc_time = *val;
    u64 delta = curr_time - alloc_time;

    if (delta > 1000000) { // This is 1/1000 of a second
        return 0;
    }

    u64 zero = 0;
    val = short_allocs.lookup_or_try_init(&loc, &zero);
    if (val) {
      *val += 1;
    }
    return 0;
}
""")

b.attach_uretprobe(name="c", sym="malloc", fn_name="alloc_leave", pid=pid)
b.attach_uprobe(name="c", sym="free", fn_name="free_enter", pid=pid)

print("Attaching to malloc/free in pid %d, Ctrl+C to quit." % pid)

stack_traces = b.get_table("stack_traces")
short_allocs = b.get_table("short_allocs")

def print_top_stack_traces(stack_traces, short_allocs , num):
    top_allocs = sorted(short_allocs.items(), key = lambda kv: kv[1].value, reverse=True)[:num]

    idx = 0
    for st, hits in top_allocs:
        print("[%d ]Hits: %d" % (idx, hits.value))
        print_stack_trace(stack_traces, st.value)
        print("-" * 10)
        idx += 1

def print_stack_trace(stack_traces, stack_id):
    for addr in stack_traces.walk(stack_id):
        printb(b"\t[%08x] %s" % (addr, b.sym(addr, pid, show_offset=True)))


# Process events until Ctrl-C
while 1:
    try:
        sleep(1)
        print_top_stack_traces(stack_traces, short_allocs, 10)

    except KeyboardInterrupt:
        print("Detaching to malloc in pid %d." % pid)
        b.detach_uretprobe(name="c", sym="malloc",  pid=pid)
        b.detach_uprobe(name="c", sym="free", pid=pid)
        exit();