Last active
February 16, 2025 14:59
-
-
Save gary23w/3bbb51566a7ba1a7bbc1a7e2d384b8b6 to your computer and use it in GitHub Desktop.
simple-packet-sniffer.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
simplepacketsniffer.py 0.4 | |
Usage examples: | |
Dataset Building Mode (with LLM labeling): | |
sudo python simplepacketsniffer.py --build-dataset --llm-label --num-samples 10 --dataset-out training_data.csv -i INTERFACE -v INFO [--xdp] | |
Training Mode: | |
sudo python simplepacketsniffer.py --train --dataset training_data.csv --model-path model.pkl -v INFO | |
Sniffer (Protection) Mode: | |
sudo python simplepacketsniffer.py -i INTERFACE --red-alert --whitelist US,CA --sentry --model-path model.pkl -v DEBUG [--xdp] | |
Expensive Mode (Real-time LLM analysis for every packet): | |
sudo python simplepacketsniffer.py --expensive -i INTERFACE -v DEBUG [--xdp] | |
***logging.yaml | |
version: 1 | |
disable_existing_loggers: false | |
formatters: | |
standard: | |
format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
simple: | |
format: '%(asctime)s - %(levelname)s - %(message)s' | |
handlers: | |
console: | |
class: logging.StreamHandler | |
level: DEBUG | |
formatter: standard | |
stream: ext://sys.stdout | |
file_handler: | |
class: logging.handlers.RotatingFileHandler | |
level: DEBUG | |
formatter: standard | |
filename: sniffer.log | |
maxBytes: 1048576 | |
backupCount: 3 | |
flagged_ip_handler: | |
class: logging.handlers.RotatingFileHandler | |
level: WARNING | |
formatter: standard | |
filename: flagged_ips.log | |
maxBytes: 1048576 | |
backupCount: 3 | |
dns_handler: | |
class: logging.handlers.RotatingFileHandler | |
level: INFO | |
formatter: standard | |
filename: dns_queries.log | |
maxBytes: 1048576 | |
backupCount: 3 | |
loggers: | |
PacketSniffer: | |
level: DEBUG | |
handlers: [console, file_handler] | |
propagate: no | |
FlaggedIPLogger: | |
level: WARNING | |
handlers: [flagged_ip_handler, console] | |
propagate: no | |
DNSQueryLogger: | |
level: INFO | |
handlers: [dns_handler, console] | |
propagate: no | |
root: | |
level: DEBUG | |
handlers: [console] | |
*** | |
""" | |
import argparse | |
import asyncio | |
import aiohttp | |
import logging | |
import logging.config | |
import yaml | |
import re | |
import ipaddress | |
import time | |
import threading | |
import struct | |
import os | |
import csv | |
import string | |
from collections import Counter | |
from cachetools import TTLCache | |
from typing import Tuple, Dict, Any, List, Optional | |
import numpy as np | |
import scapy.all as scapy | |
from scapy.all import sniff, Ether, IP, IPv6, TCP, UDP, ICMP, ARP, Raw | |
from scapy.layers.dns import DNS, DNSQR | |
try: | |
from scapy.layers.tls.all import TLS, TLSClientHello, TLSHandshake | |
except ImportError: | |
TLS = None | |
TLSClientHello = None | |
TLSHandshake = None | |
try: | |
from openai import AsyncOpenAI | |
except ImportError: | |
AsyncOpenAI = None | |
try: | |
from bcc import BPF | |
except ImportError: | |
BPF = None | |
last_xdp_metadata: Optional[Dict[str, Any]] = None | |
def setup_logging(config_file: str = "logging.yaml") -> None: | |
try: | |
with open(config_file, "r") as f: | |
config = yaml.safe_load(f.read()) | |
logging.config.dictConfig(config) | |
except Exception as e: | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(levelname)s - %(message)s" | |
) | |
logging.getLogger(__name__).exception( | |
"Failed to load logging config from %s, using basic config: %s", | |
config_file, e | |
) | |
class AsyncRunner: | |
def __init__(self) -> None: | |
self.loop = asyncio.new_event_loop() | |
self.thread = threading.Thread(target=self._run_loop, daemon=True) | |
self.thread.start() | |
def _run_loop(self) -> None: | |
asyncio.set_event_loop(self.loop) | |
self.loop.run_forever() | |
def run_coroutine(self, coro): | |
return asyncio.run_coroutine_threadsafe(coro, self.loop) | |
class IPLookup: | |
def __init__(self, ttl: int = 3600, cache_size: int = 1000, logger: Optional[logging.Logger] = None) -> None: | |
self.cache = TTLCache(maxsize=cache_size, ttl=ttl) | |
self.logger = logger or logging.getLogger(__name__) | |
async def fetch_ip_info(self, session: aiohttp.ClientSession, ip: str) -> Dict[str, Any]: | |
if ip in self.cache: | |
return self.cache[ip] | |
try: | |
async with session.get(f"https://ipinfo.io/{ip}/json", timeout=5) as resp: | |
if resp.status == 200: | |
data = await resp.json() | |
self.cache[ip] = data | |
return data | |
else: | |
self.logger.warning("IP lookup for %s returned status %s", ip, resp.status) | |
except aiohttp.ClientError as e: | |
self.logger.exception("Network error during IP lookup for %s: %s", ip, e) | |
except Exception: | |
self.logger.exception("Unexpected error during IP lookup for %s", ip) | |
return {} | |
async def get_ip_info(self, ip: str) -> Dict[str, Any]: | |
async with aiohttp.ClientSession() as session: | |
return await self.fetch_ip_info(session, ip) | |
class DNSParser: | |
def __init__(self, logger: Optional[logging.Logger] = None, blacklist: Optional[List[str]] = None) -> None: | |
self.logger = logger or logging.getLogger(__name__) | |
self.blacklist = set(blacklist or ['ipinfo.io', 'ipinfo.io.']) # TODO: add wildcard domains | |
def is_blacklisted(self, domain: str) -> bool: | |
return domain in self.blacklist | |
def parse_dns_name(self, payload: bytes, offset: int) -> Tuple[str, int]: | |
labels = [] | |
while True: | |
if offset >= len(payload): | |
break | |
length = payload[offset] | |
if (length & 0xC0) == 0xC0: | |
if offset + 1 >= len(payload): | |
break | |
pointer = ((length & 0x3F) << 8) | payload[offset + 1] | |
pointed_name, _ = self.parse_dns_name(payload, pointer) | |
labels.append(pointed_name) | |
offset += 2 | |
break | |
if length == 0: | |
offset += 1 | |
break | |
offset += 1 | |
label = payload[offset: offset + length].decode('utf-8', errors='replace') | |
labels.append(label) | |
offset += length | |
domain_name = ".".join(labels) | |
return domain_name, offset | |
def parse_dns_payload(self, payload: bytes) -> None: | |
self.logger.info("Parsing DNS payload...") | |
if len(payload) < 12: | |
self.logger.warning("DNS payload too short to parse.") | |
return | |
try: | |
transaction_id, flags, qdcount, ancount, nscount, arcount = struct.unpack("!HHHHHH", payload[:12]) | |
self.logger.info("DNS Header: ID=%#04x, Flags=%#04x, QD=%d, AN=%d, NS=%d, AR=%d", | |
transaction_id, flags, qdcount, ancount, nscount, arcount) | |
except Exception as e: | |
self.logger.exception("Error parsing DNS header: %s", e) | |
return | |
offset = 12 | |
for i in range(qdcount): | |
try: | |
domain, offset = self.parse_dns_name(payload, offset) | |
if offset + 4 > len(payload): | |
self.logger.warning("DNS question truncated.") | |
break | |
qtype, qclass = struct.unpack("!HH", payload[offset:offset + 4]) | |
offset += 4 | |
self.logger.info("DNS Question %d: %s, type: %d, class: %d", i + 1, domain, qtype, qclass) | |
except Exception as e: | |
self.logger.exception("Error parsing DNS question: %s", e) | |
break | |
class PayloadParser: | |
def __init__(self, logger: Optional[logging.Logger] = None) -> None: | |
self.logger = logger or logging.getLogger(__name__) | |
self.SIGNATURES = { | |
"504B0304": "ZIP archive / Office Open XML document (DOCX, XLSX, PPTX)", | |
"504B030414000600": "Office Open XML (DOCX, XLSX, PPTX) extended header", | |
"1F8B08": "GZIP archive", | |
"377ABCAF271C": "7-Zip archive", | |
"52617221": "RAR archive", | |
"425A68": "BZIP2 archive", | |
"213C617263683E0A": "Ar (UNIX archive) / Debian package", | |
"7F454C46": "ELF executable (Unix/Linux)", | |
"4D5A": "Windows executable (EXE, MZ header / DLL)", | |
"CAFEBABE": "Java class file or Mach-O Fat Binary (ambiguous)", | |
"FEEDFACE": "Mach-O executable (32-bit, little-endian)", | |
"CEFAEDFE": "Mach-O executable (32-bit, big-endian)", | |
"FEEDFACF": "Mach-O executable (64-bit, little-endian)", | |
"CFFAEDFE": "Mach-O executable (64-bit, big-endian)", | |
"BEBAFECA": "Mach-O Fat Binary (little endian)", | |
"4C000000": "Windows shortcut file (.lnk)", | |
"4D534346": "Microsoft Cabinet file (CAB)", | |
"D0CF11E0": "Microsoft Office legacy format (DOC, XLS, PPT)", | |
"25504446": "PDF document", | |
"7B5C727466": "RTF document (starting with '{\\rtf')", | |
"3C3F786D6C": "XML file (<?xml)", | |
"3C68746D6C3E": "HTML file", | |
"252150532D41646F6265": "PostScript/EPS document (starts with '%!PS-Adobe')", | |
"4D2D2D2D": "PostScript file (---)", | |
"89504E47": "PNG image", | |
"47494638": "GIF image", | |
"FFD8FF": "JPEG image (general)", | |
"FFD8FFE0": "JPEG image (JFIF)", | |
"FFD8FFE1": "JPEG image (EXIF)", | |
"424D": "Bitmap image (BMP)", | |
"49492A00": "TIFF image (little endian / Intel)", | |
"4D4D002A": "TIFF image (big endian / Motorola)", | |
"38425053": "Adobe Photoshop document (PSD)", | |
"00000100": "ICO icon file", | |
"00000200": "CUR cursor file", | |
"494433": "MP3 audio (ID3 tag)", | |
"000001BA": "MPEG video (VCD)", | |
"000001B3": "MPEG video", | |
"66747970": "MP4/MOV file (ftyp)", | |
"4D546864": "MIDI file", | |
"464F524D": "AIFF audio file", | |
"52494646": "AVI file (RIFF) [Also used in WAV]", | |
"664C6143": "FLAC audio file", | |
"4F676753": "OGG container file (OggS)", | |
"53514C69": "SQLite database file (SQLite format 3)", | |
"420D0D0A": "Python compiled file (.pyc) [example magic, may vary]", | |
"6465780A": "Android Dalvik Executable (DEX) file", | |
"EDABEEDB": "RPM package file", | |
"786172210D0A1A0A": "XAR archive (macOS installer package)", | |
} | |
def normalize_payload(self, payload: bytes) -> str: | |
""" | |
Attempt to decode the payload as UTF-8 text and replace non-printable characters. | |
If decoding fails or results in gibberish, fallback to a hex representation. | |
""" | |
try: | |
text = payload.decode('utf-8', errors='ignore') | |
normalized = ''.join(ch if ch in string.printable else '.' for ch in text) | |
if sum(1 for ch in normalized if ch == '.') > len(normalized) * 0.5: | |
return payload.hex() | |
return normalized | |
except Exception: | |
return payload.hex() | |
def parse_http_payload(self, payload: bytes) -> None: | |
try: | |
text = payload.decode('utf-8', errors='replace') | |
self.logger.info("HTTP Payload:\n%s", text) | |
except Exception as e: | |
self.logger.exception("Error parsing HTTP payload: %s", e) | |
def parse_text_payload(self, payload: bytes) -> None: | |
try: | |
text = payload.decode('utf-8', errors='replace') | |
self.logger.info("Text Payload:\n%s", text) | |
except Exception as e: | |
self.logger.exception("Error decoding text payload: %s", e) | |
def parse_tls_payload(self, payload: bytes) -> None: | |
self.logger.info("TLS Payload (hex):\n%s", payload.hex()) | |
def analyze_hex_dump(self, payload: bytes) -> List[Tuple[str, str]]: | |
head_hex = payload[:23].hex().upper() | |
full_hex = payload.hex().upper() | |
self.logger.info("Analyzing payload hex dump for signatures...") | |
found = [] | |
for sig, desc in self.SIGNATURES.items(): | |
if head_hex.startswith(sig) or sig in full_hex: | |
found.append((sig, desc)) | |
self.logger.warning("Detected signature %s: %s", sig, desc) | |
return found | |
class MLClassifier: | |
def __init__(self, model_path: str = "model.pkl", logger: Optional[logging.Logger] = None) -> None: | |
self.logger = logger or logging.getLogger(__name__) | |
try: | |
import joblib | |
self.model = joblib.load(model_path) | |
self.logger.info("ML model loaded successfully from %s", model_path) | |
except Exception as e: | |
self.logger.error("Failed to load ML model from %s: %s", model_path, e) | |
self.model = None | |
def extract_features(self, payload: bytes) -> np.ndarray: | |
length = len(payload) | |
counts = Counter(payload) | |
total = length if length > 0 else 1 | |
entropy = -sum((count / total) * np.log2(count / total) for count in counts.values() if count > 0) | |
return np.array([[length, entropy]]) | |
def classify(self, payload: bytes) -> bool: | |
features = self.extract_features(payload) | |
if self.model is not None: | |
prediction = self.model.predict(features) | |
self.logger.debug("ML model prediction: %s", prediction) | |
return bool(prediction[0]) | |
else: | |
if features[0, 0] > 1000 and features[0, 1] > 7.0: | |
self.logger.debug("Fallback heuristic: payload marked as malicious.") | |
return True | |
return False | |
@staticmethod | |
def train_model(dataset_path: str, model_output_path: str, logger: Optional[logging.Logger] = None) -> None: | |
import pandas as pd | |
from sklearn.ensemble import RandomForestClassifier | |
import joblib | |
logger = logger or logging.getLogger(__name__) | |
logger.info("Loading dataset from %s", dataset_path) | |
try: | |
df = pd.read_csv(dataset_path) | |
except Exception as e: | |
logger.error("Failed to load dataset: %s", e) | |
return | |
features = [] | |
labels = [] | |
classifier = MLClassifier(logger=logger) | |
for index, row in df.iterrows(): | |
payload_str = row.get('payload', '') | |
try: | |
payload_bytes = bytes.fromhex(payload_str) | |
except Exception as e: | |
logger.error("Error converting payload to bytes for row %d: %s", index, e) | |
continue | |
feats = classifier.extract_features(payload_bytes)[0] | |
features.append(feats) | |
labels.append(row.get('label', 0)) | |
if not features: | |
logger.error("No valid training samples found.") | |
return | |
logger.info("Training model on %d samples", len(features)) | |
clf = RandomForestClassifier(n_estimators=100, random_state=42) | |
clf.fit(features, labels) | |
joblib.dump(clf, model_output_path) | |
logger.info("Model trained and saved to %s", model_output_path) | |
def get_local_process_info(port: int) -> str: | |
try: | |
import psutil | |
for conn in psutil.net_connections(kind="inet"): | |
if conn.laddr and conn.laddr.port == port: | |
return str(conn.pid) if conn.pid is not None else "N/A" | |
except ImportError: | |
return "N/A (psutil not installed)" | |
return "N/A" | |
class XDPCollector: | |
def __init__(self, interface: str, logger: Optional[logging.Logger] = None) -> None: | |
if BPF is None: | |
raise ImportError("BCC BPF module is not available. Please install bcc.") | |
self.logger = logger or logging.getLogger(__name__) | |
self.interface = interface | |
self.bpf = BPF(text=self._bpf_program()) | |
func = self.bpf.load_func("xdp_prog", BPF.XDP) | |
self.bpf.attach_xdp(self.interface, func, 0) | |
self.logger.info("XDP program attached on interface %s", self.interface) | |
self.bpf["events"].open_perf_buffer(self._handle_event) | |
def _bpf_program(self) -> str: | |
return """ | |
#include <uapi/linux/bpf.h> | |
#include <linux/if_ether.h> | |
#include <linux/ip.h> | |
struct data_t { | |
u32 pkt_len; | |
u32 src_ip; | |
u32 dst_ip; | |
u8 protocol; | |
}; | |
BPF_PERF_OUTPUT(events); | |
int xdp_prog(struct xdp_md *ctx) { | |
struct data_t data = {}; | |
void *data_end = (void *)(long)ctx->data_end; | |
void *data_ptr = (void *)(long)ctx->data; | |
struct ethhdr *eth = data_ptr; | |
if (data_ptr + sizeof(*eth) > data_end) | |
return XDP_PASS; | |
if (eth->h_proto == __constant_htons(ETH_P_IP)) { | |
struct iphdr *ip = data_ptr + sizeof(*eth); | |
if ((void*)ip + sizeof(*ip) > data_end) | |
return XDP_PASS; | |
data.pkt_len = data_end - data_ptr; | |
data.src_ip = ip->saddr; | |
data.dst_ip = ip->daddr; | |
data.protocol = ip->protocol; | |
events.perf_submit(ctx, &data, sizeof(data)); | |
} | |
return XDP_PASS; | |
} | |
""" | |
def _handle_event(self, cpu, data, size): | |
global last_xdp_metadata | |
event = self.bpf["events"].event(data) | |
last_xdp_metadata = { | |
"pkt_len": event.pkt_len, | |
"src_ip": event.src_ip, | |
"dst_ip": event.dst_ip, | |
"protocol": event.protocol | |
} | |
self.logger.debug("XDP event: %s", last_xdp_metadata) | |
def poll(self): | |
self.bpf.perf_buffer_poll(timeout=100) | |
def detach(self): | |
self.bpf.remove_xdp(self.interface, 0) | |
def get_xdp_metadata_details() -> str: | |
global last_xdp_metadata | |
if last_xdp_metadata: | |
return (f"XDP Metadata: Packet Length: {last_xdp_metadata.get('pkt_len')}, " | |
f"Src IP: {last_xdp_metadata.get('src_ip')}, " | |
f"Dst IP: {last_xdp_metadata.get('dst_ip')}, " | |
f"Protocol: {last_xdp_metadata.get('protocol')}") | |
else: | |
return "No XDP metadata available." | |
class PacketSniffer: | |
COMMON_PORTS = { | |
20: ("FTP Data", "File Transfer Protocol - Data channel"), | |
21: ("FTP Control", "File Transfer Protocol - Control channel"), | |
22: ("SSH", "Secure Shell"), | |
23: ("Telnet", "Telnet protocol"), | |
25: ("SMTP", "Simple Mail Transfer Protocol"), | |
53: ("DNS", "Domain Name System"), | |
67: ("DHCP", "DHCP Server"), | |
68: ("DHCP", "DHCP Client"), | |
80: ("HTTP", "Hypertext Transfer Protocol"), | |
110: ("POP3", "Post Office Protocol"), | |
119: ("NNTP", "Network News Transfer Protocol"), | |
123: ("NTP", "Network Time Protocol"), | |
143: ("IMAP", "Internet Message Access Protocol"), | |
161: ("SNMP", "Simple Network Management Protocol"), | |
443: ("HTTPS", "HTTP Secure"), | |
3306: ("MySQL", "MySQL database service"), | |
5432: ("PostgreSQL", "PostgreSQL database service"), | |
3389: ("RDP", "Remote Desktop Protocol") | |
} | |
def __init__(self, args: argparse.Namespace) -> None: | |
self.args = args | |
self.logger = logging.getLogger("PacketSniffer") | |
self.flagged_logger = logging.getLogger("FlaggedIPLogger") | |
self.dns_logger = logging.getLogger("DNSQueryLogger") | |
self.ip_lookup = IPLookup(ttl=3600, logger=self.logger) | |
self.dns_parser = DNSParser(logger=self.logger) | |
self.payload_parser = PayloadParser(logger=self.logger) | |
self.async_runner = AsyncRunner() | |
if self.args.sentry: | |
self.ml_classifier = MLClassifier(model_path=self.args.model_path, logger=self.logger) | |
self.xdp_collector = None | |
if self.args.xdp and BPF is not None: | |
try: | |
self.xdp_collector = XDPCollector(self.args.interface, logger=self.logger) | |
self.xdp_thread = threading.Thread(target=self._poll_xdp, daemon=True) | |
self.xdp_thread.start() | |
except Exception as e: | |
self.logger.exception("Failed to initialize XDPCollector: %s", e) | |
def _poll_xdp(self): | |
while True: | |
try: | |
if self.xdp_collector: | |
self.xdp_collector.poll() | |
except Exception as e: | |
self.logger.exception("Error polling XDP events: %s", e) | |
def identify_application(self, src_port: int, dst_port: int) -> Tuple[str, str]: | |
for port in (dst_port, src_port): | |
if port in self.COMMON_PORTS: | |
return self.COMMON_PORTS[port] | |
return ("Unknown", "Unknown application protocol") | |
def block_ip(self, ip: str) -> None: | |
cmd = f"echo 'Blocking IP {ip}' >> BLOCKER.GARY" | |
self.logger.info("Blocking IP %s with command: %s", ip, cmd) | |
os.system(cmd) | |
def log_flagged_ip(self, packet, flagged_signatures: List[Tuple[str, str]], | |
app_name: str, app_details: str) -> None: | |
source_ip = "Unknown" | |
dest_ip = "Unknown" | |
port_info = "" | |
process_id = "N/A" | |
if packet.haslayer(IP): | |
source_ip = packet[IP].src | |
dest_ip = packet[IP].dst | |
elif packet.haslayer(IPv6): | |
source_ip = packet[IPv6].src | |
dest_ip = packet[IPv6].dst | |
if packet.haslayer(TCP): | |
tcp_layer = packet[TCP] | |
port_info = f"TCP src: {tcp_layer.sport}, dst: {tcp_layer.dport}" | |
process_id = get_local_process_info(tcp_layer.dport) | |
elif packet.haslayer(UDP): | |
udp_layer = packet[UDP] | |
port_info = f"UDP src: {udp_layer.sport}, dst: {udp_layer.dport}" | |
process_id = get_local_process_info(udp_layer.dport) | |
ip_background = "No background info available." | |
future = self.async_runner.run_coroutine(self.ip_lookup.get_ip_info(source_ip)) | |
try: | |
info = future.result(timeout=6) | |
if info: | |
ip_background = "\n".join([f"{k.capitalize()}: {v}" for k, v in info.items() if k in ("hostname", "city", "region", "country", "org")]) | |
except Exception as e: | |
self.logger.exception("Error retrieving IP background info: %s", e) | |
message = ( | |
"\n====== FLAGGED IP ALERT ======\n" | |
f"Source IP: {source_ip}\n" | |
f"Destination IP: {dest_ip}\n" | |
f"Application: {app_name} ({app_details})\n" | |
f"Port Info: {port_info}\n" | |
f"Process ID: {process_id}\n" | |
f"IP Background:\n{ip_background}\n" | |
f"Flagged Signatures: {flagged_signatures}\n" | |
"===============================\n" | |
) | |
self.flagged_logger.warning(message) | |
def parse_payload(self, packet, app_name: str, payload: bytes) -> None: | |
if self.args.sentry: | |
if self.ml_classifier.classify(payload): | |
src_ip = packet[IP].src if packet.haslayer(IP) else "unknown" | |
self.logger.warning("Sentry mode: payload classified as malicious. Blocking IP %s", src_ip) | |
self.block_ip(src_ip) | |
return | |
self.logger.info("Parsing payload for application: %s", app_name) | |
flagged_signatures = self.payload_parser.analyze_hex_dump(payload) | |
if flagged_signatures: | |
app_info = self.identify_application( | |
packet[TCP].sport if packet.haslayer(TCP) else (packet[UDP].sport if packet.haslayer(UDP) else 0), | |
packet[TCP].dport if packet.haslayer(TCP) else (packet[UDP].dport if packet.haslayer(UDP) else 0) | |
) | |
self.log_flagged_ip(packet, flagged_signatures, app_name, app_info[1]) | |
if "HTTP" in app_name: | |
self.payload_parser.parse_http_payload(payload) | |
elif app_name == "DNS": | |
self.dns_parser.parse_dns_payload(payload) | |
else: | |
self.payload_parser.parse_text_payload(payload) | |
def packet_handler(self, packet) -> None: | |
if self.args.expensive and packet.haslayer(Raw): | |
payload = bytes(packet[Raw].load) | |
normalized_payload = self.payload_parser.normalize_payload(payload) | |
hex_payload = payload.hex() | |
packet_summary = packet.summary() | |
packet_details = packet.show(dump=True) | |
self.logger.info("Expensive mode: analyzing packet with LLM.") | |
label = llm_label_packet(packet_summary, normalized_payload, hex_payload, packet_details, logger=self.logger) | |
if label == 1: | |
self.logger.warning("Expensive Mode: Packet flagged as malicious by LLM.") | |
self.log_flagged_ip(packet, flagged_signatures=[], app_name="Expensive Mode", app_details="LLM flagged malicious") | |
src_ip = packet[IP].src if packet.haslayer(IP) else "unknown" | |
self.block_ip(src_ip) | |
else: | |
self.logger.info("Expensive Mode: Packet deemed benign by LLM.") | |
return | |
ip_str = None | |
if packet.haslayer(IP): | |
ip_str = packet[IP].src | |
ip_obj = ipaddress.ip_address(ip_str) | |
if self.args.local_only and not ip_obj.is_private: | |
return | |
if not self.args.local_only and ip_obj.is_private: | |
return | |
elif packet.haslayer(IPv6): | |
ip_str = packet[IPv6].src | |
ip_obj = ipaddress.ip_address(ip_str) | |
if self.args.local_only and not ip_obj.is_private: | |
return | |
if not self.args.local_only and ip_obj.is_private: | |
return | |
else: | |
self.logger.warning("Packet without IP/IPv6 layer") | |
return | |
if not self.args.local_only and self.args.red_alert: | |
future = self.async_runner.run_coroutine(self.ip_lookup.get_ip_info(ip_str)) | |
try: | |
ip_info = future.result(timeout=6) | |
country = ip_info.get("country", "").upper() | |
if country in self.args.whitelist: | |
self.logger.info("Skipping packet from whitelisted country: %s", country) | |
return | |
except Exception as e: | |
self.logger.exception("Error during red alert IP filtering: %s", e) | |
self.logger.info("=" * 80) | |
self.logger.info("Packet: %s", packet.summary()) | |
summary_str = packet.summary() | |
dns_match = re.search(r"DNS Qry b'([^']+)'", summary_str) | |
if dns_match: | |
dns_query = dns_match.group(1) | |
if not self.dns_parser.is_blacklisted(dns_query): | |
self.dns_logger.info("DNS Query (fallback): %s", dns_query) | |
if packet.haslayer(DNS): | |
dns_layer = packet[DNS] | |
if dns_layer.qr == 0 and dns_layer.qd is not None: | |
try: | |
if isinstance(dns_layer.qd, DNSQR): | |
dns_query = (dns_layer.qd.qname.decode() | |
if isinstance(dns_layer.qd.qname, bytes) | |
else dns_layer.qd.qname) | |
else: | |
dns_query = ", ".join( | |
q.qname.decode() if isinstance(q.qname, bytes) else q.qname | |
for q in dns_layer.qd | |
) | |
except Exception as e: | |
dns_query = str(dns_layer.qd) | |
self.logger.info("DNS Query (from DNS layer): %s", dns_query) | |
try: | |
raw_dns_payload = bytes(dns_layer) | |
self.dns_parser.parse_dns_payload(raw_dns_payload) | |
except Exception as e: | |
self.logger.exception("Error processing DNS layer: %s", e) | |
if packet.haslayer(Ether): | |
eth = packet[Ether] | |
self.logger.info("Ethernet: src=%s, dst=%s, type=0x%04x", eth.src, eth.dst, eth.type) | |
else: | |
self.logger.warning("No Ethernet layer found.") | |
return | |
if packet.haslayer(ARP): | |
arp = packet[ARP] | |
self.logger.info("ARP: op=%s, src=%s, dst=%s", arp.op, arp.psrc, arp.pdst) | |
return | |
if packet.haslayer(IP): | |
ip_layer = packet[IP] | |
self.logger.info("IPv4: src=%s, dst=%s, ttl=%s, proto=%s", | |
ip_layer.src, ip_layer.dst, ip_layer.ttl, ip_layer.proto) | |
elif packet.haslayer(IPv6): | |
ip_layer = packet[IPv6] | |
self.logger.info("IPv6: src=%s, dst=%s, hlim=%s", | |
ip_layer.src, ip_layer.dst, ip_layer.hlim) | |
if packet.haslayer(TCP): | |
tcp_layer = packet[TCP] | |
self.logger.info("TCP: sport=%s, dport=%s", tcp_layer.sport, tcp_layer.dport) | |
app_name, app_details = self.identify_application(tcp_layer.sport, tcp_layer.dport) | |
self.logger.info("Identified Application: %s (%s)", app_name, app_details) | |
if app_name == "HTTPS" or (TLS and packet.haslayer(TLS)): | |
if TLS and packet.haslayer(TLS): | |
tls_layer = packet[TLS] | |
self.logger.info("TLS Record: %s", tls_layer.summary()) | |
if packet.haslayer(TLSClientHello): | |
client_hello = packet[TLSClientHello] | |
self.logger.info("TLS ClientHello: %s", client_hello.summary()) | |
if hasattr(client_hello, 'servernames'): | |
self.logger.info("SNI: %s", client_hello.servernames) | |
else: | |
if packet.haslayer(Raw): | |
payload = bytes(packet[Raw].load) | |
self.payload_parser.parse_tls_payload(payload) | |
else: | |
if packet.haslayer(Raw): | |
payload = bytes(packet[Raw].load) | |
self.parse_payload(packet, app_name, payload) | |
elif packet.haslayer(UDP): | |
udp_layer = packet[UDP] | |
self.logger.info("UDP: sport=%s, dport=%s", udp_layer.sport, udp_layer.dport) | |
app_name, app_details = self.identify_application(udp_layer.sport, udp_layer.dport) | |
self.logger.info("Identified Application: %s (%s)", app_name, app_details) | |
if packet.haslayer(Raw): | |
payload = bytes(packet[Raw].load) | |
self.parse_payload(packet, app_name, payload) | |
elif packet.haslayer(ICMP): | |
icmp_layer = packet[ICMP] | |
self.logger.info("ICMP: type=%s, code=%s", icmp_layer.type, icmp_layer.code) | |
else: | |
self.logger.warning("Unsupported transport layer.") | |
def run(self) -> None: | |
self.logger.info("Starting Enhanced Packet Sniffer on interface '%s'", self.args.interface) | |
try: | |
bpf_filter = "ip" if not self.args.local_only else "" | |
sniff(iface=self.args.interface, prn=self.packet_handler, store=0, filter=bpf_filter) | |
except KeyboardInterrupt: | |
self.logger.info("Stopping packet capture (KeyboardInterrupt received)...") | |
except Exception as e: | |
self.logger.exception("Error during packet capture: %s", e) | |
async def async_llm_label_packet(packet_summary: str, normalized_payload: str, hex_payload: str, packet_details: str, logger: Optional[logging.Logger] = None) -> int: | |
""" | |
Uses the AsyncOpenAI client to determine if a packet is malicious or benign. | |
Returns 1 if the answer indicates 'malicious', 0 otherwise. | |
""" | |
logger = logger or logging.getLogger(__name__) | |
if AsyncOpenAI is None: | |
logger.error("AsyncOpenAI client is not available. Please install the openai package (>=1.0.0).") | |
return 0 | |
api_key = os.getenv("OPENAI_API_KEY") | |
if not api_key: | |
logger.error("OpenAI API key not found in environment (OPENAI_API_KEY).") | |
return 0 | |
client = AsyncOpenAI(api_key=api_key) | |
xdp_details = get_xdp_metadata_details() | |
prompt = ( | |
f"Examine the following packet details and decide if the packet is malicious or benign.\n\n" | |
f"Packet Summary:\n{packet_summary}\n\n" | |
f"Packet Detailed Info (Scapy dump):\n{packet_details}\n\n" | |
f"Normalized Payload (first 300 characters):\n{normalized_payload[:300]}\n\n" | |
f"Hex Payload (first 200 characters):\n{hex_payload[:200]}\n\n" | |
f"{xdp_details}\n\n" | |
"Answer with a single word: 'malicious' or 'benign'." | |
) | |
logger.debug("LLM Prompt:\n%s", prompt) | |
try: | |
response = await client.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[ | |
{"role": "system", "content": "You are a network security analyst."}, | |
{"role": "user", "content": prompt} | |
], | |
max_tokens=10, | |
temperature=0 | |
) | |
answer = response.choices[0].message.content.strip().lower() | |
logger.debug("LLM response: %s", answer) | |
return 1 if "malicious" in answer else 0 | |
except Exception as e: | |
logger.exception("Error querying OpenAI API: %s", e) | |
return 0 | |
def llm_label_packet(packet_summary: str, normalized_payload: str, hex_payload: str, packet_details: str, logger: Optional[logging.Logger] = None) -> int: | |
""" | |
Synchronous wrapper for async_llm_label_packet. | |
""" | |
logger = logger or logging.getLogger(__name__) | |
new_loop = asyncio.new_event_loop() | |
try: | |
result = new_loop.run_until_complete( | |
async_llm_label_packet(packet_summary, normalized_payload, hex_payload, packet_details, logger=logger) | |
) | |
return result | |
except Exception as e: | |
logger.exception("Error in LLM labeling: %s", e) | |
return 0 | |
finally: | |
new_loop.run_until_complete(new_loop.shutdown_asyncgens()) | |
new_loop.close() | |
def build_dataset_main(interface: str, num_samples: int, output_path: str, use_llm: bool, logger: Optional[logging.Logger] = None) -> None: | |
""" | |
Captures packets with a Raw payload and labels them. | |
If use_llm is True, the LLM (via the async client) is used for automatic labeling. | |
The results are saved to a CSV file. | |
""" | |
logger = logger or logging.getLogger(__name__) | |
logger.info("Starting dataset capture on interface %s; capturing %d samples.", interface, num_samples) | |
samples = [] | |
def packet_callback(packet): | |
if packet.haslayer(Raw): | |
payload = bytes(packet[Raw].load) | |
hex_payload = payload.hex() | |
normalized_payload = PayloadParser(logger=logger).normalize_payload(payload) | |
packet_summary = packet.summary() | |
packet_details = packet.show(dump=True) | |
print("\nPacket captured:") | |
print(packet_summary) | |
print("Packet Detailed Info:") | |
print(packet_details) | |
print("Normalized Payload (first 40 characters):", normalized_payload[:40]) | |
if use_llm: | |
label = llm_label_packet(packet_summary, normalized_payload, hex_payload, packet_details, logger=logger) | |
print(f"LLM labeled this packet as: {'malicious' if label == 1 else 'benign'}") | |
else: | |
user_input = input("Label this packet as malicious (1) or benign (0) [default=0]: ").strip() | |
label = int(user_input) if user_input in ["0", "1"] else 0 | |
samples.append({"payload": hex_payload, "label": label}) | |
if len(samples) >= num_samples: | |
return True | |
return False | |
scapy.sniff(iface=interface, prn=packet_callback, store=0, timeout=60) | |
if samples: | |
try: | |
with open(output_path, "w", newline="") as csvfile: | |
fieldnames = ["payload", "label"] | |
writer = csv.DictWriter(csvfile, fieldnames=fieldnames) | |
writer.writeheader() | |
for sample in samples: | |
writer.writerow(sample) | |
logger.info("Dataset built and saved to %s (%d samples).", output_path, len(samples)) | |
except Exception as e: | |
logger.error("Failed to write dataset to %s: %s", output_path, e) | |
else: | |
logger.error("No samples were captured.") | |
def train_model_main(dataset_path: str, model_output_path: str, logger: Optional[logging.Logger] = None) -> None: | |
logger = logger or logging.getLogger(__name__) | |
MLClassifier.train_model(dataset_path, model_output_path, logger=logger) | |
def parse_arguments() -> argparse.Namespace: | |
parser = argparse.ArgumentParser( | |
description="Enhanced Packet Sniffer with Red Alert, Whitelist, Sentry Mode, Dataset Building, Training, and LLM-assisted Labeling" | |
) | |
parser.add_argument("-i", "--interface", type=str, default="enp0s31f6", | |
help="Network interface to sniff on") | |
parser.add_argument("-l", "--logfile", type=str, default="sniffer.log", | |
help="Path to the main log file") | |
parser.add_argument("--no-bgcheck", action="store_true", | |
help="Disable IP background lookup") | |
parser.add_argument("--local-only", action="store_true", | |
help="Capture only local (private) traffic") | |
parser.add_argument("--red-alert", action="store_true", | |
help="Enable red alert mode: only log packets from non-allied countries") | |
parser.add_argument("--whitelist", type=str, default="US", | |
help="Comma-separated list of allied (whitelisted) country codes (default: US)") | |
parser.add_argument("--sentry", action="store_true", | |
help="Enable sentry mode (ML-based blocking of malicious packets)") | |
parser.add_argument("--model-path", type=str, default="model.pkl", | |
help="Path to the ML model file (used in sentry mode)") | |
parser.add_argument("--train", action="store_true", | |
help="Run training mode to build a new ML model from a dataset") | |
parser.add_argument("--dataset", type=str, default="", | |
help="Path to the CSV dataset for training") | |
parser.add_argument("--build-dataset", action="store_true", | |
help="Capture and build a labeled dataset interactively") | |
parser.add_argument("--num-samples", type=int, default=10, | |
help="Number of samples to capture when building the dataset") | |
parser.add_argument("--dataset-out", type=str, default="training_data.csv", | |
help="Path to save the built dataset CSV file") | |
parser.add_argument("--llm-label", action="store_true", | |
help="Use the OpenAI API to automatically label packets in dataset building mode") | |
parser.add_argument("--xdp", action="store_true", | |
help="Enable XDP (Express Data Path) to gather additional packet metadata") | |
parser.add_argument("--expensive", action="store_true", | |
help="Use LLM to analyze all packets in realtime (expensive mode)") | |
parser.add_argument("-v", "--verbosity", type=str, default="INFO", | |
help="Logging level (DEBUG, INFO, WARNING, ERROR)") | |
return parser.parse_args() | |
if __name__ == "__main__": | |
args = parse_arguments() | |
args.whitelist = set(code.strip().upper() for code in args.whitelist.split(',')) | |
logging_level = getattr(logging, args.verbosity.upper(), logging.INFO) | |
setup_logging("logging.yaml") | |
logging.getLogger().setLevel(logging_level) | |
logger = logging.getLogger(__name__) | |
if args.build_dataset: | |
build_dataset_main(args.interface, args.num_samples, args.dataset_out, args.llm_label, logger=logger) | |
elif args.train: | |
if not args.dataset: | |
logger.error("Training mode requires a dataset. Please provide --dataset <path>") | |
else: | |
train_model_main(args.dataset, args.model_path, logger=logger) | |
else: | |
sniffer = PacketSniffer(args) | |
sniffer.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment