|
#!/usr/bin/env python |
|
|
|
import argparse |
|
import itertools |
|
import json |
|
import logging |
|
import os |
|
import re |
|
import subprocess |
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer |
|
from urllib.error import HTTPError |
|
from urllib.request import urlretrieve |
|
|
|
ALLOWED_HEADERS = ["Content-Type", "Content-Length"] |
|
|
|
|
|
class BaseGetHandler(BaseHTTPRequestHandler): |
|
# https://github.com/python/cpython/blob/v3.11.3/Lib/http/server.py#L566-L569 |
|
# Not available in some older versions of Python |
|
control_char_table = str.maketrans( |
|
{c: rf"\x{c:02x}" for c in itertools.chain(range(0x20), range(0x7F, 0xA0))} |
|
) |
|
control_char_table[ord("\\")] = r"\\" |
|
|
|
def scan_file(self, filepath): |
|
""" |
|
Scan the file at filepath and return 0 if it is safe to serve, or |
|
non-zero if it should be blocked. |
|
""" |
|
raise NotImplementedError("scan_file() not implemented") |
|
|
|
def get_url(self): |
|
""" |
|
Return the URL to fetch, or None if the path is not allowed. |
|
""" |
|
raise NotImplementedError("get_url() not implemented") |
|
|
|
def do_GET(self): |
|
url = self.get_url() |
|
|
|
if not url: |
|
self.send_response(400, f"Invalid path requested: {self.path}") |
|
self.end_headers() |
|
return |
|
|
|
logging.info(f"Fetching {url}") |
|
try: |
|
filepath, headers = urlretrieve(url) |
|
except HTTPError as e: |
|
self.send_response(e.code, e.reason) |
|
self.end_headers() |
|
self.wfile.write(e.reason.encode() + b"\n") |
|
return |
|
|
|
try: |
|
rc = self.scan_file(filepath) |
|
if rc != 0: |
|
msg = f"BLOCKED: {url} may contain malicious content" |
|
logging.warning(msg) |
|
self.send_response(500, msg) |
|
self.end_headers() |
|
self.wfile.write(msg.encode() + b"\n") |
|
return |
|
|
|
self.send_response(200) |
|
for header in ALLOWED_HEADERS: |
|
if header in headers: |
|
self.send_header(header, headers[header]) |
|
self.end_headers() |
|
with open(filepath, "rb") as f: |
|
self.wfile.write(f.read()) |
|
finally: |
|
os.remove(filepath) |
|
|
|
# Override default webserver log format |
|
|
|
def _format_log(self, format, args): |
|
message = format % args |
|
return message.translate(self.control_char_table) |
|
|
|
def log_error(self, format, *args): |
|
logging.error(self._format_log(format, args)) |
|
|
|
def log_message(self, format, *args): |
|
logging.info(self._format_log(format, args)) |
|
|
|
|
|
def make_handler(cfg): |
|
class Handler(BaseGetHandler): |
|
def scan_file(self, filepath): |
|
rc = subprocess.call( |
|
[c.format(filepath=filepath) for c in cfg["scancommand"]] |
|
) |
|
return rc |
|
|
|
def get_url(self): |
|
return self.get_conda_url() |
|
|
|
def get_conda_url(self): |
|
conda_cfg = cfg["conda"] |
|
if not conda_cfg: |
|
return None |
|
|
|
m = re.match( |
|
r"^/(?P<ch>[\w-]+)/(?P<pl>[\w-]+)/(?P<fn>[\w\-\.!]+)$", |
|
self.path, |
|
) |
|
if not m: |
|
return None |
|
|
|
ch = m.group("ch") |
|
pl = m.group("pl") |
|
fn = m.group("fn") |
|
channel_cfg = conda_cfg["channels"].get(ch, None) |
|
if not channel_cfg: |
|
return None |
|
if pl not in channel_cfg["platforms"]: |
|
return None |
|
|
|
for regex in channel_cfg["allowed_re"]: |
|
if re.match(regex, fn): |
|
return f"{conda_cfg['condaserver']}/{ch}/{pl}/{fn}" |
|
return None |
|
|
|
return Handler |
|
|
|
|
|
def run(cfg): |
|
server = ThreadingHTTPServer(("0.0.0.0", cfg["port"]), make_handler(cfg)) |
|
server.serve_forever() |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--config", "-c", required=True, help="Config file") |
|
args = parser.parse_args() |
|
with open(args.config) as f: |
|
cfg = json.load(f) |
|
run(cfg) |
|
|
|
|
|
if __name__ == "__main__": |
|
logging.basicConfig( |
|
format="%(asctime)s [%(levelname)s] %(message)s", level=logging.INFO |
|
) |
|
main() |