|
#!/usr/bin/env python3 |
|
"""Local CONNECT proxy that rewrites Mix Hex.pm requests for Envoy compatibility.""" |
|
from __future__ import annotations |
|
|
|
import argparse |
|
import contextlib |
|
import errno |
|
import logging |
|
import os |
|
import selectors |
|
import socket |
|
import ssl |
|
import subprocess |
|
import sys |
|
import threading |
|
from http.client import HTTPConnection, HTTPSConnection |
|
from pathlib import Path |
|
from typing import Dict, Optional, Tuple |
|
from urllib.parse import urlparse |
|
|
|
LOG = logging.getLogger("hex_proxy") |
|
DEFAULT_HOST = "127.0.0.1" |
|
DEFAULT_PORT = 8956 |
|
HEX_HOSTS = {"hex.pm"} |
|
HEX_SUFFIXES = (".hex.pm",) |
|
|
|
|
|
class HexProxyError(Exception): |
|
"""Raised for recoverable proxy handling errors.""" |
|
|
|
|
|
def should_intercept_host(host: str) -> bool: |
|
hostname = host.lower() |
|
return hostname in HEX_HOSTS or any(hostname.endswith(suffix) for suffix in HEX_SUFFIXES) |
|
|
|
|
|
class CertificateAuthority: |
|
def __init__(self, cert_dir: Path) -> None: |
|
self.cert_dir = cert_dir |
|
self.cert_dir.mkdir(parents=True, exist_ok=True) |
|
self.ca_cert = self.cert_dir / "ca.pem" |
|
self.ca_key = self.cert_dir / "ca.key" |
|
self._issued: Dict[str, Tuple[Path, Path]] = {} |
|
self._ensure_ca() |
|
|
|
def ca_path(self) -> Path: |
|
return self.ca_cert |
|
|
|
def _ensure_ca(self) -> None: |
|
if self.ca_cert.exists() and self.ca_key.exists(): |
|
return |
|
LOG.info("Generating Hex proxy certificate authority") |
|
ca_cfg = self.cert_dir / "ca.cnf" |
|
ca_cfg.write_text( |
|
""" |
|
[req] |
|
default_bits = 2048 |
|
prompt = no |
|
default_md = sha256 |
|
distinguished_name = dn |
|
x509_extensions = v3_ca |
|
|
|
[dn] |
|
CN = Hex Proxy Root |
|
|
|
[v3_ca] |
|
subjectKeyIdentifier = hash |
|
authorityKeyIdentifier = keyid:always,issuer |
|
basicConstraints = critical, CA:true |
|
keyUsage = critical, digitalSignature, cRLSign, keyCertSign |
|
""" |
|
) |
|
subprocess.run( |
|
[ |
|
"openssl", |
|
"req", |
|
"-x509", |
|
"-nodes", |
|
"-newkey", |
|
"rsa:2048", |
|
"-keyout", |
|
str(self.ca_key), |
|
"-out", |
|
str(self.ca_cert), |
|
"-days", |
|
"365", |
|
"-config", |
|
str(ca_cfg), |
|
], |
|
check=True, |
|
stdout=subprocess.DEVNULL, |
|
stderr=subprocess.DEVNULL, |
|
) |
|
|
|
def issue_certificate(self, hostname: str) -> Tuple[Path, Path]: |
|
key = hostname.lower() |
|
if key in self._issued: |
|
return self._issued[key] |
|
|
|
cert_file = self.cert_dir / f"{key.replace('*', 'wildcard')}.pem" |
|
key_file = self.cert_dir / f"{key.replace('*', 'wildcard')}.key" |
|
csr_file = self.cert_dir / f"{key.replace('*', 'wildcard')}.csr" |
|
server_cfg = self.cert_dir / f"{key.replace('*', 'wildcard')}.cnf" |
|
|
|
server_cfg.write_text( |
|
f""" |
|
[req] |
|
default_bits = 2048 |
|
prompt = no |
|
default_md = sha256 |
|
req_extensions = req_ext |
|
distinguished_name = dn |
|
|
|
[dn] |
|
CN = {hostname} |
|
|
|
[req_ext] |
|
subjectAltName = @alt_names |
|
extendedKeyUsage = serverAuth |
|
|
|
[alt_names] |
|
DNS.1 = {hostname} |
|
""" |
|
) |
|
|
|
subprocess.run( |
|
[ |
|
"openssl", |
|
"req", |
|
"-new", |
|
"-nodes", |
|
"-newkey", |
|
"rsa:2048", |
|
"-keyout", |
|
str(key_file), |
|
"-out", |
|
str(csr_file), |
|
"-config", |
|
str(server_cfg), |
|
], |
|
check=True, |
|
stdout=subprocess.DEVNULL, |
|
stderr=subprocess.DEVNULL, |
|
) |
|
try: |
|
subprocess.run( |
|
[ |
|
"openssl", |
|
"x509", |
|
"-req", |
|
"-in", |
|
str(csr_file), |
|
"-CA", |
|
str(self.ca_cert), |
|
"-CAkey", |
|
str(self.ca_key), |
|
"-CAcreateserial", |
|
"-out", |
|
str(cert_file), |
|
"-days", |
|
"30", |
|
"-extensions", |
|
"req_ext", |
|
"-extfile", |
|
str(server_cfg), |
|
], |
|
check=True, |
|
stdout=subprocess.DEVNULL, |
|
stderr=subprocess.DEVNULL, |
|
) |
|
finally: |
|
with contextlib.suppress(FileNotFoundError): |
|
csr_file.unlink() |
|
with contextlib.suppress(FileNotFoundError): |
|
server_cfg.unlink() |
|
self._issued[key] = (cert_file, key_file) |
|
return cert_file, key_file |
|
|
|
|
|
def read_connect_line(sock: socket.socket) -> Tuple[str, int]: |
|
with sock.makefile("rb", buffering=0) as reader: |
|
line = reader.readline() |
|
if not line: |
|
raise HexProxyError("client closed connection before CONNECT request") |
|
try: |
|
method, target, _version = line.strip().split() |
|
except ValueError as exc: |
|
raise HexProxyError(f"invalid CONNECT request line: {line!r}") from exc |
|
if method.upper() != b"CONNECT": |
|
raise HexProxyError(f"unsupported method {method.decode(errors='ignore')}") |
|
host, _, port_str = target.partition(b":") |
|
port = int(port_str or b"443") |
|
while True: |
|
header_line = reader.readline() |
|
if not header_line or header_line in (b"\r\n", b"\n"): |
|
break |
|
return host.decode("ascii"), port |
|
|
|
|
|
def read_http_request(stream) -> Optional[Tuple[str, str, str, Dict[str, str], bytes]]: |
|
line = stream.readline() |
|
if not line: |
|
return None |
|
parts = line.decode("iso-8859-1").strip().split() |
|
if len(parts) != 3: |
|
raise HexProxyError(f"malformed request line: {line!r}") |
|
method, target, version = parts |
|
headers: Dict[str, str] = {} |
|
while True: |
|
header_line = stream.readline() |
|
if not header_line: |
|
break |
|
if header_line in (b"\r\n", b"\n"): |
|
break |
|
try: |
|
key, value = header_line.decode("iso-8859-1").split(":", 1) |
|
except ValueError as exc: |
|
raise HexProxyError(f"invalid header line: {header_line!r}") from exc |
|
headers[key.strip().lower()] = value.strip() |
|
content_length = int(headers.get("content-length", "0")) |
|
body = stream.read(content_length) if content_length else b"" |
|
return method, target, version, headers, body |
|
|
|
|
|
class UpstreamClient: |
|
def __init__( |
|
self, |
|
upstream_proxy: Optional[str], |
|
verify_path: Optional[str], |
|
disable_verify: bool, |
|
) -> None: |
|
self.verify_path = verify_path |
|
self.disable_verify = disable_verify |
|
self.proxy = urlparse(upstream_proxy) if upstream_proxy else None |
|
|
|
def _create_ssl_context(self) -> ssl.SSLContext: |
|
if self.disable_verify: |
|
context = ssl.create_default_context() |
|
context.check_hostname = False |
|
context.verify_mode = ssl.CERT_NONE |
|
return context |
|
if self.verify_path: |
|
return ssl.create_default_context(cafile=self.verify_path) |
|
return ssl.create_default_context() |
|
|
|
def request( |
|
self, method: str, url: str, headers: Dict[str, str], body: bytes |
|
) -> Tuple[int, str, Dict[str, str], bytes]: |
|
parsed = urlparse(url) |
|
if parsed.scheme not in {"https", "http"}: |
|
raise HexProxyError(f"unsupported URL scheme: {parsed.scheme}") |
|
|
|
if parsed.scheme == "https": |
|
context = self._create_ssl_context() |
|
if self.proxy: |
|
if self.proxy.scheme not in {"http", "https"}: |
|
raise HexProxyError(f"unsupported proxy scheme: {self.proxy.scheme}") |
|
host = self.proxy.hostname |
|
port = self.proxy.port or (443 if self.proxy.scheme == "https" else 80) |
|
conn = HTTPSConnection(host, port, context=context) |
|
conn.set_tunnel(parsed.hostname, parsed.port or 443) |
|
else: |
|
conn = HTTPSConnection(parsed.hostname, parsed.port or 443, context=context) |
|
path = build_path(parsed) |
|
else: # http |
|
if self.proxy: |
|
if self.proxy.scheme not in {"http", "https"}: |
|
raise HexProxyError(f"unsupported proxy scheme: {self.proxy.scheme}") |
|
host = self.proxy.hostname |
|
port = self.proxy.port or (443 if self.proxy.scheme == "https" else 80) |
|
if self.proxy.scheme == "https": |
|
context = self._create_ssl_context() |
|
conn = HTTPSConnection(host, port, context=context) |
|
else: |
|
conn = HTTPConnection(host, port) |
|
path = build_path(parsed, absolute=True) |
|
else: |
|
conn = HTTPConnection(parsed.hostname, parsed.port or 80) |
|
path = build_path(parsed) |
|
|
|
conn.request(method, path or "/", body=body, headers=headers) |
|
response = conn.getresponse() |
|
status = response.status |
|
reason = response.reason |
|
response_headers = {key.lower(): value for key, value in response.getheaders()} |
|
data = response.read() |
|
conn.close() |
|
return status, reason, response_headers, data |
|
|
|
|
|
def build_path(parsed, *, absolute: bool = False) -> str: |
|
path = parsed.path or "/" |
|
if parsed.params: |
|
path += f";{parsed.params}" |
|
if parsed.query: |
|
path += f"?{parsed.query}" |
|
if parsed.fragment: |
|
path += f"#{parsed.fragment}" |
|
if absolute: |
|
return f"{parsed.scheme}://{parsed.netloc}{path}" |
|
return path |
|
|
|
|
|
class HttpProxyWorker(threading.Thread): |
|
def __init__(self, client_sock: socket.socket, upstream_client: UpstreamClient) -> None: |
|
super().__init__(daemon=True) |
|
self.client_sock = client_sock |
|
self.upstream_client = upstream_client |
|
|
|
def run(self) -> None: |
|
try: |
|
self.client_sock.settimeout(30) |
|
with contextlib.ExitStack() as stack: |
|
rfile = stack.enter_context(self.client_sock.makefile("rb")) |
|
wfile = stack.enter_context(self.client_sock.makefile("wb")) |
|
while True: |
|
request = read_http_request(rfile) |
|
if request is None: |
|
break |
|
method, target, _version, headers, body = request |
|
response = self.handle_request(method, target, headers, body) |
|
if response is None: |
|
break |
|
status_code, reason, resp_headers, resp_body = response |
|
wfile.write(f"HTTP/1.1 {status_code} {reason}\r\n".encode("iso-8859-1")) |
|
for key, value in resp_headers.items(): |
|
wfile.write(f"{key}: {value}\r\n".encode("iso-8859-1")) |
|
wfile.write(b"\r\n") |
|
if resp_body: |
|
wfile.write(resp_body) |
|
wfile.flush() |
|
except HexProxyError as exc: |
|
LOG.warning("HTTP proxy error: %s", exc) |
|
except socket.timeout as exc: |
|
LOG.debug("HTTP proxy timeout: %s", exc) |
|
except Exception: # pragma: no cover - defensive logging |
|
LOG.exception("Unhandled HTTP proxy exception") |
|
finally: |
|
with contextlib.suppress(Exception): |
|
self.client_sock.close() |
|
|
|
def handle_request( |
|
self, |
|
method: str, |
|
target: str, |
|
headers: Dict[str, str], |
|
body: bytes, |
|
) -> Optional[Tuple[int, str, Dict[str, str], bytes]]: |
|
if method.upper() not in {"GET", "HEAD"}: |
|
return 405, "Method Not Allowed", {"content-length": "0"}, b"" |
|
|
|
if target.startswith("http://") or target.startswith("https://"): |
|
url = target |
|
else: |
|
host = headers.get("host") |
|
if not host: |
|
raise HexProxyError("missing Host header for HTTP request") |
|
url = f"http://{host}{target}" |
|
|
|
cleaned_headers = { |
|
key: value |
|
for key, value in headers.items() |
|
if key |
|
not in { |
|
"connection", |
|
"proxy-connection", |
|
"keep-alive", |
|
"te", |
|
"transfer-encoding", |
|
"content-length", |
|
} |
|
} |
|
|
|
try: |
|
LOG.debug("Forwarding HTTP %s %s", method, url) |
|
status, reason, upstream_headers, data = self.upstream_client.request( |
|
method, url, cleaned_headers, body |
|
) |
|
except HexProxyError as exc: |
|
LOG.warning("Upstream HTTP request failed: %s", exc) |
|
return 502, "Bad Gateway", {"content-length": "0"}, b"" |
|
|
|
hop_by_hop = { |
|
"connection", |
|
"proxy-connection", |
|
"keep-alive", |
|
"transfer-encoding", |
|
"te", |
|
"trailer", |
|
} |
|
response_headers = { |
|
key: value for key, value in upstream_headers.items() if key not in hop_by_hop |
|
} |
|
response_headers.setdefault("content-length", str(len(data))) |
|
return status, reason, response_headers, data |
|
|
|
|
|
class ProxyWorker(threading.Thread): |
|
def __init__( |
|
self, |
|
client_sock: socket.socket, |
|
ssl_context: ssl.SSLContext, |
|
upstream_host: str, |
|
upstream_client: UpstreamClient, |
|
) -> None: |
|
super().__init__(daemon=True) |
|
self.client_sock = client_sock |
|
self.ssl_context = ssl_context |
|
self.upstream_host = upstream_host |
|
self.upstream_client = upstream_client |
|
|
|
def run(self) -> None: |
|
try: |
|
self.client_sock.settimeout(30) |
|
self.client_sock.sendall(b"HTTP/1.1 200 Connection Established\r\n\r\n") |
|
with self.ssl_context.wrap_socket(self.client_sock, server_side=True) as tls_sock: |
|
tls_sock.settimeout(30) |
|
with contextlib.ExitStack() as stack: |
|
rfile = stack.enter_context(tls_sock.makefile("rb")) |
|
wfile = stack.enter_context(tls_sock.makefile("wb")) |
|
while True: |
|
request = read_http_request(rfile) |
|
if request is None: |
|
break |
|
method, target, _version, headers, body = request |
|
response = self.handle_request(method, target, headers, body) |
|
if response is None: |
|
break |
|
status_code, reason, resp_headers, resp_body = response |
|
wfile.write(f"HTTP/1.1 {status_code} {reason}\r\n".encode("iso-8859-1")) |
|
for key, value in resp_headers.items(): |
|
header_line = f"{key}: {value}\r\n" |
|
wfile.write(header_line.encode("iso-8859-1")) |
|
wfile.write(b"\r\n") |
|
if resp_body: |
|
wfile.write(resp_body) |
|
wfile.flush() |
|
except HexProxyError as exc: |
|
LOG.warning("Proxy worker error: %s", exc) |
|
except (socket.timeout, ssl.SSLError) as exc: |
|
LOG.debug("Socket error: %s", exc) |
|
except Exception: # pragma: no cover - defensive logging |
|
LOG.exception("Unhandled proxy worker exception") |
|
finally: |
|
with contextlib.suppress(Exception): |
|
self.client_sock.close() |
|
|
|
def handle_request( |
|
self, |
|
method: str, |
|
target: str, |
|
headers: Dict[str, str], |
|
body: bytes, |
|
) -> Optional[Tuple[int, str, Dict[str, str], bytes]]: |
|
if method.upper() not in {"GET", "HEAD"}: |
|
return 405, "Method Not Allowed", {"content-length": "0"}, b"" |
|
|
|
if target.startswith("http://") or target.startswith("https://"): |
|
url = target |
|
else: |
|
host = headers.get("host", self.upstream_host) |
|
url = f"https://{host}{target}" |
|
|
|
cleaned_headers = { |
|
key: value |
|
for key, value in headers.items() |
|
if key |
|
not in { |
|
"connection", |
|
"proxy-connection", |
|
"keep-alive", |
|
"te", |
|
"transfer-encoding", |
|
"content-length", |
|
} |
|
} |
|
|
|
try: |
|
LOG.debug("Forwarding %s %s", method, url) |
|
status, reason, upstream_headers, data = self.upstream_client.request( |
|
method, url, cleaned_headers, body |
|
) |
|
except HexProxyError as exc: |
|
LOG.warning("Upstream request failed: %s", exc) |
|
return 502, "Bad Gateway", {"content-length": "0"}, b"" |
|
|
|
hop_by_hop = { |
|
"connection", |
|
"proxy-connection", |
|
"keep-alive", |
|
"transfer-encoding", |
|
"te", |
|
"trailer", |
|
} |
|
response_headers = { |
|
key: value for key, value in upstream_headers.items() if key not in hop_by_hop |
|
} |
|
response_headers.setdefault("content-length", str(len(data))) |
|
return status, reason, response_headers, data |
|
|
|
|
|
class TunnelWorker(threading.Thread): |
|
def __init__( |
|
self, |
|
client_sock: socket.socket, |
|
target_host: str, |
|
target_port: int, |
|
upstream_proxy: Optional[str], |
|
verify_path: Optional[str], |
|
disable_verify: bool, |
|
) -> None: |
|
super().__init__(daemon=True) |
|
self.client_sock = client_sock |
|
self.target_host = target_host |
|
self.target_port = target_port |
|
self.upstream_proxy = urlparse(upstream_proxy) if upstream_proxy else None |
|
self.verify_path = verify_path |
|
self.disable_verify = disable_verify |
|
|
|
def run(self) -> None: |
|
try: |
|
self.client_sock.settimeout(30) |
|
upstream_sock = self._open_tunnel() |
|
except HexProxyError as exc: |
|
LOG.warning("Failed to establish tunnel for %s:%s: %s", self.target_host, self.target_port, exc) |
|
self._send_error_response() |
|
return |
|
except Exception as exc: # pragma: no cover - defensive logging |
|
LOG.exception("Unexpected tunnel error") |
|
self._send_error_response() |
|
return |
|
|
|
try: |
|
self.client_sock.sendall(b"HTTP/1.1 200 Connection Established\r\n\r\n") |
|
self._pump(upstream_sock) |
|
finally: |
|
with contextlib.suppress(Exception): |
|
upstream_sock.close() |
|
with contextlib.suppress(Exception): |
|
self.client_sock.close() |
|
|
|
def _open_tunnel(self) -> socket.socket: |
|
if self.upstream_proxy: |
|
if self.upstream_proxy.scheme not in {"http", "https"}: |
|
raise HexProxyError(f"unsupported upstream proxy scheme: {self.upstream_proxy.scheme}") |
|
host = self.upstream_proxy.hostname |
|
if host is None: |
|
raise HexProxyError("upstream proxy host missing") |
|
port = self.upstream_proxy.port or (443 if self.upstream_proxy.scheme == "https" else 80) |
|
raw_sock = socket.create_connection((host, port), timeout=30) |
|
if self.upstream_proxy.scheme == "https": |
|
if self.disable_verify: |
|
context = ssl.create_default_context() |
|
context.check_hostname = False |
|
context.verify_mode = ssl.CERT_NONE |
|
elif self.verify_path: |
|
context = ssl.create_default_context(cafile=self.verify_path) |
|
else: |
|
context = ssl.create_default_context() |
|
raw_sock = context.wrap_socket(raw_sock, server_hostname=host) |
|
connect_request = ( |
|
f"CONNECT {self.target_host}:{self.target_port} HTTP/1.1\r\n" |
|
f"Host: {self.target_host}:{self.target_port}\r\n\r\n" |
|
).encode("ascii") |
|
raw_sock.sendall(connect_request) |
|
response = b"" |
|
while b"\r\n\r\n" not in response: |
|
chunk = raw_sock.recv(4096) |
|
if not chunk: |
|
raise HexProxyError("upstream proxy closed connection during CONNECT negotiation") |
|
response += chunk |
|
status_line = response.split(b"\r\n", 1)[0] |
|
if not (status_line.startswith(b"HTTP/1.1 200") or status_line.startswith(b"HTTP/1.0 200")): |
|
raise HexProxyError(f"upstream proxy tunnel failed: {status_line.decode('iso-8859-1', errors='replace')}") |
|
return raw_sock |
|
return socket.create_connection((self.target_host, self.target_port), timeout=30) |
|
|
|
def _pump(self, upstream_sock: socket.socket) -> None: |
|
upstream_sock.settimeout(30) |
|
selector = selectors.DefaultSelector() |
|
try: |
|
selector.register(self.client_sock, selectors.EVENT_READ) |
|
selector.register(upstream_sock, selectors.EVENT_READ) |
|
while True: |
|
events = selector.select(timeout=30) |
|
if not events: |
|
continue |
|
for key, _ in events: |
|
src: socket.socket = key.fileobj # type: ignore[assignment] |
|
try: |
|
data = src.recv(8192) |
|
except socket.timeout: |
|
continue |
|
except ConnectionResetError: |
|
return |
|
except OSError as exc: |
|
if exc.errno == errno.ECONNRESET: |
|
return |
|
raise |
|
if not data: |
|
return |
|
dst = upstream_sock if src is self.client_sock else self.client_sock |
|
dst.sendall(data) |
|
finally: |
|
selector.close() |
|
|
|
def _send_error_response(self) -> None: |
|
with contextlib.suppress(Exception): |
|
self.client_sock.sendall(b"HTTP/1.1 502 Bad Gateway\r\nContent-Length: 0\r\n\r\n") |
|
with contextlib.suppress(Exception): |
|
self.client_sock.close() |
|
|
|
|
|
class HexProxyServer: |
|
def __init__( |
|
self, |
|
listen_host: str, |
|
listen_port: int, |
|
cert_authority: CertificateAuthority, |
|
upstream_client: UpstreamClient, |
|
) -> None: |
|
self.listen_host = listen_host |
|
self.listen_port = listen_port |
|
self.cert_authority = cert_authority |
|
self.upstream_client = upstream_client |
|
self._shutdown = threading.Event() |
|
|
|
def serve_forever(self) -> None: |
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_sock: |
|
server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
|
server_sock.bind((self.listen_host, self.listen_port)) |
|
server_sock.listen(5) |
|
LOG.info("Hex proxy listening on %s:%s", self.listen_host, self.listen_port) |
|
while not self._shutdown.is_set(): |
|
try: |
|
client_sock, _addr = server_sock.accept() |
|
except OSError: |
|
if self._shutdown.is_set(): |
|
break |
|
raise |
|
try: |
|
client_sock.settimeout(5) |
|
peek = client_sock.recv(7, socket.MSG_PEEK) |
|
except socket.timeout: |
|
client_sock.close() |
|
continue |
|
except OSError as exc: |
|
LOG.debug("Failed to peek client data: %s", exc) |
|
client_sock.close() |
|
continue |
|
|
|
if not peek: |
|
client_sock.close() |
|
continue |
|
|
|
if peek.upper().startswith(b"CONNECT"): |
|
try: |
|
upstream_host, upstream_port = read_connect_line(client_sock) |
|
except HexProxyError as exc: |
|
LOG.warning("CONNECT negotiation failed: %s", exc) |
|
client_sock.close() |
|
continue |
|
|
|
if should_intercept_host(upstream_host): |
|
LOG.debug("Intercepting %s:%s for header sanitization", upstream_host, upstream_port) |
|
cert_file, key_file = self.cert_authority.issue_certificate(upstream_host) |
|
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) |
|
ssl_context.load_cert_chain(certfile=str(cert_file), keyfile=str(key_file)) |
|
worker: threading.Thread = ProxyWorker( |
|
client_sock, |
|
ssl_context, |
|
upstream_host, |
|
self.upstream_client, |
|
) |
|
else: |
|
LOG.debug("Tunneling %s:%s without interception", upstream_host, upstream_port) |
|
upstream_proxy = self.upstream_client.proxy.geturl() if self.upstream_client.proxy else None |
|
worker = TunnelWorker( |
|
client_sock, |
|
upstream_host, |
|
upstream_port, |
|
upstream_proxy, |
|
self.upstream_client.verify_path, |
|
self.upstream_client.disable_verify, |
|
) |
|
else: |
|
LOG.debug("Handling plain HTTP proxy request") |
|
worker = HttpProxyWorker(client_sock, self.upstream_client) |
|
worker.start() |
|
|
|
def shutdown(self) -> None: |
|
self._shutdown.set() |
|
|
|
|
|
def parse_args(argv: Optional[list[str]] = None) -> argparse.Namespace: |
|
parser = argparse.ArgumentParser(description="Local proxy that sanitizes Hex.pm requests for mix") |
|
parser.add_argument("--listen-host", default=DEFAULT_HOST) |
|
parser.add_argument("--listen-port", type=int, default=DEFAULT_PORT) |
|
parser.add_argument( |
|
"--upstream-proxy", |
|
help="Explicit upstream proxy URL. Defaults to HTTP_PROXY/HTTPS_PROXY env vars.", |
|
) |
|
parser.add_argument( |
|
"--cert-dir", |
|
default=str(Path.home() / ".cache" / "hex-proxy-cert"), |
|
help="Directory where the generated TLS certificate is stored.", |
|
) |
|
parser.add_argument("--verbose", action="store_true", help="Enable debug logging") |
|
return parser.parse_args(argv) |
|
|
|
|
|
def main(argv: Optional[list[str]] = None) -> int: |
|
args = parse_args(argv) |
|
|
|
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO, format="[%(levelname)s] %(message)s") |
|
|
|
upstream_proxy = args.upstream_proxy or os.environ.get("HTTPS_PROXY") or os.environ.get("https_proxy") |
|
if upstream_proxy: |
|
LOG.info("Forwarding requests through upstream proxy %s", upstream_proxy) |
|
else: |
|
LOG.warning("No upstream proxy configured; direct egress must be allowed") |
|
|
|
verify_path = os.environ.get("REQUESTS_CA_BUNDLE") or os.environ.get("SSL_CERT_FILE") |
|
disable_verify_env = os.environ.get("HEX_PROXY_DISABLE_UPSTREAM_VERIFY", "") |
|
disable_verify = disable_verify_env.lower() in {"1", "true", "yes", "on"} |
|
if disable_verify: |
|
LOG.warning("Upstream TLS verification disabled; relying on MITM proxy trust") |
|
|
|
try: |
|
cert_authority = CertificateAuthority(Path(args.cert_dir)) |
|
except FileNotFoundError as exc: |
|
LOG.error("Unable to create certificate directory: %s", exc) |
|
return 1 |
|
except subprocess.CalledProcessError: |
|
LOG.error("openssl binary is required to generate certificates") |
|
return 1 |
|
|
|
upstream_client = UpstreamClient(upstream_proxy, verify_path, disable_verify) |
|
|
|
server = HexProxyServer(args.listen_host, args.listen_port, cert_authority, upstream_client) |
|
try: |
|
server.serve_forever() |
|
except KeyboardInterrupt: |
|
LOG.info("Shutting down proxy") |
|
server.shutdown() |
|
return 0 |
|
|
|
|
|
if __name__ == "__main__": |
|
sys.exit(main()) |