#!/usr/bin/env python
# Lesson 8

from bcc import BPF
import ctypes as ct

# load BPF program

with open('sync_count_bpf.c', 'r') as bpf_file:
    bpf_text=bpf_file.read()

b = BPF(text="""
#include <uapi/linux/ptrace.h>
#include <linux/blkdev.h>

struct data_t {
    u32 pid;
    u64 ts;
    u64 count;
    u64 delta;
    char comm[TASK_COMM_LEN];
};
BPF_PERF_OUTPUT(events);
BPF_HASH(last);

int do_trace(struct pt_regs *ctx) {
    u64 ts, zero = 0, *tsp, *valp, delta, key = 1, key_count = 0;
    struct data_t data = {};
    
    data.pid = bpf_get_current_pid_tgid();
    data.ts  = bpf_ktime_get_ns();
    bpf_get_current_comm(&data.comm, sizeof(data.comm));
    
    valp = last.lookup_or_init(&key_count, &zero);
    ++(*valp);
    data.count = *valp;
    
    // attempt to read stored timestamp
    tsp = last.lookup(&key);
    if (tsp != 0) {
        delta = bpf_ktime_get_ns() - *tsp;
        if (delta < 1000000000) {
            data.delta = delta / 1000000;
            // output if time is less than 1 second
            events.perf_submit(ctx, &data, sizeof(data));
        }
        last.delete(&key);
    }
    
    // update stored timestamp
    ts = bpf_ktime_get_ns();
    last.update(&key, &ts);
    
    return 0;
}
""")

b.attach_kprobe(event="sys_sync", fn_name="do_trace")
print("Tracing for quick sync's... Ctrl-C to end")

# define output data structure in Python
TASK_COMM_LEN = 16    # linux/sched.h
class Data(ct.Structure):
    _fields_ = [("pid", ct.c_ulonglong),
                ("ts", ct.c_ulonglong),
                ("count", ct.c_ulonglong),
                ("delta", ct.c_ulonglong),
                ("comm", ct.c_char * TASK_COMM_LEN)]

# format output
start = 0

def print_event(cpu, data, size):
    global start
    event = ct.cast(data, ct.POINTER(Data)).contents
    if start == 0:
        start = event.ts
    time_s = (float(event.ts - start)) /1000000000 
    print("%d at time %.2f s: multiple syncs detected, last %s ms ago" % (event.count, time_s, event.delta))

b["events"].open_perf_buffer(print_event)
while 1:
    b.kprobe_poll()