Skip to content

Instantly share code, notes, and snippets.

@ammarfaizi2
Last active May 30, 2023 16:00
Show Gist options
  • Save ammarfaizi2/9a60f21bde326c5fe95cf4a87e46a1a9 to your computer and use it in GitHub Desktop.
Save ammarfaizi2/9a60f21bde326c5fe95cf4a87e46a1a9 to your computer and use it in GitHub Desktop.
// SPDX-License-Identifier: GPL-2.0-only
/*
* tc-nat.c - A simple NAT implementation using BPF.
*
* Copyright (C) 2023 Muhammad Aldin Setiawan <[email protected]>
* Copyright (C) 2023 Ammar Faizi <[email protected]>
*/
#include <linux/if_ether.h>
#include <linux/pkt_cls.h>
#include <linux/types.h>
#include <linux/icmp.h>
#include <linux/tcp.h>
#include <linux/bpf.h>
#include <linux/ipv6.h>
#include <linux/ip.h>
#include <bpf/bpf_helpers.h>
#include <stdint.h>
#define DEBUG_TC_NAT 1
#ifndef __section
#define __section(NAME) __attribute__((__section__(NAME), __used__))
#endif
#ifndef __maybe_unused
#define __maybe_unused __attribute__((__unused__))
#endif
#if DEBUG_TC_NAT
#define bp_debug(...) bpf_printk(__VA_ARGS__)
#else
#define bp_debug(...) do {} while (0)
#endif
#ifndef IPV4
#define IPV4(A, B, C, D) ((A) | (B << 8) | (C << 16) | (D << 24))
#endif
#ifndef BPF_FUNC
# define BPF_FUNC(NAME, ...) \
(* NAME)(__VA_ARGS__) = (void *) BPF_FUNC_##NAME
#endif
static int BPF_FUNC(skb_store_bytes, struct __sk_buff *skb, uint32_t off,
const void *from, uint32_t len, uint32_t flags);
static int BPF_FUNC(csum_diff, void *from, uint32_t from_size, void *to,
uint32_t to_size, uint32_t seed);
static int BPF_FUNC(l3_csum_replace, struct __sk_buff *skb, uint32_t off,
uint32_t from, uint32_t to, uint32_t flags);
static int BPF_FUNC(l4_csum_replace, struct __sk_buff *skb, uint32_t off,
uint32_t from, uint32_t to, uint32_t flags);
static const int L3_OFF = ETH_HLEN; // IP header offset
static const int L4_OFF = L3_OFF + 20; // TCP header offset: l3_off + sizeof(struct iphdr)
enum {
IPPROTO_ICMP = 1,
IPPROTO_TCP = 6
};
enum {
DI_EGRESS,
DI_INGRESS
};
struct ip_pkt {
uint8_t di_type;
struct __sk_buff *skb;
union {
struct iphdr *hdr;
struct ipv6hdr *hdr6;
};
void *data_end;
};
struct in_addr6 {
union {
__be32 o1;
};
};
struct ct4_icmp {
__be16 id;
__be16 pad;
__be32 src;
__be32 dst;
};
struct ct6_icmp {
__be16 id;
__be16 pad;
struct in_addr6 src;
struct in_addr6 dst;
};
struct {
__uint(type, BPF_MAP_TYPE_HASH);
__type(key, struct ct4_icmp);
__type(value, __be32);
__uint(max_entries, 1024);
__uint(pinning, LIBBPF_PIN_BY_NAME);
} CT4_MAP_ICMP SEC(".maps");
/*
* TODO(ammarfaizi2): Implement IPv6 conntrack.
*/
// struct {
// __uint(type, BPF_MAP_TYPE_HASH);
// __type(key, struct ct6_icmp);
// __type(value, __be32);
// __uint(max_entries, 1024);
// __uint(pinning, LIBBPF_PIN_BY_NAME);
// } CT6_MAP_ICMP SEC(".maps");
static __always_inline __u16 ntohs(uint16_t x)
{
return ((x & 0xff) << 8) | ((x >> 8) & 0xff);
}
static const __be32 SNAT_IP = IPV4(200, 0, 0, 50);
/*
* Do something similar to this:
*
* iptables -t nat -A POSTROUTING -p icmp -s 192.168.122.0/24 ! -d 192.168.122.0/24 -j MASQUERADE
*/
static void rule_icmp_masquerade(struct ip_pkt *pkt)
{
struct icmphdr *icmp = (void *)&pkt->hdr[1];
__be32 mask = IPV4(255, 255, 255, 0);
__be32 snat_ip = SNAT_IP;
__be32 sum;
if (pkt->data_end < (void *)&icmp[1] || pkt->hdr->protocol != IPPROTO_ICMP)
return;
if ((pkt->hdr->saddr & mask) != IPV4(192, 168, 122, 0)) {
/*
* It doesn't come from 192.168.122.0/24, skip!
*/
return;
}
if ((pkt->hdr->daddr & mask) == IPV4(192, 168, 122, 0)) {
/*
* Destination is in the same subnet, no need to SNAT.
*/
return;
}
bp_debug("SNAT: %x -> %x", pkt->hdr->saddr, snat_ip);
struct ct4_icmp key = {
.id = ntohs(icmp->un.echo.id),
.src = snat_ip,
.dst = pkt->hdr->daddr,
};
bpf_map_update_elem(&CT4_MAP_ICMP, &key, &pkt->hdr->saddr, BPF_ANY);
bp_debug("egress icmp_key={id=%d; src=%x; dst=%x;}", key.id, key.src, key.dst);
sum = csum_diff(&pkt->hdr->saddr, 4, &snat_ip, 4, 0);
skb_store_bytes(pkt->skb, L3_OFF + offsetof(struct iphdr, saddr), &snat_ip, 4, 0);
l3_csum_replace(pkt->skb, L3_OFF + offsetof(struct iphdr, check), 0, sum, 0);
}
static int handle_ct4_icmp(struct ip_pkt *pkt)
{
struct icmphdr *icmp = (void *)&pkt->hdr[1];
__be32 *orig_ip;
__be32 dnat_ip;
__be32 sum;
if (pkt->data_end < (void *)&icmp[1] || pkt->hdr->protocol != IPPROTO_ICMP)
return TC_ACT_OK;
struct ct4_icmp key = {
.id = ntohs(icmp->un.echo.id),
.src = SNAT_IP,
.dst = pkt->hdr->saddr,
};
orig_ip = bpf_map_lookup_elem(&CT4_MAP_ICMP, &key);
if (!orig_ip) {
bp_debug("No entry found for ICMP ID=%d", ntohs(icmp->un.echo.id));
return TC_ACT_OK;
}
bp_debug("ingress icmp_key={id=%d; src=%x; dst=%x;}", key.id, key.src, key.dst);
dnat_ip = *orig_ip;
bp_debug("DNAT: %x -> %x", pkt->hdr->daddr, dnat_ip);
sum = csum_diff(&pkt->hdr->daddr, 4, &dnat_ip, 4, 0);
skb_store_bytes(pkt->skb, L3_OFF + offsetof(struct iphdr, daddr), &dnat_ip, 4, 0);
l3_csum_replace(pkt->skb, L3_OFF + offsetof(struct iphdr, check), 0, sum, 0);
return TC_ACT_OK;
}
static void exec_nat_postrouting(struct ip_pkt *pkt)
{
/*
* Add your rules here.
*/
rule_icmp_masquerade(pkt);
}
static void exec_nat_prerouting(struct ip_pkt *pkt)
{
/*
* Add your rules here.
*/
}
static int handle_ip4_icmp(struct ip_pkt *pkt)
{
return TC_ACT_OK;
}
static int handle_ip4_packet(struct ip_pkt *pkt)
{
int ret;
switch (pkt->hdr->protocol) {
case IPPROTO_ICMP:
return handle_ct4_icmp(pkt);
default:
return TC_ACT_OK;
}
}
static int handle_egress_ipv4(struct ip_pkt *pkt)
{
int ret;
exec_nat_postrouting(pkt);
return ret;
}
static int handle_ingress_ipv4(struct ip_pkt *pkt)
{
int ret;
exec_nat_prerouting(pkt);
ret = handle_ip4_packet(pkt);
return ret;
}
static int handle_egress_ipv6(struct ip_pkt *pkt)
{
return TC_ACT_OK;
}
static int handle_ingress_ipv6(struct ip_pkt *pkt)
{
return TC_ACT_OK;
}
__section("egress")
int tc_egress(struct __sk_buff *skb)
{
void *data_end = (void *)(unsigned long)skb->data_end;
void *data = (void *)(unsigned long)skb->data;
struct ip_pkt pkt;
if (data_end < data + sizeof(struct ethhdr) + sizeof(struct iphdr))
return TC_ACT_OK;
pkt.skb = skb;
pkt.di_type = DI_EGRESS;
pkt.hdr = data + sizeof(struct ethhdr);
pkt.data_end = data_end;
switch (pkt.hdr->version) {
case 4:
return handle_egress_ipv4(&pkt);
case 6:
return handle_egress_ipv6(&pkt);
default:
bp_debug("Unknown IP version=%d", pkt.hdr->version);
return TC_ACT_OK;
}
}
__section("ingress")
int tc_ingress(struct __sk_buff *skb)
{
void *data_end = (void *)(unsigned long)skb->data_end;
void *data = (void *)(unsigned long)skb->data;
struct ip_pkt pkt;
if (data_end < data + sizeof(struct ethhdr) + sizeof(struct iphdr))
return TC_ACT_OK;
pkt.skb = skb;
pkt.di_type = DI_INGRESS;
pkt.hdr = data + sizeof(struct ethhdr);
pkt.data_end = data_end;
switch (pkt.hdr->version) {
case 4:
return handle_ingress_ipv4(&pkt);
case 6:
return handle_ingress_ipv6(&pkt);
default:
bp_debug("Unknown IP version=%d", pkt.hdr->version);
return TC_ACT_OK;
}
}
char __license[] __section("license") = "GPL";
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment