Skip to content

Instantly share code, notes, and snippets.

@gary23w
Last active February 16, 2025 14:59
Show Gist options
  • Save gary23w/3bbb51566a7ba1a7bbc1a7e2d384b8b6 to your computer and use it in GitHub Desktop.
Save gary23w/3bbb51566a7ba1a7bbc1a7e2d384b8b6 to your computer and use it in GitHub Desktop.
simple-packet-sniffer.py
"""
simplepacketsniffer.py 0.4
Usage examples:
Dataset Building Mode (with LLM labeling):
sudo python simplepacketsniffer.py --build-dataset --llm-label --num-samples 10 --dataset-out training_data.csv -i INTERFACE -v INFO [--xdp]
Training Mode:
sudo python simplepacketsniffer.py --train --dataset training_data.csv --model-path model.pkl -v INFO
Sniffer (Protection) Mode:
sudo python simplepacketsniffer.py -i INTERFACE --red-alert --whitelist US,CA --sentry --model-path model.pkl -v DEBUG [--xdp]
Expensive Mode (Real-time LLM analysis for every packet):
sudo python simplepacketsniffer.py --expensive -i INTERFACE -v DEBUG [--xdp]
***logging.yaml
version: 1
disable_existing_loggers: false
formatters:
standard:
format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
simple:
format: '%(asctime)s - %(levelname)s - %(message)s'
handlers:
console:
class: logging.StreamHandler
level: DEBUG
formatter: standard
stream: ext://sys.stdout
file_handler:
class: logging.handlers.RotatingFileHandler
level: DEBUG
formatter: standard
filename: sniffer.log
maxBytes: 1048576
backupCount: 3
flagged_ip_handler:
class: logging.handlers.RotatingFileHandler
level: WARNING
formatter: standard
filename: flagged_ips.log
maxBytes: 1048576
backupCount: 3
dns_handler:
class: logging.handlers.RotatingFileHandler
level: INFO
formatter: standard
filename: dns_queries.log
maxBytes: 1048576
backupCount: 3
loggers:
PacketSniffer:
level: DEBUG
handlers: [console, file_handler]
propagate: no
FlaggedIPLogger:
level: WARNING
handlers: [flagged_ip_handler, console]
propagate: no
DNSQueryLogger:
level: INFO
handlers: [dns_handler, console]
propagate: no
root:
level: DEBUG
handlers: [console]
***
"""
import argparse
import asyncio
import aiohttp
import logging
import logging.config
import yaml
import re
import ipaddress
import time
import threading
import struct
import os
import csv
import string
from collections import Counter
from cachetools import TTLCache
from typing import Tuple, Dict, Any, List, Optional
import numpy as np
import scapy.all as scapy
from scapy.all import sniff, Ether, IP, IPv6, TCP, UDP, ICMP, ARP, Raw
from scapy.layers.dns import DNS, DNSQR
try:
from scapy.layers.tls.all import TLS, TLSClientHello, TLSHandshake
except ImportError:
TLS = None
TLSClientHello = None
TLSHandshake = None
try:
from openai import AsyncOpenAI
except ImportError:
AsyncOpenAI = None
try:
from bcc import BPF
except ImportError:
BPF = None
last_xdp_metadata: Optional[Dict[str, Any]] = None
def setup_logging(config_file: str = "logging.yaml") -> None:
try:
with open(config_file, "r") as f:
config = yaml.safe_load(f.read())
logging.config.dictConfig(config)
except Exception as e:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
logging.getLogger(__name__).exception(
"Failed to load logging config from %s, using basic config: %s",
config_file, e
)
class AsyncRunner:
def __init__(self) -> None:
self.loop = asyncio.new_event_loop()
self.thread = threading.Thread(target=self._run_loop, daemon=True)
self.thread.start()
def _run_loop(self) -> None:
asyncio.set_event_loop(self.loop)
self.loop.run_forever()
def run_coroutine(self, coro):
return asyncio.run_coroutine_threadsafe(coro, self.loop)
class IPLookup:
def __init__(self, ttl: int = 3600, cache_size: int = 1000, logger: Optional[logging.Logger] = None) -> None:
self.cache = TTLCache(maxsize=cache_size, ttl=ttl)
self.logger = logger or logging.getLogger(__name__)
async def fetch_ip_info(self, session: aiohttp.ClientSession, ip: str) -> Dict[str, Any]:
if ip in self.cache:
return self.cache[ip]
try:
async with session.get(f"https://ipinfo.io/{ip}/json", timeout=5) as resp:
if resp.status == 200:
data = await resp.json()
self.cache[ip] = data
return data
else:
self.logger.warning("IP lookup for %s returned status %s", ip, resp.status)
except aiohttp.ClientError as e:
self.logger.exception("Network error during IP lookup for %s: %s", ip, e)
except Exception:
self.logger.exception("Unexpected error during IP lookup for %s", ip)
return {}
async def get_ip_info(self, ip: str) -> Dict[str, Any]:
async with aiohttp.ClientSession() as session:
return await self.fetch_ip_info(session, ip)
class DNSParser:
def __init__(self, logger: Optional[logging.Logger] = None, blacklist: Optional[List[str]] = None) -> None:
self.logger = logger or logging.getLogger(__name__)
self.blacklist = set(blacklist or ['ipinfo.io', 'ipinfo.io.']) # TODO: add wildcard domains
def is_blacklisted(self, domain: str) -> bool:
return domain in self.blacklist
def parse_dns_name(self, payload: bytes, offset: int) -> Tuple[str, int]:
labels = []
while True:
if offset >= len(payload):
break
length = payload[offset]
if (length & 0xC0) == 0xC0:
if offset + 1 >= len(payload):
break
pointer = ((length & 0x3F) << 8) | payload[offset + 1]
pointed_name, _ = self.parse_dns_name(payload, pointer)
labels.append(pointed_name)
offset += 2
break
if length == 0:
offset += 1
break
offset += 1
label = payload[offset: offset + length].decode('utf-8', errors='replace')
labels.append(label)
offset += length
domain_name = ".".join(labels)
return domain_name, offset
def parse_dns_payload(self, payload: bytes) -> None:
self.logger.info("Parsing DNS payload...")
if len(payload) < 12:
self.logger.warning("DNS payload too short to parse.")
return
try:
transaction_id, flags, qdcount, ancount, nscount, arcount = struct.unpack("!HHHHHH", payload[:12])
self.logger.info("DNS Header: ID=%#04x, Flags=%#04x, QD=%d, AN=%d, NS=%d, AR=%d",
transaction_id, flags, qdcount, ancount, nscount, arcount)
except Exception as e:
self.logger.exception("Error parsing DNS header: %s", e)
return
offset = 12
for i in range(qdcount):
try:
domain, offset = self.parse_dns_name(payload, offset)
if offset + 4 > len(payload):
self.logger.warning("DNS question truncated.")
break
qtype, qclass = struct.unpack("!HH", payload[offset:offset + 4])
offset += 4
self.logger.info("DNS Question %d: %s, type: %d, class: %d", i + 1, domain, qtype, qclass)
except Exception as e:
self.logger.exception("Error parsing DNS question: %s", e)
break
class PayloadParser:
def __init__(self, logger: Optional[logging.Logger] = None) -> None:
self.logger = logger or logging.getLogger(__name__)
self.SIGNATURES = {
"504B0304": "ZIP archive / Office Open XML document (DOCX, XLSX, PPTX)",
"504B030414000600": "Office Open XML (DOCX, XLSX, PPTX) extended header",
"1F8B08": "GZIP archive",
"377ABCAF271C": "7-Zip archive",
"52617221": "RAR archive",
"425A68": "BZIP2 archive",
"213C617263683E0A": "Ar (UNIX archive) / Debian package",
"7F454C46": "ELF executable (Unix/Linux)",
"4D5A": "Windows executable (EXE, MZ header / DLL)",
"CAFEBABE": "Java class file or Mach-O Fat Binary (ambiguous)",
"FEEDFACE": "Mach-O executable (32-bit, little-endian)",
"CEFAEDFE": "Mach-O executable (32-bit, big-endian)",
"FEEDFACF": "Mach-O executable (64-bit, little-endian)",
"CFFAEDFE": "Mach-O executable (64-bit, big-endian)",
"BEBAFECA": "Mach-O Fat Binary (little endian)",
"4C000000": "Windows shortcut file (.lnk)",
"4D534346": "Microsoft Cabinet file (CAB)",
"D0CF11E0": "Microsoft Office legacy format (DOC, XLS, PPT)",
"25504446": "PDF document",
"7B5C727466": "RTF document (starting with '{\\rtf')",
"3C3F786D6C": "XML file (<?xml)",
"3C68746D6C3E": "HTML file",
"252150532D41646F6265": "PostScript/EPS document (starts with '%!PS-Adobe')",
"4D2D2D2D": "PostScript file (---)",
"89504E47": "PNG image",
"47494638": "GIF image",
"FFD8FF": "JPEG image (general)",
"FFD8FFE0": "JPEG image (JFIF)",
"FFD8FFE1": "JPEG image (EXIF)",
"424D": "Bitmap image (BMP)",
"49492A00": "TIFF image (little endian / Intel)",
"4D4D002A": "TIFF image (big endian / Motorola)",
"38425053": "Adobe Photoshop document (PSD)",
"00000100": "ICO icon file",
"00000200": "CUR cursor file",
"494433": "MP3 audio (ID3 tag)",
"000001BA": "MPEG video (VCD)",
"000001B3": "MPEG video",
"66747970": "MP4/MOV file (ftyp)",
"4D546864": "MIDI file",
"464F524D": "AIFF audio file",
"52494646": "AVI file (RIFF) [Also used in WAV]",
"664C6143": "FLAC audio file",
"4F676753": "OGG container file (OggS)",
"53514C69": "SQLite database file (SQLite format 3)",
"420D0D0A": "Python compiled file (.pyc) [example magic, may vary]",
"6465780A": "Android Dalvik Executable (DEX) file",
"EDABEEDB": "RPM package file",
"786172210D0A1A0A": "XAR archive (macOS installer package)",
}
def normalize_payload(self, payload: bytes) -> str:
"""
Attempt to decode the payload as UTF-8 text and replace non-printable characters.
If decoding fails or results in gibberish, fallback to a hex representation.
"""
try:
text = payload.decode('utf-8', errors='ignore')
normalized = ''.join(ch if ch in string.printable else '.' for ch in text)
if sum(1 for ch in normalized if ch == '.') > len(normalized) * 0.5:
return payload.hex()
return normalized
except Exception:
return payload.hex()
def parse_http_payload(self, payload: bytes) -> None:
try:
text = payload.decode('utf-8', errors='replace')
self.logger.info("HTTP Payload:\n%s", text)
except Exception as e:
self.logger.exception("Error parsing HTTP payload: %s", e)
def parse_text_payload(self, payload: bytes) -> None:
try:
text = payload.decode('utf-8', errors='replace')
self.logger.info("Text Payload:\n%s", text)
except Exception as e:
self.logger.exception("Error decoding text payload: %s", e)
def parse_tls_payload(self, payload: bytes) -> None:
self.logger.info("TLS Payload (hex):\n%s", payload.hex())
def analyze_hex_dump(self, payload: bytes) -> List[Tuple[str, str]]:
head_hex = payload[:23].hex().upper()
full_hex = payload.hex().upper()
self.logger.info("Analyzing payload hex dump for signatures...")
found = []
for sig, desc in self.SIGNATURES.items():
if head_hex.startswith(sig) or sig in full_hex:
found.append((sig, desc))
self.logger.warning("Detected signature %s: %s", sig, desc)
return found
class MLClassifier:
def __init__(self, model_path: str = "model.pkl", logger: Optional[logging.Logger] = None) -> None:
self.logger = logger or logging.getLogger(__name__)
try:
import joblib
self.model = joblib.load(model_path)
self.logger.info("ML model loaded successfully from %s", model_path)
except Exception as e:
self.logger.error("Failed to load ML model from %s: %s", model_path, e)
self.model = None
def extract_features(self, payload: bytes) -> np.ndarray:
length = len(payload)
counts = Counter(payload)
total = length if length > 0 else 1
entropy = -sum((count / total) * np.log2(count / total) for count in counts.values() if count > 0)
return np.array([[length, entropy]])
def classify(self, payload: bytes) -> bool:
features = self.extract_features(payload)
if self.model is not None:
prediction = self.model.predict(features)
self.logger.debug("ML model prediction: %s", prediction)
return bool(prediction[0])
else:
if features[0, 0] > 1000 and features[0, 1] > 7.0:
self.logger.debug("Fallback heuristic: payload marked as malicious.")
return True
return False
@staticmethod
def train_model(dataset_path: str, model_output_path: str, logger: Optional[logging.Logger] = None) -> None:
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
import joblib
logger = logger or logging.getLogger(__name__)
logger.info("Loading dataset from %s", dataset_path)
try:
df = pd.read_csv(dataset_path)
except Exception as e:
logger.error("Failed to load dataset: %s", e)
return
features = []
labels = []
classifier = MLClassifier(logger=logger)
for index, row in df.iterrows():
payload_str = row.get('payload', '')
try:
payload_bytes = bytes.fromhex(payload_str)
except Exception as e:
logger.error("Error converting payload to bytes for row %d: %s", index, e)
continue
feats = classifier.extract_features(payload_bytes)[0]
features.append(feats)
labels.append(row.get('label', 0))
if not features:
logger.error("No valid training samples found.")
return
logger.info("Training model on %d samples", len(features))
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(features, labels)
joblib.dump(clf, model_output_path)
logger.info("Model trained and saved to %s", model_output_path)
def get_local_process_info(port: int) -> str:
try:
import psutil
for conn in psutil.net_connections(kind="inet"):
if conn.laddr and conn.laddr.port == port:
return str(conn.pid) if conn.pid is not None else "N/A"
except ImportError:
return "N/A (psutil not installed)"
return "N/A"
class XDPCollector:
def __init__(self, interface: str, logger: Optional[logging.Logger] = None) -> None:
if BPF is None:
raise ImportError("BCC BPF module is not available. Please install bcc.")
self.logger = logger or logging.getLogger(__name__)
self.interface = interface
self.bpf = BPF(text=self._bpf_program())
func = self.bpf.load_func("xdp_prog", BPF.XDP)
self.bpf.attach_xdp(self.interface, func, 0)
self.logger.info("XDP program attached on interface %s", self.interface)
self.bpf["events"].open_perf_buffer(self._handle_event)
def _bpf_program(self) -> str:
return """
#include <uapi/linux/bpf.h>
#include <linux/if_ether.h>
#include <linux/ip.h>
struct data_t {
u32 pkt_len;
u32 src_ip;
u32 dst_ip;
u8 protocol;
};
BPF_PERF_OUTPUT(events);
int xdp_prog(struct xdp_md *ctx) {
struct data_t data = {};
void *data_end = (void *)(long)ctx->data_end;
void *data_ptr = (void *)(long)ctx->data;
struct ethhdr *eth = data_ptr;
if (data_ptr + sizeof(*eth) > data_end)
return XDP_PASS;
if (eth->h_proto == __constant_htons(ETH_P_IP)) {
struct iphdr *ip = data_ptr + sizeof(*eth);
if ((void*)ip + sizeof(*ip) > data_end)
return XDP_PASS;
data.pkt_len = data_end - data_ptr;
data.src_ip = ip->saddr;
data.dst_ip = ip->daddr;
data.protocol = ip->protocol;
events.perf_submit(ctx, &data, sizeof(data));
}
return XDP_PASS;
}
"""
def _handle_event(self, cpu, data, size):
global last_xdp_metadata
event = self.bpf["events"].event(data)
last_xdp_metadata = {
"pkt_len": event.pkt_len,
"src_ip": event.src_ip,
"dst_ip": event.dst_ip,
"protocol": event.protocol
}
self.logger.debug("XDP event: %s", last_xdp_metadata)
def poll(self):
self.bpf.perf_buffer_poll(timeout=100)
def detach(self):
self.bpf.remove_xdp(self.interface, 0)
def get_xdp_metadata_details() -> str:
global last_xdp_metadata
if last_xdp_metadata:
return (f"XDP Metadata: Packet Length: {last_xdp_metadata.get('pkt_len')}, "
f"Src IP: {last_xdp_metadata.get('src_ip')}, "
f"Dst IP: {last_xdp_metadata.get('dst_ip')}, "
f"Protocol: {last_xdp_metadata.get('protocol')}")
else:
return "No XDP metadata available."
class PacketSniffer:
COMMON_PORTS = {
20: ("FTP Data", "File Transfer Protocol - Data channel"),
21: ("FTP Control", "File Transfer Protocol - Control channel"),
22: ("SSH", "Secure Shell"),
23: ("Telnet", "Telnet protocol"),
25: ("SMTP", "Simple Mail Transfer Protocol"),
53: ("DNS", "Domain Name System"),
67: ("DHCP", "DHCP Server"),
68: ("DHCP", "DHCP Client"),
80: ("HTTP", "Hypertext Transfer Protocol"),
110: ("POP3", "Post Office Protocol"),
119: ("NNTP", "Network News Transfer Protocol"),
123: ("NTP", "Network Time Protocol"),
143: ("IMAP", "Internet Message Access Protocol"),
161: ("SNMP", "Simple Network Management Protocol"),
443: ("HTTPS", "HTTP Secure"),
3306: ("MySQL", "MySQL database service"),
5432: ("PostgreSQL", "PostgreSQL database service"),
3389: ("RDP", "Remote Desktop Protocol")
}
def __init__(self, args: argparse.Namespace) -> None:
self.args = args
self.logger = logging.getLogger("PacketSniffer")
self.flagged_logger = logging.getLogger("FlaggedIPLogger")
self.dns_logger = logging.getLogger("DNSQueryLogger")
self.ip_lookup = IPLookup(ttl=3600, logger=self.logger)
self.dns_parser = DNSParser(logger=self.logger)
self.payload_parser = PayloadParser(logger=self.logger)
self.async_runner = AsyncRunner()
if self.args.sentry:
self.ml_classifier = MLClassifier(model_path=self.args.model_path, logger=self.logger)
self.xdp_collector = None
if self.args.xdp and BPF is not None:
try:
self.xdp_collector = XDPCollector(self.args.interface, logger=self.logger)
self.xdp_thread = threading.Thread(target=self._poll_xdp, daemon=True)
self.xdp_thread.start()
except Exception as e:
self.logger.exception("Failed to initialize XDPCollector: %s", e)
def _poll_xdp(self):
while True:
try:
if self.xdp_collector:
self.xdp_collector.poll()
except Exception as e:
self.logger.exception("Error polling XDP events: %s", e)
def identify_application(self, src_port: int, dst_port: int) -> Tuple[str, str]:
for port in (dst_port, src_port):
if port in self.COMMON_PORTS:
return self.COMMON_PORTS[port]
return ("Unknown", "Unknown application protocol")
def block_ip(self, ip: str) -> None:
cmd = f"echo 'Blocking IP {ip}' >> BLOCKER.GARY"
self.logger.info("Blocking IP %s with command: %s", ip, cmd)
os.system(cmd)
def log_flagged_ip(self, packet, flagged_signatures: List[Tuple[str, str]],
app_name: str, app_details: str) -> None:
source_ip = "Unknown"
dest_ip = "Unknown"
port_info = ""
process_id = "N/A"
if packet.haslayer(IP):
source_ip = packet[IP].src
dest_ip = packet[IP].dst
elif packet.haslayer(IPv6):
source_ip = packet[IPv6].src
dest_ip = packet[IPv6].dst
if packet.haslayer(TCP):
tcp_layer = packet[TCP]
port_info = f"TCP src: {tcp_layer.sport}, dst: {tcp_layer.dport}"
process_id = get_local_process_info(tcp_layer.dport)
elif packet.haslayer(UDP):
udp_layer = packet[UDP]
port_info = f"UDP src: {udp_layer.sport}, dst: {udp_layer.dport}"
process_id = get_local_process_info(udp_layer.dport)
ip_background = "No background info available."
future = self.async_runner.run_coroutine(self.ip_lookup.get_ip_info(source_ip))
try:
info = future.result(timeout=6)
if info:
ip_background = "\n".join([f"{k.capitalize()}: {v}" for k, v in info.items() if k in ("hostname", "city", "region", "country", "org")])
except Exception as e:
self.logger.exception("Error retrieving IP background info: %s", e)
message = (
"\n====== FLAGGED IP ALERT ======\n"
f"Source IP: {source_ip}\n"
f"Destination IP: {dest_ip}\n"
f"Application: {app_name} ({app_details})\n"
f"Port Info: {port_info}\n"
f"Process ID: {process_id}\n"
f"IP Background:\n{ip_background}\n"
f"Flagged Signatures: {flagged_signatures}\n"
"===============================\n"
)
self.flagged_logger.warning(message)
def parse_payload(self, packet, app_name: str, payload: bytes) -> None:
if self.args.sentry:
if self.ml_classifier.classify(payload):
src_ip = packet[IP].src if packet.haslayer(IP) else "unknown"
self.logger.warning("Sentry mode: payload classified as malicious. Blocking IP %s", src_ip)
self.block_ip(src_ip)
return
self.logger.info("Parsing payload for application: %s", app_name)
flagged_signatures = self.payload_parser.analyze_hex_dump(payload)
if flagged_signatures:
app_info = self.identify_application(
packet[TCP].sport if packet.haslayer(TCP) else (packet[UDP].sport if packet.haslayer(UDP) else 0),
packet[TCP].dport if packet.haslayer(TCP) else (packet[UDP].dport if packet.haslayer(UDP) else 0)
)
self.log_flagged_ip(packet, flagged_signatures, app_name, app_info[1])
if "HTTP" in app_name:
self.payload_parser.parse_http_payload(payload)
elif app_name == "DNS":
self.dns_parser.parse_dns_payload(payload)
else:
self.payload_parser.parse_text_payload(payload)
def packet_handler(self, packet) -> None:
if self.args.expensive and packet.haslayer(Raw):
payload = bytes(packet[Raw].load)
normalized_payload = self.payload_parser.normalize_payload(payload)
hex_payload = payload.hex()
packet_summary = packet.summary()
packet_details = packet.show(dump=True)
self.logger.info("Expensive mode: analyzing packet with LLM.")
label = llm_label_packet(packet_summary, normalized_payload, hex_payload, packet_details, logger=self.logger)
if label == 1:
self.logger.warning("Expensive Mode: Packet flagged as malicious by LLM.")
self.log_flagged_ip(packet, flagged_signatures=[], app_name="Expensive Mode", app_details="LLM flagged malicious")
src_ip = packet[IP].src if packet.haslayer(IP) else "unknown"
self.block_ip(src_ip)
else:
self.logger.info("Expensive Mode: Packet deemed benign by LLM.")
return
ip_str = None
if packet.haslayer(IP):
ip_str = packet[IP].src
ip_obj = ipaddress.ip_address(ip_str)
if self.args.local_only and not ip_obj.is_private:
return
if not self.args.local_only and ip_obj.is_private:
return
elif packet.haslayer(IPv6):
ip_str = packet[IPv6].src
ip_obj = ipaddress.ip_address(ip_str)
if self.args.local_only and not ip_obj.is_private:
return
if not self.args.local_only and ip_obj.is_private:
return
else:
self.logger.warning("Packet without IP/IPv6 layer")
return
if not self.args.local_only and self.args.red_alert:
future = self.async_runner.run_coroutine(self.ip_lookup.get_ip_info(ip_str))
try:
ip_info = future.result(timeout=6)
country = ip_info.get("country", "").upper()
if country in self.args.whitelist:
self.logger.info("Skipping packet from whitelisted country: %s", country)
return
except Exception as e:
self.logger.exception("Error during red alert IP filtering: %s", e)
self.logger.info("=" * 80)
self.logger.info("Packet: %s", packet.summary())
summary_str = packet.summary()
dns_match = re.search(r"DNS Qry b'([^']+)'", summary_str)
if dns_match:
dns_query = dns_match.group(1)
if not self.dns_parser.is_blacklisted(dns_query):
self.dns_logger.info("DNS Query (fallback): %s", dns_query)
if packet.haslayer(DNS):
dns_layer = packet[DNS]
if dns_layer.qr == 0 and dns_layer.qd is not None:
try:
if isinstance(dns_layer.qd, DNSQR):
dns_query = (dns_layer.qd.qname.decode()
if isinstance(dns_layer.qd.qname, bytes)
else dns_layer.qd.qname)
else:
dns_query = ", ".join(
q.qname.decode() if isinstance(q.qname, bytes) else q.qname
for q in dns_layer.qd
)
except Exception as e:
dns_query = str(dns_layer.qd)
self.logger.info("DNS Query (from DNS layer): %s", dns_query)
try:
raw_dns_payload = bytes(dns_layer)
self.dns_parser.parse_dns_payload(raw_dns_payload)
except Exception as e:
self.logger.exception("Error processing DNS layer: %s", e)
if packet.haslayer(Ether):
eth = packet[Ether]
self.logger.info("Ethernet: src=%s, dst=%s, type=0x%04x", eth.src, eth.dst, eth.type)
else:
self.logger.warning("No Ethernet layer found.")
return
if packet.haslayer(ARP):
arp = packet[ARP]
self.logger.info("ARP: op=%s, src=%s, dst=%s", arp.op, arp.psrc, arp.pdst)
return
if packet.haslayer(IP):
ip_layer = packet[IP]
self.logger.info("IPv4: src=%s, dst=%s, ttl=%s, proto=%s",
ip_layer.src, ip_layer.dst, ip_layer.ttl, ip_layer.proto)
elif packet.haslayer(IPv6):
ip_layer = packet[IPv6]
self.logger.info("IPv6: src=%s, dst=%s, hlim=%s",
ip_layer.src, ip_layer.dst, ip_layer.hlim)
if packet.haslayer(TCP):
tcp_layer = packet[TCP]
self.logger.info("TCP: sport=%s, dport=%s", tcp_layer.sport, tcp_layer.dport)
app_name, app_details = self.identify_application(tcp_layer.sport, tcp_layer.dport)
self.logger.info("Identified Application: %s (%s)", app_name, app_details)
if app_name == "HTTPS" or (TLS and packet.haslayer(TLS)):
if TLS and packet.haslayer(TLS):
tls_layer = packet[TLS]
self.logger.info("TLS Record: %s", tls_layer.summary())
if packet.haslayer(TLSClientHello):
client_hello = packet[TLSClientHello]
self.logger.info("TLS ClientHello: %s", client_hello.summary())
if hasattr(client_hello, 'servernames'):
self.logger.info("SNI: %s", client_hello.servernames)
else:
if packet.haslayer(Raw):
payload = bytes(packet[Raw].load)
self.payload_parser.parse_tls_payload(payload)
else:
if packet.haslayer(Raw):
payload = bytes(packet[Raw].load)
self.parse_payload(packet, app_name, payload)
elif packet.haslayer(UDP):
udp_layer = packet[UDP]
self.logger.info("UDP: sport=%s, dport=%s", udp_layer.sport, udp_layer.dport)
app_name, app_details = self.identify_application(udp_layer.sport, udp_layer.dport)
self.logger.info("Identified Application: %s (%s)", app_name, app_details)
if packet.haslayer(Raw):
payload = bytes(packet[Raw].load)
self.parse_payload(packet, app_name, payload)
elif packet.haslayer(ICMP):
icmp_layer = packet[ICMP]
self.logger.info("ICMP: type=%s, code=%s", icmp_layer.type, icmp_layer.code)
else:
self.logger.warning("Unsupported transport layer.")
def run(self) -> None:
self.logger.info("Starting Enhanced Packet Sniffer on interface '%s'", self.args.interface)
try:
bpf_filter = "ip" if not self.args.local_only else ""
sniff(iface=self.args.interface, prn=self.packet_handler, store=0, filter=bpf_filter)
except KeyboardInterrupt:
self.logger.info("Stopping packet capture (KeyboardInterrupt received)...")
except Exception as e:
self.logger.exception("Error during packet capture: %s", e)
async def async_llm_label_packet(packet_summary: str, normalized_payload: str, hex_payload: str, packet_details: str, logger: Optional[logging.Logger] = None) -> int:
"""
Uses the AsyncOpenAI client to determine if a packet is malicious or benign.
Returns 1 if the answer indicates 'malicious', 0 otherwise.
"""
logger = logger or logging.getLogger(__name__)
if AsyncOpenAI is None:
logger.error("AsyncOpenAI client is not available. Please install the openai package (>=1.0.0).")
return 0
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
logger.error("OpenAI API key not found in environment (OPENAI_API_KEY).")
return 0
client = AsyncOpenAI(api_key=api_key)
xdp_details = get_xdp_metadata_details()
prompt = (
f"Examine the following packet details and decide if the packet is malicious or benign.\n\n"
f"Packet Summary:\n{packet_summary}\n\n"
f"Packet Detailed Info (Scapy dump):\n{packet_details}\n\n"
f"Normalized Payload (first 300 characters):\n{normalized_payload[:300]}\n\n"
f"Hex Payload (first 200 characters):\n{hex_payload[:200]}\n\n"
f"{xdp_details}\n\n"
"Answer with a single word: 'malicious' or 'benign'."
)
logger.debug("LLM Prompt:\n%s", prompt)
try:
response = await client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a network security analyst."},
{"role": "user", "content": prompt}
],
max_tokens=10,
temperature=0
)
answer = response.choices[0].message.content.strip().lower()
logger.debug("LLM response: %s", answer)
return 1 if "malicious" in answer else 0
except Exception as e:
logger.exception("Error querying OpenAI API: %s", e)
return 0
def llm_label_packet(packet_summary: str, normalized_payload: str, hex_payload: str, packet_details: str, logger: Optional[logging.Logger] = None) -> int:
"""
Synchronous wrapper for async_llm_label_packet.
"""
logger = logger or logging.getLogger(__name__)
new_loop = asyncio.new_event_loop()
try:
result = new_loop.run_until_complete(
async_llm_label_packet(packet_summary, normalized_payload, hex_payload, packet_details, logger=logger)
)
return result
except Exception as e:
logger.exception("Error in LLM labeling: %s", e)
return 0
finally:
new_loop.run_until_complete(new_loop.shutdown_asyncgens())
new_loop.close()
def build_dataset_main(interface: str, num_samples: int, output_path: str, use_llm: bool, logger: Optional[logging.Logger] = None) -> None:
"""
Captures packets with a Raw payload and labels them.
If use_llm is True, the LLM (via the async client) is used for automatic labeling.
The results are saved to a CSV file.
"""
logger = logger or logging.getLogger(__name__)
logger.info("Starting dataset capture on interface %s; capturing %d samples.", interface, num_samples)
samples = []
def packet_callback(packet):
if packet.haslayer(Raw):
payload = bytes(packet[Raw].load)
hex_payload = payload.hex()
normalized_payload = PayloadParser(logger=logger).normalize_payload(payload)
packet_summary = packet.summary()
packet_details = packet.show(dump=True)
print("\nPacket captured:")
print(packet_summary)
print("Packet Detailed Info:")
print(packet_details)
print("Normalized Payload (first 40 characters):", normalized_payload[:40])
if use_llm:
label = llm_label_packet(packet_summary, normalized_payload, hex_payload, packet_details, logger=logger)
print(f"LLM labeled this packet as: {'malicious' if label == 1 else 'benign'}")
else:
user_input = input("Label this packet as malicious (1) or benign (0) [default=0]: ").strip()
label = int(user_input) if user_input in ["0", "1"] else 0
samples.append({"payload": hex_payload, "label": label})
if len(samples) >= num_samples:
return True
return False
scapy.sniff(iface=interface, prn=packet_callback, store=0, timeout=60)
if samples:
try:
with open(output_path, "w", newline="") as csvfile:
fieldnames = ["payload", "label"]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for sample in samples:
writer.writerow(sample)
logger.info("Dataset built and saved to %s (%d samples).", output_path, len(samples))
except Exception as e:
logger.error("Failed to write dataset to %s: %s", output_path, e)
else:
logger.error("No samples were captured.")
def train_model_main(dataset_path: str, model_output_path: str, logger: Optional[logging.Logger] = None) -> None:
logger = logger or logging.getLogger(__name__)
MLClassifier.train_model(dataset_path, model_output_path, logger=logger)
def parse_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Enhanced Packet Sniffer with Red Alert, Whitelist, Sentry Mode, Dataset Building, Training, and LLM-assisted Labeling"
)
parser.add_argument("-i", "--interface", type=str, default="enp0s31f6",
help="Network interface to sniff on")
parser.add_argument("-l", "--logfile", type=str, default="sniffer.log",
help="Path to the main log file")
parser.add_argument("--no-bgcheck", action="store_true",
help="Disable IP background lookup")
parser.add_argument("--local-only", action="store_true",
help="Capture only local (private) traffic")
parser.add_argument("--red-alert", action="store_true",
help="Enable red alert mode: only log packets from non-allied countries")
parser.add_argument("--whitelist", type=str, default="US",
help="Comma-separated list of allied (whitelisted) country codes (default: US)")
parser.add_argument("--sentry", action="store_true",
help="Enable sentry mode (ML-based blocking of malicious packets)")
parser.add_argument("--model-path", type=str, default="model.pkl",
help="Path to the ML model file (used in sentry mode)")
parser.add_argument("--train", action="store_true",
help="Run training mode to build a new ML model from a dataset")
parser.add_argument("--dataset", type=str, default="",
help="Path to the CSV dataset for training")
parser.add_argument("--build-dataset", action="store_true",
help="Capture and build a labeled dataset interactively")
parser.add_argument("--num-samples", type=int, default=10,
help="Number of samples to capture when building the dataset")
parser.add_argument("--dataset-out", type=str, default="training_data.csv",
help="Path to save the built dataset CSV file")
parser.add_argument("--llm-label", action="store_true",
help="Use the OpenAI API to automatically label packets in dataset building mode")
parser.add_argument("--xdp", action="store_true",
help="Enable XDP (Express Data Path) to gather additional packet metadata")
parser.add_argument("--expensive", action="store_true",
help="Use LLM to analyze all packets in realtime (expensive mode)")
parser.add_argument("-v", "--verbosity", type=str, default="INFO",
help="Logging level (DEBUG, INFO, WARNING, ERROR)")
return parser.parse_args()
if __name__ == "__main__":
args = parse_arguments()
args.whitelist = set(code.strip().upper() for code in args.whitelist.split(','))
logging_level = getattr(logging, args.verbosity.upper(), logging.INFO)
setup_logging("logging.yaml")
logging.getLogger().setLevel(logging_level)
logger = logging.getLogger(__name__)
if args.build_dataset:
build_dataset_main(args.interface, args.num_samples, args.dataset_out, args.llm_label, logger=logger)
elif args.train:
if not args.dataset:
logger.error("Training mode requires a dataset. Please provide --dataset <path>")
else:
train_model_main(args.dataset, args.model_path, logger=logger)
else:
sniffer = PacketSniffer(args)
sniffer.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment