Last active
December 4, 2018 18:00
-
-
Save dmfigol/48f7bbd2aeaeac88648865679dca13f3 to your computer and use it in GitHub Desktop.
Convert logs to firewall rules
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Script which converts Splunk logs to different firewall objects and | |
prints them on the console | |
Copyright (c) 2018 Cisco and/or its affiliates. | |
This software is licensed to you under the terms of the Cisco Sample | |
Code License, Version 1.0 (the "License"). You may obtain a copy of the | |
License at | |
https://developer.cisco.com/docs/licenses | |
All use of the material herein must be in accordance with the terms of | |
the License. All rights not expressly granted by the License are | |
reserved. Unless required by applicable law or agreed to separately in | |
writing, software distributed under the License is distributed on an "AS | |
IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express | |
or implied. | |
""" | |
__license__ = "Cisco Sample Code License, Version 1.0" | |
from collections import defaultdict | |
from csv import DictReader | |
from glob import glob | |
from typing import ( | |
Any, | |
Dict, | |
Optional, | |
DefaultDict, | |
ValuesView, | |
Tuple, | |
List, | |
Set, | |
Iterable, | |
cast, | |
) | |
CLIENT_SERVER_SERVICES_PORTS = {5000, 80} | |
BIDIR_SERVICES_PORTS = {179} | |
UnknownServiceEntriesType = DefaultDict[ | |
Tuple[str, int], List[Dict[str, Any]] | |
] # (ip, port) -> [row1, ...] | |
UnknownProtocolEntriesType = DefaultDict[ | |
Tuple[str, str, str], List[Dict[str, Any]] | |
] # (src_ip, dest_ip, protocol) -> [row1, ...] | |
class DatacenterFirewallRules: | |
def __init__(self) -> None: | |
self.name_to_node: Dict[str, "Node"] = {} | |
self.unknown_service_entries: UnknownServiceEntriesType = defaultdict(list) | |
self.unknown_protocol_entries: UnknownProtocolEntriesType = defaultdict(list) | |
def get_node(self, name: str) -> Any: | |
return self.name_to_node.get(name) | |
@property | |
def nodes(self) -> ValuesView["Node"]: | |
return self.name_to_node.values() | |
def process_logs(self) -> None: | |
for filename in glob("*.csv"): | |
with open(filename) as f: | |
reader = DictReader(f) | |
for row in reader: | |
self.process_csv_row(row) | |
def process_csv_row(self, row: Dict[str, Any]) -> None: | |
node_name = row["dvc"] | |
node = self.name_to_node.setdefault(node_name, Node(node_name)) | |
interface_name = row["src_interface"] | |
protocol = row["transport"] | |
hosts: Tuple[str, ...] | |
src_ip = cast(str, row["src_ip"]) | |
src_port_str = row["src_port"] | |
src_port: Optional[int] = None | |
if src_port_str: | |
src_port = int(src_port_str) | |
dest_ip = cast(str, row["dest_ip"]) | |
dest_port_str = row["dest_port"] | |
dest_port: Optional[int] = None | |
if dest_port_str: | |
dest_port = int(dest_port_str) | |
if protocol in {"udp", "tcp"}: | |
acl_src_ip: Optional[str] = src_ip | |
acl_src_obj_group = None | |
acl_src_port = src_port | |
acl_dest_ip: Optional[str] = dest_ip | |
acl_dest_obj_group = None | |
acl_dest_port = dest_port | |
prot = protocol.upper() | |
interface = node.get_or_create_interface(interface_name) | |
acl = interface.applied_acls["in"] | |
if src_port in BIDIR_SERVICES_PORTS or dest_port in BIDIR_SERVICES_PORTS: | |
if src_port in BIDIR_SERVICES_PORTS: | |
service_port = src_port | |
elif dest_port in BIDIR_SERVICES_PORTS: | |
service_port = dest_port | |
network_obj_group_name = f"AUTO-{prot}-{src_port}-SERVERS_CLIENTS" | |
hosts = (src_ip, dest_ip) | |
network_obj_group = node.get_or_create_network_obj_group( | |
network_obj_group_name, hosts | |
) | |
acl_src_obj_group = network_obj_group | |
acl_dest_obj_group = network_obj_group | |
ace = AccessListEntry( | |
acl, | |
action="permit", | |
protocol=protocol, | |
src_obj_group=acl_src_obj_group, | |
dest_obj_group=acl_dest_obj_group, | |
src_port=service_port, | |
) | |
mirror_ace = AccessListEntry( | |
acl, | |
action="permit", | |
protocol=protocol, | |
src_obj_group=acl_src_obj_group, | |
dest_obj_group=acl_dest_obj_group, | |
dest_port=service_port, | |
) | |
acl.add_access_list_entry(ace) | |
acl.add_access_list_entry(mirror_ace) | |
elif ( | |
src_port in CLIENT_SERVER_SERVICES_PORTS | |
or dest_port in CLIENT_SERVER_SERVICES_PORTS | |
): | |
if src_port in CLIENT_SERVER_SERVICES_PORTS: | |
network_obj_group_name = f"AUTO-{src_ip}-{prot}-{src_port}-CLIENTS" | |
hosts = (dest_ip,) | |
network_obj_group = node.get_or_create_network_obj_group( | |
network_obj_group_name, hosts | |
) | |
acl_dest_ip = None | |
acl_dest_obj_group = network_obj_group | |
acl_dest_port = None | |
elif dest_port in CLIENT_SERVER_SERVICES_PORTS: | |
network_obj_group_name = ( | |
f"AUTO-{dest_ip}-{prot}-{dest_port}-CLIENTS" | |
) | |
hosts = (src_ip,) | |
network_obj_group = node.get_or_create_network_obj_group( | |
network_obj_group_name, hosts | |
) | |
acl_src_ip = None | |
acl_src_obj_group = network_obj_group | |
acl_src_port = None | |
ace = AccessListEntry( | |
acl, | |
action="permit", | |
protocol=protocol, | |
src_ip=acl_src_ip, | |
src_obj_group=acl_src_obj_group, | |
src_port=acl_src_port, | |
dest_ip=acl_dest_ip, | |
dest_obj_group=acl_dest_obj_group, | |
dest_port=acl_dest_port, | |
) | |
acl.add_access_list_entry(ace) | |
else: | |
src_port = cast(int, src_port) | |
dest_port = cast(int, dest_port) | |
self.unknown_service_entries[(src_ip, src_port)].append(row) | |
self.unknown_service_entries[(dest_ip, dest_port)].append(row) | |
else: | |
self.unknown_protocol_entries[(src_ip, dest_ip, protocol)].append(row) | |
def show(self) -> None: | |
for node in self.nodes: | |
node.show() | |
print("=" * 20) | |
if self.unknown_service_entries: | |
print("Unknown services:") | |
unknown_service_entries_num = [ | |
(key, len(rows)) for key, rows in self.unknown_service_entries.items() | |
] | |
for (service_ip, service_port), num_rows in sorted( | |
unknown_service_entries_num, key=lambda item: -item[1] | |
): | |
print(f"* {service_ip}:{service_port} - {num_rows} entries") | |
if self.unknown_protocol_entries: | |
print("Unknown protocols:") | |
unknown_protocol_entries_num = [ | |
(key, len(rows)) for key, rows in self.unknown_protocol_entries.items() | |
] | |
for (src_ip, dest_ip, protocol), num_rows in sorted( | |
unknown_protocol_entries_num, key=lambda item: -item[1] | |
): | |
print(f"* {protocol}: {src_ip} -> {dest_ip} - {num_rows} entries") | |
class NetworkObjectGroup: | |
def __init__(self, name: str) -> None: | |
self.name = name | |
self.hosts: Set[str] = set() | |
def __repr__(self) -> str: | |
return f"{self.__class__.__qualname__}(name={self.name!r})" | |
@property | |
def contains_only_clients(self) -> bool: | |
if self.name.endswith("-CLIENTS"): | |
return True | |
else: | |
return False | |
class AccessListEntry: | |
CMP_ATTRS = ( | |
"acl", | |
"action", | |
"protocol", | |
"src_ip", | |
"src_obj_group", | |
"src_port", | |
"dest_ip", | |
"dest_obj_group", | |
"dest_port", | |
) | |
def __init__( | |
self, | |
acl: "AccessList", | |
action: str = "permit", | |
protocol: str = "ip", | |
seq_num: Optional[int] = None, | |
src_ip: Optional[str] = None, | |
src_obj_group: Optional[NetworkObjectGroup] = None, | |
src_port: Optional[int] = None, | |
dest_ip: Optional[str] = None, | |
dest_obj_group: Optional[NetworkObjectGroup] = None, | |
dest_port: Optional[int] = None, | |
) -> None: | |
self.acl = acl | |
self.action = action | |
self.protocol = protocol | |
self.seq_num = seq_num | |
self.src_ip = src_ip | |
if src_ip is None and src_obj_group is None: | |
self.src_ip = "any" | |
self.src_obj_group = src_obj_group | |
self.src_port = src_port | |
self.dest_ip = dest_ip | |
if dest_ip is None and dest_obj_group is None: | |
self.dest_ip = "any" | |
self.dest_obj_group = dest_obj_group | |
self.dest_port = dest_port | |
def __eq__(self, other) -> bool: | |
return all(getattr(self, key) == getattr(other, key) for key in self.CMP_ATTRS) | |
def __hash__(self) -> int: | |
return hash(tuple(getattr(self, key) for key in self.CMP_ATTRS)) | |
def __str__(self) -> str: | |
if self.seq_num: | |
seq_num = "{self.seq_num} " | |
else: | |
seq_num = "" | |
src = self.src_ip | |
if not src: | |
if self.src_obj_group is None: | |
raise ValueError("src_ip and src_obj_group can't be both None") | |
src = self.src_obj_group.name | |
if self.src_port: | |
src_port = f":{self.src_port}" | |
else: | |
src_port = "" | |
dest = self.dest_ip | |
if not dest: | |
if self.dest_obj_group is None: | |
raise ValueError("dest_ip and dest_obj_group can't be both None") | |
dest = self.dest_obj_group.name | |
if self.dest_port: | |
dest_port = f":{self.dest_port}" | |
else: | |
dest_port = "" | |
return ( | |
f"{seq_num}{self.action} " | |
f"{self.protocol} {src}{src_port} -> " | |
f"{dest}{dest_port}" | |
) | |
class AccessList: | |
def __init__(self, name: str, node: "Node") -> None: | |
self.name = name | |
self.node = node | |
self.ace: List[AccessListEntry] = [] | |
self.ace_set: Set[AccessListEntry] = set() # for quick membership check | |
def __repr__(self) -> str: | |
return f"{self.__class__.__qualname__}(name={self.name!r})" | |
def __eq__(self, other) -> bool: | |
return self.name == other.name and self.ace_set == other.ace_set | |
def __hash__(self) -> int: | |
return hash(self.name) | |
def __str__(self) -> str: | |
return f"Access-list {self.name!r}:\n" f"{self.ace_str}" | |
def __contains__(self, ace: object) -> bool: | |
if not isinstance(ace, AccessListEntry): | |
raise NotImplementedError | |
return ace in self.ace_set | |
@property | |
def ace_str(self) -> str: | |
return "\n".join(f"* {ace}" for ace in self.ace) | |
def add_access_list_entry(self, ace: AccessListEntry) -> bool: | |
if ace in self.ace_set: | |
return False | |
else: | |
self.ace.append(ace) | |
self.ace_set.add(ace) | |
return True | |
class Interface: | |
def __init__(self, name: str, node: "Node") -> None: | |
self.node = node | |
self.name = name | |
name_upper = name.upper() | |
acl_in_name = f"{name_upper}-IN" | |
acl_in = node.get_or_create_access_list(acl_in_name) | |
self.applied_acls: Dict[str, AccessList] = { | |
"in": acl_in | |
} # direction -> AccessList | |
def __repr__(self) -> str: | |
return f"{self.__class__.__qualname__}(name={self.name!r})" | |
class Node: | |
def __init__(self, name: str) -> None: | |
self.name = name | |
self.name_to_interface: Dict[str, Interface] = {} | |
self.name_to_acl: Dict[str, AccessList] = {} | |
self.network_obj_groups_mapping: Dict[str, NetworkObjectGroup] = {} | |
def get_network_obj_group(self, name: str) -> NetworkObjectGroup: | |
return self.network_obj_groups_mapping[name] | |
def add_network_obj_group(self, obj: NetworkObjectGroup) -> None: | |
self.network_obj_groups_mapping[obj.name] = obj | |
@property | |
def network_obj_groups(self) -> ValuesView[NetworkObjectGroup]: | |
return self.network_obj_groups_mapping.values() | |
@property | |
def interfaces(self) -> ValuesView[Interface]: | |
return self.name_to_interface.values() | |
def get_or_create_interface(self, name: str) -> Interface: | |
if name in self.name_to_interface: | |
interface = self.name_to_interface[name] | |
else: | |
interface = Interface(name, self) | |
self.name_to_interface[name] = interface | |
return interface | |
def get_or_create_access_list(self, name: str) -> AccessList: | |
if name in self.name_to_acl: | |
acl = self.name_to_acl[name] | |
else: | |
acl = AccessList(name, self) | |
self.name_to_acl[name] = acl | |
return acl | |
def get_or_create_network_obj_group( | |
self, name: str, hosts: Iterable[str] | |
) -> NetworkObjectGroup: | |
if name in self.network_obj_groups_mapping: | |
network_obj_group = self.get_network_obj_group(name) | |
else: | |
network_obj_group = NetworkObjectGroup(name) | |
self.add_network_obj_group(network_obj_group) | |
network_obj_group.hosts.update(hosts) | |
return network_obj_group | |
def __repr__(self) -> str: | |
return f"{self.__class__.__qualname__}(name={self.name!r})" | |
def show(self) -> None: | |
print(f"=== Node name: {self.name!r} ===") | |
for network_obj_group in self.network_obj_groups: | |
print(f"Network object group {network_obj_group.name!r} contains hosts:") | |
for host in sorted(network_obj_group.hosts): | |
print(f"* {host}") | |
print() | |
for interface in self.interfaces: | |
for direction, acl in interface.applied_acls.items(): | |
print( | |
f"Interface {interface.name!r} has ACL {acl.name!r} applied, direction: {direction!r}:" | |
) | |
print(acl.ace_str, end="\n\n") | |
def main() -> None: | |
rules = DatacenterFirewallRules() | |
rules.process_logs() | |
rules.show() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment