UDP load balancer proto using bcc (XDP/Bpf)
#define KBUILD_MODNAME "foo"
#include <uapi/linux/bpf.h>
#include <linux/bpf.h>
#include <linux/icmp.h>
#include <linux/if_ether.h>
#include <linux/if_vlan.h>
#include <linux/in.h>
#include <linux/ip.h>
#include <linux/tcp.h>
#include <linux/udp.h>
/* 0x3FFF mask to check for fragment offset field */
#define IP_FRAGMENTED 65343
// MAC address
typedef unsigned char mac[6];
// Real Server structure (MAC address + IP address)
struct server {
__be32 ipAddr;
unsigned char macAddr[ETH_ALEN];
// packet structure to log load balancing
struct packet {
unsigned char dmac[ETH_ALEN];
unsigned char smac[ETH_ALEN];
__be32 daddr;
__be32 saddr;
static inline __u16 csum_fold_helper(__u64 csum) {
int i;
#pragma unroll
for (i = 0; i < 4; i ++) {
if (csum >> 16)
csum = (csum & 0xffff) + (csum >> 16);
return ~csum;
static inline void ipv4_csum_inline(void *iph, __u64 *csum) {
__u16 *next_iph_u16 = (__u16 *)iph;
#pragma clang loop unroll(full)
for (int i = 0; i < sizeof(struct iphdr) >> 1; i++) {
*csum += *next_iph_u16++;
*csum = csum_fold_helper(*csum);
static inline void ipv4_csum(void *data_start, int data_size, __u64 *csum) {
*csum = bpf_csum_diff(0, 0, data_start, data_size, *csum);
*csum = csum_fold_helper(*csum);
static inline void ipv4_l4_csum(void *data_start, __u32 data_size,
__u64 *csum, struct iphdr *iph) {
__u32 tmp = 0;
*csum = bpf_csum_diff(0, 0, &iph->saddr, sizeof(__be32), *csum);
*csum = bpf_csum_diff(0, 0, &iph->daddr, sizeof(__be32), *csum);
tmp = __builtin_bswap32((__u32)(iph->protocol));
*csum = bpf_csum_diff(0, 0, &tmp, sizeof(__u32), *csum);
tmp = __builtin_bswap32((__u32)(data_size));
*csum = bpf_csum_diff(0, 0, &tmp, sizeof(__u32), *csum);
*csum = bpf_csum_diff(0, 0, data_start, data_size, *csum);
*csum = csum_fold_helper(*csum);
// A map which contains port to redirect
BPF_HASH(ports, __be16, int, 10); // TODO how to we handle the max number of port we support.
// A map which contains real server
BPF_HASH(realServers, int, struct server, 10); // TODO how to we handle the max number of real server.
// Virtual IP is accessible via the 'VIP' constant
int xdp_prog(struct CTXTYPE *ctx) {
void *data_end = (void *)(long)ctx->data_end;
void *data = (void *)(long)ctx->data;
struct ethhdr * eth = data;
if (eth + 1 > data_end)
return XDP_DROP;
// Handle only IP packets (v4?)
if (eth->h_proto != bpf_htons(ETH_P_IP)){
return XDP_PASS;
struct iphdr *iph;
iph = eth + 1;
if (iph + 1 > data_end)
return XDP_DROP;
// Minimum valid header length value is 5.
// see (
if (iph->ihl < 5)
return XDP_DROP;
// IP header size is variable because of options field.
// see (
//if ((void *) iph + iph->ihl * 4 > data_end)
// return XDP_DROP;
// TODO support IP header with variable size
if (iph->ihl != 5)
return XDP_PASS;
// Do not support fragmented packets as L4 headers may be missing
if (iph->frag_off & IP_FRAGMENTED)
return XDP_PASS; // TODO should we support it ?
// We only handle UDP traffic
if (iph->protocol != IPPROTO_UDP) {
return XDP_PASS;
struct udphdr *udp;
//udp = (void *) iph + iph->ihl * 4;
udp = iph + 1;
if (udp + 1 > data_end)
return XDP_DROP;
//__u16 udp_len = bpf_ntohs(udp->len);
__u16 udp_len = 8;
if (udp_len < 8)
return XDP_DROP;
if (udp_len > 512) // TODO use a more approriate max value
return XDP_DROP;
if ((void *) udp + udp_len > data_end)
return XDP_DROP;
// Is it ingress traffic ? destination IP == VIP
if (iph->daddr == VIP) {
if (!ports.lookup(&(udp->dest))) {
return XDP_PASS;
} else {
// Log packet before
struct packet pkt = {};
memcpy(&pkt, data, sizeof(pkt)); // crappy
pkt.daddr = iph->daddr;
pkt.saddr = iph->saddr;
// handle ingress traffic
// TODO support several real server
int i = 0;
struct server * server = realServers.lookup(&i);
if (server == NULL) {
return XDP_PASS;
memcpy(eth->h_dest, server->macAddr, 6);
iph->daddr = server->ipAddr;
} else
// Is it egress traffic ? source IP == VIP
if (iph->saddr == VIP) {
if (!ports.lookup(&(udp->source))) {
return XDP_PASS;
} else {
// Log packet before
struct packet pkt = {};
memcpy(&pkt, data, sizeof(pkt)); // crappy
pkt.daddr = iph->daddr;
pkt.saddr = iph->saddr;
// handle egress traffic
// TODO support several real server
int i = 0;
struct server * server = realServers.lookup(&i);
if (server == NULL) {
return XDP_PASS;
memcpy(eth->h_source, server->macAddr, 6);
iph->saddr = server->ipAddr;
} else {
return XDP_PASS;
// Update IP checksum
// TODO support IP header with variable size
iph->check = 0;
__u64 cs = 0 ;
ipv4_csum(iph, sizeof (*iph), &cs);
iph->check = cs;
// Update UDP checksum
udp->check = 0;
cs = 0;
ipv4_l4_csum(udp, udp_len, &cs, iph) ;
udp->check = cs;
// Log packet after
struct packet pkt = {};
memcpy(&pkt, data, sizeof(pkt)); // crappy
pkt.daddr = iph->daddr;
pkt.saddr = iph->saddr;
return XDP_TX;
from __future__ import print_function
from bcc import BPF
import ctypes as ct
import ipaddress
import socket
import argparse
import binascii
import struct
import re
# Utils
def ip_strton(ip_address):
# struct.unpack("I", socket.inet_aton(ip_address))[0]
return socket.htonl((int) (ipaddress.ip_address(ip_address)))
def ip_ntostr(ip_address):
if isinstance(ip_address, ct.c_uint):
ip_address = ip_address.value
return ipaddress.ip_address(socket.ntohl(ip_address))
def mac_strtob(mac_address):
bytes = binascii.unhexlify(mac_address.replace(':',''))
if len(bytes) is not 6:
raise TypeError("mac address must be a 6 bytes arrays")
return bytes
def mac_btostr(mac_address):
bytestr = bytes(mac_address).hex()
return ':'.join(bytestr[i:i+2] for i in range(0,12,2))
def ip_mac_tostr(mac_address, ip_address):
return "{}/{}".format(mac_btostr(mac_address),ip_ntostr(ip_address))
# Custom argument parser
def mac_ip_parser(s,pat=re.compile("^(.+?)/(.+)$")):
m = pat.match(s)
if not m:
raise argparse.ArgumentTypeError("Invalid address '{}': format is 'MAC_addr/IP_addr' (e.g. 5E:FF:56:A2:AF:15/".format(s))
mac = mac_strtob(
except Exception as e:
raise argparse.ArgumentTypeError("Invalid MAC address '{}' : {}".format(, str(e)))
ip = ip_strton(
except Exception as e:
raise argparse.ArgumentTypeError("Invalid IP address '{}' : {}".format(, str(e)))
return {"ip":ip,"mac":mac}
# Parse Arguments
parser = argparse.ArgumentParser()
parser.add_argument("ifnet", help="network interface to load balance (e.g. eth0)")
parser.add_argument("-vip", "--virtual_ip", help="<Required> The virtual IP of this loadbalancer", required=True)
parser.add_argument("-rs", "--real_server",type=mac_ip_parser, nargs=1, help="<Required> Real server addresse(s) e.g. 5E:FF:56:A2:AF:15/", required=True)
parser.add_argument("-p", "--port", type=int, nargs='+', help="<Required> UDP port(s) to load balance", required=True)
parser.add_argument("-d", "--debug", type=int, choices=[0, 1, 2, 3, 4],
help="Use to set bpf verbosity (0 is minimal)", default=0)
args = parser.parse_args()
# Get configuration from Arguments
ifnet = args.ifnet # network interface to attach xdp program
vip = ip_strton(args.virtual_ip) # virtual ip of load balancer
real_servers = args.real_server
ports = args.port # ports of to load balance
debug = args.debug # bpf verbosity
print("\nLoad balancing UDP traffic over {} interface for port(s) {} from :".format(ifnet, ports, ip_ntostr(vip)))
for real_server in real_servers:
print ("VIP:{} <=======> Real Server:{}".format(ip_ntostr(vip), ip_mac_tostr(real_server["mac"],real_server["ip"])))
# Shared structure used for perf_buffer
class Data(ct.Structure):
_fields_ = [
("dmac", ct.c_ubyte * 6),
("smac", ct.c_ubyte * 6),
("daddr", ct.c_uint),
("saddr", ct.c_uint)
# Compile & attach bpf program
b = BPF(src_file ="test.c", debug=debug, cflags=["-w", "-DVIP={}".format(vip), "-DCTXTYPE=xdp_md"])
fn = b.load_func("xdp_prog", BPF.XDP)
b.attach_xdp(ifnet, fn)
# Set Configurations
## Ports configs
ports_map = b["ports"]
for port in ports:
ports_map[ports_map.Key(socket.htons(port))] = ports_map.Leaf(True)
## Real servers configs
real_servers_map = b.get_table("realServers")
i = 0
for real_server in real_servers:
real_servers_map[real_servers_map.Key(i)] = real_servers_map.Leaf(real_server['ip'], (ct.c_ubyte * 6).from_buffer_copy(real_server['mac']))
# Utility function to print udp dest NAT.
def print_event(cpu, data, size):
event = ct.cast(data, ct.POINTER(Data)).contents
print("source {} --> dest {}".format(ip_mac_tostr(event.smac, event.saddr),ip_mac_tostr(event.dmac, event.daddr)))
# Loop to read perf buffer
while 1:
#(task, pid, cpu, flags, ts, msg) = b.trace_fields()
#print("%s \n" % (msg))
except ValueError:
except KeyboardInterrupt:
# Detach bpf progam
