Skip to content

Instantly share code, notes, and snippets.

@regit
Created November 28, 2018 20:29
Show Gist options
  • Save regit/1e591311fa3ba5cd0b8d73940348599a to your computer and use it in GitHub Desktop.
Save regit/1e591311fa3ba5cd0b8d73940348599a to your computer and use it in GitHub Desktop.
Sobind: a eBPF script to detect network bind attempt
#! /usr/bin/python2
#
# sobind Trace TCP bind events
# For Linux, uses BCC, eBPF. Embedded C.
#
# USAGE: sobind.py [-h] [-p PID] [--show-netns]
#
# This is provided as a basic example of TCP connection & socket tracing.
# It could be useful in scenarios where load balancers needs to be updated
# dynamically as application is fully initialized.
#
# All IPv4 bind attempts are traced, even if they ultimately fail or the
# the binding program is not willing to accept().
#
# Copyright (c) 2016 Jean-Tiare Le Bigot.
# Licensed under the Apache License, Version 2.0 (the "License")
#
# 04-Mar-2016 Jean-Tiare Le Bigot Created this.
import os
from socket import inet_ntop, AF_INET, AF_INET6, SOCK_STREAM, SOCK_DGRAM, ntohs
from struct import pack
import argparse
from bcc import BPF
import ctypes as ct
# Arguments
examples = """Examples:
./sobind.py # Stream socket bind
./sobind.py -p 1234 # Stream socket bind for specified PID only
./sobind.py --netns 4242 # " for the specified network namespace ID only
./sobind.py --show-netns # Show network ns ID (useful for containers)
"""
parser = argparse.ArgumentParser(
description="Stream sockets bind",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=examples)
parser.add_argument("--show-netns", action="store_true",
help="show network namespace")
parser.add_argument("-p", "--pid", default=0, type=int,
help="trace this PID only")
parser.add_argument("-n", "--netns", default=0, type=int,
help="trace this Network Namespace only")
parser.add_argument("--ebpf", action="store_true",
help=argparse.SUPPRESS)
# BPF Program
bpf_text = """
#include <net/net_namespace.h>
#include <bcc/proto.h>
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wenum-conversion"
#include <net/inet_sock.h>
#pragma clang diagnostic pop
// Common structure for UDP/TCP IPv4/IPv6
struct bind_evt_t {
u64 ts_us;
u64 pid_tgid;
u64 netns;
u64 proto; // family << 16 | type
u64 lport; // use only 16 bits
u64 laddr[2]; // IPv4: store in laddr[0]
char task[TASK_COMM_LEN];
};
BPF_PERF_OUTPUT(bind_evt);
// Send an event for each IPv4 bind with PID, bound address and port
static int inet_bind(struct pt_regs *ctx, struct socket *sock,
const struct sockaddr *addr,
int addrlen)
{
// cast types. Intermediate cast not needed, kept for readability
struct sock *sk = sock->sk;
// Built event for userland
struct bind_evt_t evt = {
.ts_us = bpf_ktime_get_ns() / 1000,
};
// Get process comm. Needs LLVM >= 3.7.1
// see https://github.com/iovisor/bcc/issues/393
bpf_get_current_comm(evt.task, TASK_COMM_LEN);
// Get socket IP family
u16 family = sk->__sk_common.skc_family;
u8 protocol = 0;
// Can't read bitfield directly so use this direct read
bpf_probe_read(&protocol, 1, ((u8*)&sk->sk_gso_max_segs - 3));
if (protocol == IPPROTO_TCP) {
protocol = SOCK_STREAM;
} else if (protocol == IPPROTO_UDP) {
protocol = SOCK_DGRAM;
}
evt.proto = family << 16 | protocol;
// Get PID
evt.pid_tgid = bpf_get_current_pid_tgid();
##FILTER_PID##
if (family == AF_INET) {
struct sockaddr_in *in_addr = (struct sockaddr_in *)addr;
// Get port
evt.lport = in_addr->sin_port;
evt.lport = ntohs(evt.lport);
evt.laddr[0] = in_addr->sin_addr.s_addr;
} else if (family == AF_INET6) {
struct sockaddr_in6 *in6_addr = (struct sockaddr_in6 *)addr;
evt.lport = in6_addr->sin6_port;
evt.lport = ntohs(evt.lport);
bpf_probe_read(evt.laddr, sizeof(evt.laddr),
in6_addr->sin6_addr.s6_addr);
}
// Get network namespace id, if kernel supports it
#ifdef CONFIG_NET_NS
evt.netns = sk->__sk_common.skc_net.net->ns.inum;
#else
evt.netns = 0;
#endif
##FILTER_NETNS##
// Send event to userland
bind_evt.perf_submit(ctx, &evt, sizeof(evt));
return 0;
};
int kprobe__inet_bind(struct pt_regs *ctx, struct socket *sock,
const struct sockaddr *addr,
int addrlen)
{
return inet_bind(ctx, sock, addr, addrlen);
}
// Send an event for each IPv6 bind with PID, bound address and port
int kprobe__inet6_bind(struct pt_regs *ctx, struct socket *sock,
const struct sockaddr *addr,
int addrlen)
{
return inet_bind(ctx, sock, addr, addrlen);
}
"""
# event data
TASK_COMM_LEN = 16 # linux/sched.h
class ListenEvt(ct.Structure):
_fields_ = [
("ts_us", ct.c_ulonglong),
("pid_tgid", ct.c_ulonglong),
("netns", ct.c_ulonglong),
("proto", ct.c_ulonglong),
("lport", ct.c_ulonglong),
("laddr", ct.c_ulonglong * 2),
("task", ct.c_char * TASK_COMM_LEN)
]
# TODO: properties to unpack protocol / ip / pid / tgid ...
# Format output
def event_printer(show_netns):
def print_event(cpu, data, size):
# Decode event
event = ct.cast(data, ct.POINTER(ListenEvt)).contents
pid = event.pid_tgid & 0xffffffff
proto_family = event.proto & 0xff
proto_type = event.proto >> 16 & 0xff
if proto_family == SOCK_STREAM:
protocol = "TCP"
elif proto_family == SOCK_DGRAM:
protocol = "UDP"
else:
protocol = "UNK"
address = ""
if proto_type == AF_INET:
protocol += "v4"
address = inet_ntop(AF_INET, pack("I", event.laddr[0]))
elif proto_type == AF_INET6:
address = inet_ntop(AF_INET6, event.laddr)
protocol += "v6"
# Display
if show_netns:
print("%-6d %-12.12s %-12s %-6s %-5s %-39s" % (
pid, event.task, event.netns, protocol,
event.lport, address,
))
else:
print("%-6d %-12.12s %-6s %-5s %-39s" % (
pid, event.task, protocol,
event.lport, address,
))
return print_event
if __name__ == "__main__":
# Parse arguments
args = parser.parse_args()
pid_filter = ""
netns_filter = ""
if args.pid:
pid_filter = "if (evt.pid_tgid != %d) return 0;" % args.pid
if args.netns:
netns_filter = "if (evt.netns != %d) return 0;" % args.netns
bpf_text = bpf_text.replace("##FILTER_PID##", pid_filter)
bpf_text = bpf_text.replace("##FILTER_NETNS##", netns_filter)
if args.ebpf:
print(bpf_text)
exit()
# Initialize BPF
b = BPF(text=bpf_text)
b["bind_evt"].open_perf_buffer(event_printer(args.show_netns))
# Print headers
if args.show_netns:
print("%-6s %-12s %-12s %-6s %-5s %-39s" %
("PID", "COMM", "NETNS", "PROTO", "PORT", "ADDR"))
else:
print("%-6s %-12s %-6s %-5s %-39s" %
("PID", "COMM", "PROTO", "PORT", "ADDR"))
# Read events
while 1:
b.perf_buffer_poll()
@ivanmrsulja
Copy link

Thank you for this code, just one minor improvement, if you are reading socket protocol on a linux system is to use:

bpf_probe_read(&protocol, sizeof(protocol), &sk->sk_protocol);

@karelbilek
Copy link

karelbilek commented Jan 6, 2025

I would love this to be run with python3 (as getting python2 tooling is not easy anymore)... but I don't know any python; time to try the AI thing

edit: well claude converted it to python3 and it seems to work. Thank you robotic overlords

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment