Last active
May 30, 2023 16:00
-
-
Save ammarfaizi2/9a60f21bde326c5fe95cf4a87e46a1a9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// SPDX-License-Identifier: GPL-2.0-only | |
/* | |
* tc-nat.c - A simple NAT implementation using BPF. | |
* | |
* Copyright (C) 2023 Muhammad Aldin Setiawan <[email protected]> | |
* Copyright (C) 2023 Ammar Faizi <[email protected]> | |
*/ | |
#include <linux/if_ether.h> | |
#include <linux/pkt_cls.h> | |
#include <linux/types.h> | |
#include <linux/icmp.h> | |
#include <linux/tcp.h> | |
#include <linux/bpf.h> | |
#include <linux/ipv6.h> | |
#include <linux/ip.h> | |
#include <bpf/bpf_helpers.h> | |
#include <stdint.h> | |
#define DEBUG_TC_NAT 1 | |
#ifndef __section | |
#define __section(NAME) __attribute__((__section__(NAME), __used__)) | |
#endif | |
#ifndef __maybe_unused | |
#define __maybe_unused __attribute__((__unused__)) | |
#endif | |
#if DEBUG_TC_NAT | |
#define bp_debug(...) bpf_printk(__VA_ARGS__) | |
#else | |
#define bp_debug(...) do {} while (0) | |
#endif | |
#ifndef IPV4 | |
#define IPV4(A, B, C, D) ((A) | (B << 8) | (C << 16) | (D << 24)) | |
#endif | |
#ifndef BPF_FUNC | |
# define BPF_FUNC(NAME, ...) \ | |
(* NAME)(__VA_ARGS__) = (void *) BPF_FUNC_##NAME | |
#endif | |
static int BPF_FUNC(skb_store_bytes, struct __sk_buff *skb, uint32_t off, | |
const void *from, uint32_t len, uint32_t flags); | |
static int BPF_FUNC(csum_diff, void *from, uint32_t from_size, void *to, | |
uint32_t to_size, uint32_t seed); | |
static int BPF_FUNC(l3_csum_replace, struct __sk_buff *skb, uint32_t off, | |
uint32_t from, uint32_t to, uint32_t flags); | |
static int BPF_FUNC(l4_csum_replace, struct __sk_buff *skb, uint32_t off, | |
uint32_t from, uint32_t to, uint32_t flags); | |
static const int L3_OFF = ETH_HLEN; // IP header offset | |
static const int L4_OFF = L3_OFF + 20; // TCP header offset: l3_off + sizeof(struct iphdr) | |
enum { | |
IPPROTO_ICMP = 1, | |
IPPROTO_TCP = 6 | |
}; | |
enum { | |
DI_EGRESS, | |
DI_INGRESS | |
}; | |
struct ip_pkt { | |
uint8_t di_type; | |
struct __sk_buff *skb; | |
union { | |
struct iphdr *hdr; | |
struct ipv6hdr *hdr6; | |
}; | |
void *data_end; | |
}; | |
struct in_addr6 { | |
union { | |
__be32 o1; | |
}; | |
}; | |
struct ct4_icmp { | |
__be16 id; | |
__be16 pad; | |
__be32 src; | |
__be32 dst; | |
}; | |
struct ct6_icmp { | |
__be16 id; | |
__be16 pad; | |
struct in_addr6 src; | |
struct in_addr6 dst; | |
}; | |
struct { | |
__uint(type, BPF_MAP_TYPE_HASH); | |
__type(key, struct ct4_icmp); | |
__type(value, __be32); | |
__uint(max_entries, 1024); | |
__uint(pinning, LIBBPF_PIN_BY_NAME); | |
} CT4_MAP_ICMP SEC(".maps"); | |
/* | |
* TODO(ammarfaizi2): Implement IPv6 conntrack. | |
*/ | |
// struct { | |
// __uint(type, BPF_MAP_TYPE_HASH); | |
// __type(key, struct ct6_icmp); | |
// __type(value, __be32); | |
// __uint(max_entries, 1024); | |
// __uint(pinning, LIBBPF_PIN_BY_NAME); | |
// } CT6_MAP_ICMP SEC(".maps"); | |
static __always_inline __u16 ntohs(uint16_t x) | |
{ | |
return ((x & 0xff) << 8) | ((x >> 8) & 0xff); | |
} | |
static const __be32 SNAT_IP = IPV4(200, 0, 0, 50); | |
/* | |
* Do something similar to this: | |
* | |
* iptables -t nat -A POSTROUTING -p icmp -s 192.168.122.0/24 ! -d 192.168.122.0/24 -j MASQUERADE | |
*/ | |
static void rule_icmp_masquerade(struct ip_pkt *pkt) | |
{ | |
struct icmphdr *icmp = (void *)&pkt->hdr[1]; | |
__be32 mask = IPV4(255, 255, 255, 0); | |
__be32 snat_ip = SNAT_IP; | |
__be32 sum; | |
if (pkt->data_end < (void *)&icmp[1] || pkt->hdr->protocol != IPPROTO_ICMP) | |
return; | |
if ((pkt->hdr->saddr & mask) != IPV4(192, 168, 122, 0)) { | |
/* | |
* It doesn't come from 192.168.122.0/24, skip! | |
*/ | |
return; | |
} | |
if ((pkt->hdr->daddr & mask) == IPV4(192, 168, 122, 0)) { | |
/* | |
* Destination is in the same subnet, no need to SNAT. | |
*/ | |
return; | |
} | |
bp_debug("SNAT: %x -> %x", pkt->hdr->saddr, snat_ip); | |
struct ct4_icmp key = { | |
.id = ntohs(icmp->un.echo.id), | |
.src = snat_ip, | |
.dst = pkt->hdr->daddr, | |
}; | |
bpf_map_update_elem(&CT4_MAP_ICMP, &key, &pkt->hdr->saddr, BPF_ANY); | |
bp_debug("egress icmp_key={id=%d; src=%x; dst=%x;}", key.id, key.src, key.dst); | |
sum = csum_diff(&pkt->hdr->saddr, 4, &snat_ip, 4, 0); | |
skb_store_bytes(pkt->skb, L3_OFF + offsetof(struct iphdr, saddr), &snat_ip, 4, 0); | |
l3_csum_replace(pkt->skb, L3_OFF + offsetof(struct iphdr, check), 0, sum, 0); | |
} | |
static int handle_ct4_icmp(struct ip_pkt *pkt) | |
{ | |
struct icmphdr *icmp = (void *)&pkt->hdr[1]; | |
__be32 *orig_ip; | |
__be32 dnat_ip; | |
__be32 sum; | |
if (pkt->data_end < (void *)&icmp[1] || pkt->hdr->protocol != IPPROTO_ICMP) | |
return TC_ACT_OK; | |
struct ct4_icmp key = { | |
.id = ntohs(icmp->un.echo.id), | |
.src = SNAT_IP, | |
.dst = pkt->hdr->saddr, | |
}; | |
orig_ip = bpf_map_lookup_elem(&CT4_MAP_ICMP, &key); | |
if (!orig_ip) { | |
bp_debug("No entry found for ICMP ID=%d", ntohs(icmp->un.echo.id)); | |
return TC_ACT_OK; | |
} | |
bp_debug("ingress icmp_key={id=%d; src=%x; dst=%x;}", key.id, key.src, key.dst); | |
dnat_ip = *orig_ip; | |
bp_debug("DNAT: %x -> %x", pkt->hdr->daddr, dnat_ip); | |
sum = csum_diff(&pkt->hdr->daddr, 4, &dnat_ip, 4, 0); | |
skb_store_bytes(pkt->skb, L3_OFF + offsetof(struct iphdr, daddr), &dnat_ip, 4, 0); | |
l3_csum_replace(pkt->skb, L3_OFF + offsetof(struct iphdr, check), 0, sum, 0); | |
return TC_ACT_OK; | |
} | |
static void exec_nat_postrouting(struct ip_pkt *pkt) | |
{ | |
/* | |
* Add your rules here. | |
*/ | |
rule_icmp_masquerade(pkt); | |
} | |
static void exec_nat_prerouting(struct ip_pkt *pkt) | |
{ | |
/* | |
* Add your rules here. | |
*/ | |
} | |
static int handle_ip4_icmp(struct ip_pkt *pkt) | |
{ | |
return TC_ACT_OK; | |
} | |
static int handle_ip4_packet(struct ip_pkt *pkt) | |
{ | |
int ret; | |
switch (pkt->hdr->protocol) { | |
case IPPROTO_ICMP: | |
return handle_ct4_icmp(pkt); | |
default: | |
return TC_ACT_OK; | |
} | |
} | |
static int handle_egress_ipv4(struct ip_pkt *pkt) | |
{ | |
int ret; | |
exec_nat_postrouting(pkt); | |
return ret; | |
} | |
static int handle_ingress_ipv4(struct ip_pkt *pkt) | |
{ | |
int ret; | |
exec_nat_prerouting(pkt); | |
ret = handle_ip4_packet(pkt); | |
return ret; | |
} | |
static int handle_egress_ipv6(struct ip_pkt *pkt) | |
{ | |
return TC_ACT_OK; | |
} | |
static int handle_ingress_ipv6(struct ip_pkt *pkt) | |
{ | |
return TC_ACT_OK; | |
} | |
__section("egress") | |
int tc_egress(struct __sk_buff *skb) | |
{ | |
void *data_end = (void *)(unsigned long)skb->data_end; | |
void *data = (void *)(unsigned long)skb->data; | |
struct ip_pkt pkt; | |
if (data_end < data + sizeof(struct ethhdr) + sizeof(struct iphdr)) | |
return TC_ACT_OK; | |
pkt.skb = skb; | |
pkt.di_type = DI_EGRESS; | |
pkt.hdr = data + sizeof(struct ethhdr); | |
pkt.data_end = data_end; | |
switch (pkt.hdr->version) { | |
case 4: | |
return handle_egress_ipv4(&pkt); | |
case 6: | |
return handle_egress_ipv6(&pkt); | |
default: | |
bp_debug("Unknown IP version=%d", pkt.hdr->version); | |
return TC_ACT_OK; | |
} | |
} | |
__section("ingress") | |
int tc_ingress(struct __sk_buff *skb) | |
{ | |
void *data_end = (void *)(unsigned long)skb->data_end; | |
void *data = (void *)(unsigned long)skb->data; | |
struct ip_pkt pkt; | |
if (data_end < data + sizeof(struct ethhdr) + sizeof(struct iphdr)) | |
return TC_ACT_OK; | |
pkt.skb = skb; | |
pkt.di_type = DI_INGRESS; | |
pkt.hdr = data + sizeof(struct ethhdr); | |
pkt.data_end = data_end; | |
switch (pkt.hdr->version) { | |
case 4: | |
return handle_ingress_ipv4(&pkt); | |
case 6: | |
return handle_ingress_ipv6(&pkt); | |
default: | |
bp_debug("Unknown IP version=%d", pkt.hdr->version); | |
return TC_ACT_OK; | |
} | |
} | |
char __license[] __section("license") = "GPL"; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment