Created
July 8, 2018 07:40
-
-
Save yunazuno/16a37a51de3f26c25f1182556c677971 to your computer and use it in GitHub Desktop.
PoC: Offloading L3 FIB into a NIC hardware through tc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import pyroute2 | |
import socket | |
from pyroute2.netlink import rtnl | |
import subprocess | |
import time | |
from operator import itemgetter | |
class TCL3Switch(object): | |
def __init__(self, l3mdev_ifname, block=1, chain=0): | |
self._ipr = pyroute2.IPRoute() | |
self._l3mdev_ifindex = self._get_ifindex(l3mdev_ifname) | |
self._vrf_table_id = self._get_vrf_table_id(self._l3mdev_ifindex) | |
self._l3mdev_slaves = self._get_l3mdev_slaves(self._l3mdev_ifindex) | |
self._block = block | |
self._chain = chain | |
def _get_ifindex(self, ifname): | |
try: | |
return self._ipr.link_lookup(ifname=ifname)[0] | |
except IndexError as e: | |
raise ValueError(f"ifname {ifname} is missing") from e | |
def _get_vrf_table_id(self, ifindex): | |
link = self._ipr.get_links(ifindex)[0] | |
for linkinfo in link.get_attrs("IFLA_LINKINFO"): | |
if linkinfo.get_attr("IFLA_INFO_KIND") == "vrf": | |
table_id = linkinfo.get_attr("IFLA_INFO_DATA").get_attr("IFLA_VRF_TABLE") | |
if table_id is not None: | |
return table_id | |
raise ValueError(f"Failed to find VRF table id for ifindex {l3mdev_ifindex}") | |
def _get_l3mdev_slaves(self, ifindex): | |
slaves = self._ipr.get_links(*self._ipr.link_lookup(master=ifindex)) | |
slave_map = dict([(l["index"], l.get_attr("IFLA_IFNAME")) for l in slaves]) | |
return slave_map | |
def _build_neighbour_flows(self): | |
neighbours = self._ipr.get_neighbours(state=rtnl.ndmsg.NUD_REACHABLE, family=socket.AF_INET) | |
for neigh in neighbours: | |
if neigh["ifindex"] in self._l3mdev_slaves.keys(): | |
flow = dict( | |
dst=neigh.get_attr("NDA_DST"), | |
dst_len=32, | |
action="redirect", | |
redirect_mac=neigh.get_attr("NDA_LLADDR"), | |
redirect_ifindex=neigh["ifindex"], | |
) | |
yield flow | |
def _build_route_flows(self): | |
routes = self._ipr.get_routes(table=self._vrf_table_id, family=socket.AF_INET) | |
for route in routes: | |
dst = route.get_attr("RTA_DST") | |
dst_len = route["dst_len"] | |
oif = route.get_attr("RTA_OIF") | |
gateway = route.get_attr("RTA_GATEWAY") | |
if dst is None and dst_len == 0: | |
# default route | |
dst = "0.0.0.0" | |
if route["type"] == rtnl.rt_type["unicast"] and gateway and oif: | |
try: | |
gateway_neigh = self._ipr.get_neighbours(state=rtnl.ndmsg.NUD_REACHABLE, ifindex=oif, | |
dst=gateway)[0] | |
flow = dict( | |
dst=dst, | |
dst_len=dst_len, | |
action="redirect", | |
redirect_mac=gateway_neigh.get_attr("NDA_LLADDR"), | |
redirect_ifindex=oif, | |
) | |
yield flow | |
except IndexError: | |
pass | |
def _build_flows(self): | |
flows = list(self._build_neighbour_flows()) + list(self._build_route_flows()) | |
flows.sort(key=itemgetter("dst_len"), reverse=True) | |
return flows | |
def _generate_flower_filters(self): | |
flows = self._build_flows() | |
for flow in flows: | |
command = f"flower dst_ip {flow['dst']}/{flow['dst_len']}" | |
if flow["action"] == "redirect": | |
ifname = self._l3mdev_slaves[flow['redirect_ifindex']] | |
command += f" action pedit ex munge eth dst set {flow['redirect_mac']}" | |
command += f" pipe mirred egress redirect dev {ifname}" | |
yield command | |
def set_ingress_qdisc(self): | |
for ifname in self._l3mdev_slaves.values(): | |
command = f"tc qdisc add dev {ifname} ingress_block {self._block} ingress" | |
subprocess.run(command, shell=True) | |
def install_filters(self, pref_start): | |
filters = list(self._generate_flower_filters()) | |
for pref, flower_filter in enumerate(filters, start=pref_start): | |
command = f"tc filter add block {self._block} protocol ip chain {self._chain} pref {pref} {flower_filter}" | |
subprocess.run(command, shell=True) | |
return len(filters) | |
def delete_filters(self, pref_start, num): | |
for pref in range(pref_start + num - 1, pref_start - 1, -1): | |
command = f"tc filter del block {self._block} protocol ip chain {self._chain} pref {pref}" | |
subprocess.run(command, shell=True) | |
def run(self, pref_offset=(1, 1001)): | |
pref_index = 0 | |
pref_start = pref_offset[pref_index] | |
num_old = 0 | |
while True: | |
num_new = self.install_filters(pref_start) | |
pref_index = (pref_index + 1) % 2 | |
pref_start = pref_offset[pref_index] | |
self.delete_filters(pref_start, num_old) | |
num_old = num_new | |
time.sleep(1) | |
if __name__ == '__main__': | |
import sys | |
l3mdev_ifname = sys.argv[1] | |
l3sw = TCL3Switch(l3mdev_ifname) | |
l3sw.set_ingress_qdisc() | |
l3sw.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment