Last active
March 11, 2023 18:04
-
-
Save sbernard31/d4fee7518a1ff130452211c0d355b3f7 to your computer and use it in GitHub Desktop.
UDP load balancer proto using bcc (XDP/Bpf)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#define KBUILD_MODNAME "foo" | |
#include <uapi/linux/bpf.h> | |
#include <linux/bpf.h> | |
#include <linux/icmp.h> | |
#include <linux/if_ether.h> | |
#include <linux/if_vlan.h> | |
#include <linux/in.h> | |
#include <linux/ip.h> | |
#include <linux/tcp.h> | |
#include <linux/udp.h> | |
/* 0x3FFF mask to check for fragment offset field */ | |
#define IP_FRAGMENTED 65343 | |
// MAC address | |
typedef unsigned char mac[6]; | |
// Real Server structure (MAC address + IP address) | |
struct server { | |
__be32 ipAddr; | |
unsigned char macAddr[ETH_ALEN]; | |
}; | |
// packet structure to log load balancing | |
struct packet { | |
unsigned char dmac[ETH_ALEN]; | |
unsigned char smac[ETH_ALEN]; | |
__be32 daddr; | |
__be32 saddr; | |
}; | |
BPF_PERF_OUTPUT(events); | |
__attribute__((__always_inline__)) | |
static inline __u16 csum_fold_helper(__u64 csum) { | |
int i; | |
#pragma unroll | |
for (i = 0; i < 4; i ++) { | |
if (csum >> 16) | |
csum = (csum & 0xffff) + (csum >> 16); | |
} | |
return ~csum; | |
} | |
__attribute__((__always_inline__)) | |
static inline void ipv4_csum_inline(void *iph, __u64 *csum) { | |
__u16 *next_iph_u16 = (__u16 *)iph; | |
#pragma clang loop unroll(full) | |
for (int i = 0; i < sizeof(struct iphdr) >> 1; i++) { | |
*csum += *next_iph_u16++; | |
} | |
*csum = csum_fold_helper(*csum); | |
} | |
__attribute__((__always_inline__)) | |
static inline void ipv4_csum(void *data_start, int data_size, __u64 *csum) { | |
*csum = bpf_csum_diff(0, 0, data_start, data_size, *csum); | |
*csum = csum_fold_helper(*csum); | |
} | |
__attribute__((__always_inline__)) | |
static inline void ipv4_l4_csum(void *data_start, __u32 data_size, | |
__u64 *csum, struct iphdr *iph) { | |
__u32 tmp = 0; | |
*csum = bpf_csum_diff(0, 0, &iph->saddr, sizeof(__be32), *csum); | |
*csum = bpf_csum_diff(0, 0, &iph->daddr, sizeof(__be32), *csum); | |
tmp = __builtin_bswap32((__u32)(iph->protocol)); | |
*csum = bpf_csum_diff(0, 0, &tmp, sizeof(__u32), *csum); | |
tmp = __builtin_bswap32((__u32)(data_size)); | |
*csum = bpf_csum_diff(0, 0, &tmp, sizeof(__u32), *csum); | |
*csum = bpf_csum_diff(0, 0, data_start, data_size, *csum); | |
*csum = csum_fold_helper(*csum); | |
} | |
// A map which contains port to redirect | |
BPF_HASH(ports, __be16, int, 10); // TODO how to we handle the max number of port we support. | |
// A map which contains real server | |
BPF_HASH(realServers, int, struct server, 10); // TODO how to we handle the max number of real server. | |
// Virtual IP is accessible via the 'VIP' constant | |
int xdp_prog(struct CTXTYPE *ctx) { | |
void *data_end = (void *)(long)ctx->data_end; | |
void *data = (void *)(long)ctx->data; | |
// https://git.kernel.org/pub/scm/linux/kernel/git/stable/linux.git/tree/include/uapi/linux/if_ether.h | |
struct ethhdr * eth = data; | |
if (eth + 1 > data_end) | |
return XDP_DROP; | |
// Handle only IP packets (v4?) | |
if (eth->h_proto != bpf_htons(ETH_P_IP)){ | |
return XDP_PASS; | |
} | |
// https://git.kernel.org/pub/scm/linux/kernel/git/stable/linux.git/tree/include/uapi/linux/ip.h | |
struct iphdr *iph; | |
iph = eth + 1; | |
if (iph + 1 > data_end) | |
return XDP_DROP; | |
// Minimum valid header length value is 5. | |
// see (https://tools.ietf.org/html/rfc791#section-3.1) | |
if (iph->ihl < 5) | |
return XDP_DROP; | |
// IP header size is variable because of options field. | |
// see (https://tools.ietf.org/html/rfc791#section-3.1) | |
//if ((void *) iph + iph->ihl * 4 > data_end) | |
// return XDP_DROP; | |
// TODO support IP header with variable size | |
if (iph->ihl != 5) | |
return XDP_PASS; | |
// Do not support fragmented packets as L4 headers may be missing | |
if (iph->frag_off & IP_FRAGMENTED) | |
return XDP_PASS; // TODO should we support it ? | |
// We only handle UDP traffic | |
if (iph->protocol != IPPROTO_UDP) { | |
return XDP_PASS; | |
} | |
// https://git.kernel.org/pub/scm/linux/kernel/git/stable/linux.git/tree/include/uapi/linux/udp.h | |
struct udphdr *udp; | |
//udp = (void *) iph + iph->ihl * 4; | |
udp = iph + 1; | |
if (udp + 1 > data_end) | |
return XDP_DROP; | |
//__u16 udp_len = bpf_ntohs(udp->len); | |
__u16 udp_len = 8; | |
if (udp_len < 8) | |
return XDP_DROP; | |
if (udp_len > 512) // TODO use a more approriate max value | |
return XDP_DROP; | |
if ((void *) udp + udp_len > data_end) | |
return XDP_DROP; | |
// Is it ingress traffic ? destination IP == VIP | |
if (iph->daddr == VIP) { | |
if (!ports.lookup(&(udp->dest))) { | |
return XDP_PASS; | |
} else { | |
// Log packet before | |
struct packet pkt = {}; | |
memcpy(&pkt, data, sizeof(pkt)); // crappy | |
pkt.daddr = iph->daddr; | |
pkt.saddr = iph->saddr; | |
events.perf_submit(ctx,&pkt,sizeof(pkt)); | |
// handle ingress traffic | |
// TODO support several real server | |
int i = 0; | |
struct server * server = realServers.lookup(&i); | |
if (server == NULL) { | |
return XDP_PASS; | |
} | |
memcpy(eth->h_dest, server->macAddr, 6); | |
iph->daddr = server->ipAddr; | |
} | |
} else | |
// Is it egress traffic ? source IP == VIP | |
if (iph->saddr == VIP) { | |
if (!ports.lookup(&(udp->source))) { | |
return XDP_PASS; | |
} else { | |
// Log packet before | |
struct packet pkt = {}; | |
memcpy(&pkt, data, sizeof(pkt)); // crappy | |
pkt.daddr = iph->daddr; | |
pkt.saddr = iph->saddr; | |
events.perf_submit(ctx,&pkt,sizeof(pkt)); | |
// handle egress traffic | |
// TODO support several real server | |
int i = 0; | |
struct server * server = realServers.lookup(&i); | |
if (server == NULL) { | |
return XDP_PASS; | |
} | |
memcpy(eth->h_source, server->macAddr, 6); | |
iph->saddr = server->ipAddr; | |
} | |
} else { | |
return XDP_PASS; | |
} | |
// Update IP checksum | |
// TODO support IP header with variable size | |
iph->check = 0; | |
__u64 cs = 0 ; | |
ipv4_csum(iph, sizeof (*iph), &cs); | |
iph->check = cs; | |
// Update UDP checksum | |
udp->check = 0; | |
cs = 0; | |
ipv4_l4_csum(udp, udp_len, &cs, iph) ; | |
udp->check = cs; | |
// Log packet after | |
struct packet pkt = {}; | |
memcpy(&pkt, data, sizeof(pkt)); // crappy | |
pkt.daddr = iph->daddr; | |
pkt.saddr = iph->saddr; | |
events.perf_submit(ctx,&pkt,sizeof(pkt)); | |
return XDP_TX; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
from __future__ import print_function | |
from bcc import BPF | |
import ctypes as ct | |
import ipaddress | |
import socket | |
import argparse | |
import binascii | |
import struct | |
import re | |
# Utils | |
def ip_strton(ip_address): | |
# struct.unpack("I", socket.inet_aton(ip_address))[0] | |
return socket.htonl((int) (ipaddress.ip_address(ip_address))) | |
def ip_ntostr(ip_address): | |
if isinstance(ip_address, ct.c_uint): | |
ip_address = ip_address.value | |
return ipaddress.ip_address(socket.ntohl(ip_address)) | |
def mac_strtob(mac_address): | |
bytes = binascii.unhexlify(mac_address.replace(':','')) | |
if len(bytes) is not 6: | |
raise TypeError("mac address must be a 6 bytes arrays") | |
return bytes | |
def mac_btostr(mac_address): | |
bytestr = bytes(mac_address).hex() | |
return ':'.join(bytestr[i:i+2] for i in range(0,12,2)) | |
def ip_mac_tostr(mac_address, ip_address): | |
return "{}/{}".format(mac_btostr(mac_address),ip_ntostr(ip_address)) | |
# Custom argument parser | |
def mac_ip_parser(s,pat=re.compile("^(.+?)/(.+)$")): | |
m = pat.match(s) | |
if not m: | |
raise argparse.ArgumentTypeError("Invalid address '{}': format is 'MAC_addr/IP_addr' (e.g. 5E:FF:56:A2:AF:15/10.40.0.1)".format(s)) | |
try: | |
mac = mac_strtob(m.group(1)) | |
except Exception as e: | |
raise argparse.ArgumentTypeError("Invalid MAC address '{}' : {}".format(m.group(1), str(e))) | |
try: | |
ip = ip_strton(m.group(2)) | |
except Exception as e: | |
raise argparse.ArgumentTypeError("Invalid IP address '{}' : {}".format(m.group(2), str(e))) | |
return {"ip":ip,"mac":mac} | |
# Parse Arguments | |
parser = argparse.ArgumentParser() | |
parser.add_argument("ifnet", help="network interface to load balance (e.g. eth0)") | |
parser.add_argument("-vip", "--virtual_ip", help="<Required> The virtual IP of this loadbalancer", required=True) | |
parser.add_argument("-rs", "--real_server",type=mac_ip_parser, nargs=1, help="<Required> Real server addresse(s) e.g. 5E:FF:56:A2:AF:15/10.40.0.1", required=True) | |
parser.add_argument("-p", "--port", type=int, nargs='+', help="<Required> UDP port(s) to load balance", required=True) | |
parser.add_argument("-d", "--debug", type=int, choices=[0, 1, 2, 3, 4], | |
help="Use to set bpf verbosity (0 is minimal)", default=0) | |
args = parser.parse_args() | |
# Get configuration from Arguments | |
ifnet = args.ifnet # network interface to attach xdp program | |
vip = ip_strton(args.virtual_ip) # virtual ip of load balancer | |
real_servers = args.real_server | |
ports = args.port # ports of to load balance | |
debug = args.debug # bpf verbosity | |
print("\nLoad balancing UDP traffic over {} interface for port(s) {} from :".format(ifnet, ports, ip_ntostr(vip))) | |
for real_server in real_servers: | |
print ("VIP:{} <=======> Real Server:{}".format(ip_ntostr(vip), ip_mac_tostr(real_server["mac"],real_server["ip"]))) | |
# Shared structure used for perf_buffer | |
class Data(ct.Structure): | |
_fields_ = [ | |
("dmac", ct.c_ubyte * 6), | |
("smac", ct.c_ubyte * 6), | |
("daddr", ct.c_uint), | |
("saddr", ct.c_uint) | |
] | |
# Compile & attach bpf program | |
b = BPF(src_file ="test.c", debug=debug, cflags=["-w", "-DVIP={}".format(vip), "-DCTXTYPE=xdp_md"]) | |
fn = b.load_func("xdp_prog", BPF.XDP) | |
b.attach_xdp(ifnet, fn) | |
# Set Configurations | |
## Ports configs | |
ports_map = b["ports"] | |
for port in ports: | |
ports_map[ports_map.Key(socket.htons(port))] = ports_map.Leaf(True) | |
## Real servers configs | |
real_servers_map = b.get_table("realServers") | |
i = 0 | |
for real_server in real_servers: | |
real_servers_map[real_servers_map.Key(i)] = real_servers_map.Leaf(real_server['ip'], (ct.c_ubyte * 6).from_buffer_copy(real_server['mac'])) | |
i+=1 | |
# Utility function to print udp dest NAT. | |
def print_event(cpu, data, size): | |
event = ct.cast(data, ct.POINTER(Data)).contents | |
print("source {} --> dest {}".format(ip_mac_tostr(event.smac, event.saddr),ip_mac_tostr(event.dmac, event.daddr))) | |
# Loop to read perf buffer | |
b["events"].open_perf_buffer(print_event) | |
while 1: | |
try: | |
b.perf_buffer_poll() | |
# DEBUG STUFF | |
#(task, pid, cpu, flags, ts, msg) = b.trace_fields() | |
#print("%s \n" % (msg)) | |
except ValueError: | |
continue | |
except KeyboardInterrupt: | |
break; | |
# Detach bpf progam | |
b.remove_xdp(ifnet) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
See a more advanced version of the code at : https://github.com/AirVantage/sbulb