Created
April 14, 2022 20:45
-
-
Save jonashaag/3773351576fcc56632f285277029865c to your computer and use it in GitHub Desktop.
Simple caching Conda proxy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import atexit | |
import base64 | |
import logging | |
import os | |
import pickle | |
import diskcache | |
import proxy2 | |
logger = logging.getLogger("conda_proxy") | |
class DiskCache: | |
def __init__(self, path, size: int): | |
self._cache = diskcache.Cache( | |
path, eviction_policy="least-frequently-used", size_limit=size | |
) | |
def _key(self, path: str) -> str: | |
return base64.b64encode(path.encode()).decode() | |
def is_cached(self, path: str) -> bool: | |
return self._key(path) in self._cache | |
def read_bytes(self, path: str) -> bytes: | |
return self._cache.get(self._key(path)) | |
def write_bytes(self, path: str, blob: bytes) -> None: | |
self._cache.add(self._key(path), blob) | |
def close(self) -> None: | |
self._cache.close() | |
class CondaProxyRequestHandler(proxy2.ProxyRequestHandler): | |
def request_handler(self, req_body): | |
if self.path in repodata_cache: | |
logger.info(f"Found {self.path} in cache.") | |
self.cache_miss = False | |
return repodata_cache[self.path] | |
elif packages_cache.is_cached(self.path): | |
logger.info(f"Found {self.path} in cache.") | |
self.cache_miss = False | |
return pickle.loads(packages_cache.read_bytes(self.path)) | |
else: | |
self.cache_miss = True | |
def response_handler(self, req_body, res, res_body): | |
if self.cache_miss: | |
if self.path.endswith("repodata.json"): | |
logger.info(f"Adding {self.path} to cache.") | |
repodata_cache.set(self.path, (res, res_body), expire=repodata_ttl) | |
else: | |
logger.info(f"Adding {self.path} to cache.") | |
packages_cache.write_bytes(self.path, pickle.dumps((res, res_body))) | |
persistence_path = os.environ.get("CONDA_PROXY_CACHE_PATH", "/tmp/uvproxy") | |
cache_max_size = int(os.environ.get("CONDA_PROXY_CACHE_SIZE", 1e9)) | |
repodata_ttl = int(os.environ.get("CONDA_PROXY_REPODATA_TTL", 3600)) | |
timeout = int(os.environ.get("CONDA_PROXY_HTTP_TIMEOUT", "300")) | |
# We need two diferent caches with their own strategies | |
packages_cache = DiskCache(persistence_path, cache_max_size) | |
repodata_cache = diskcache.Cache(persistence_path) | |
atexit.register(packages_cache.close) | |
atexit.register(repodata_cache.close) | |
def main(): | |
import http.server | |
import socket | |
import socketserver | |
class ThreadingHTTPServer(socketserver.ThreadingMixIn, http.server.HTTPServer): | |
address_family = socket.AF_INET6 | |
daemon_threads = True | |
logging.basicConfig(level="DEBUG") | |
httpd = ThreadingHTTPServer(("localhost", 8080), CondaProxyRequestHandler) | |
httpd.serve_forever() | |
if __name__ == "__main__": | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# From: https://github.com/inaz2/proxy2/pull/6 | |
# BSD 3 clause | |
import http.client | |
import http.server | |
import os | |
import re | |
import ssl | |
import string | |
import threading | |
import time | |
import urllib.parse | |
import OpenSSL | |
import ssl_wrapper | |
def join_with_script_dir(path): | |
return os.path.join(os.path.dirname(os.path.abspath(__file__)), path) | |
class ProxyRequestHandler(http.server.BaseHTTPRequestHandler): | |
cakey = join_with_script_dir('ca.key') | |
cacert = join_with_script_dir('ca.crt') | |
certkey = join_with_script_dir('cert.key') | |
certdir = join_with_script_dir('certs/') | |
timeout = 10 | |
lock = threading.Lock() | |
def __init__(self, *args, **kwargs): | |
self.tls = threading.local() | |
self.tls.conns = {} | |
super().__init__(*args, **kwargs) | |
def do_CONNECT(self): | |
hostname = self.path.split(':')[0] | |
ippat = re.compile(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$") | |
cert_category = "DNS" | |
if ippat.match(hostname): | |
cert_category = "IP" | |
certpath = "%s/%s.crt" % (ssl_wrapper.cert_dir.rstrip('/'), hostname) | |
with self.lock: | |
if not os.path.isfile(certpath): | |
x509_serial = int("%d" % (time.time() * 1000)) | |
valid_time_interval = (0, 60 * 60 * 24 * 365) | |
cert_request = ssl_wrapper.create_cert_request(ssl_wrapper.cert_key_obj, CN=hostname) | |
cert = ssl_wrapper.create_certificate( | |
cert_request, (ssl_wrapper.ca_crt_obj, ssl_wrapper.ca_key_obj), x509_serial, | |
valid_time_interval, | |
subject_alt_names=[ | |
string.Template("${category}:${hostname}").substitute(hostname=hostname, category=cert_category) | |
] | |
) | |
with open(certpath, 'wb+') as f: | |
f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)) | |
self.wfile.write("HTTP/1.1 {} {}\r\n".format(200, 'Connection Established').encode('latin-1')) | |
self.wfile.write(b'\r\n') | |
self.connection = ssl.wrap_socket(self.connection, | |
keyfile=ssl_wrapper.cert_key, | |
certfile=certpath, | |
server_side=True) | |
self.rfile = self.connection.makefile("rb", self.rbufsize) | |
self.wfile = self.connection.makefile("wb", self.wbufsize) | |
conntype = self.headers.get('Proxy-Connection', '') | |
if conntype.lower() != 'close': | |
self.close_connection = False | |
def do_GET(self): | |
content_length = int(self.headers.get('Content-Length', 0)) | |
req_body = self.rfile.read(content_length) if content_length else None | |
if self.path[0] == '/': | |
if isinstance(self.connection, ssl.SSLSocket): | |
self.path = "https://{}{}".format(self.headers['Host'], self.path) | |
else: | |
self.path = "http://{}{}".format(self.headers['Host'], self.path) | |
req_body_modified = self.request_handler(req_body) | |
if req_body_modified is False: | |
self.send_error(403) | |
return | |
if isinstance(req_body_modified, tuple): | |
res, res_body = req_body_modified | |
else: | |
if req_body_modified is not None: | |
req_body = req_body_modified | |
if 'Content-Length' in self.headers: | |
del self.headers['Content-Length'] | |
self.headers['Content-Length'] = str(len(req_body_modified)) | |
res, res_body = self._make_req(req_body) | |
if 'Content-Length' not in res.msg: | |
res.msg['Content-Length'] = str(len(res_body)) | |
setattr(res, 'headers', self.filter_headers(res.msg)) | |
self.wfile.write(f"HTTP/1.1 {res.status} {res.reason}\r\n".encode("ascii")) | |
for k, v in res.msg.items(): | |
self.send_header(k, v) | |
self.end_headers() | |
if res_body: | |
self.wfile.write(res_body) | |
self.wfile.flush() | |
def _make_req(self, req_body): | |
url = urllib.parse.urlsplit(self.path) | |
scheme, netloc, path = url.scheme, url.netloc, (url.path + '?' + url.query if url.query else url.path) | |
assert scheme in ('http', 'https') | |
origin = (scheme, netloc) | |
if netloc: | |
if 'Host' in self.headers: | |
del self.headers['Host'] | |
self.headers['Host'] = netloc | |
setattr(self, 'headers', self.filter_headers(self.headers)) | |
# Make connection to upstream | |
conn = self.tls.conns.get(origin) | |
if conn is None: | |
conn = self.tls.conns[origin] = { | |
"https": http.client.HTTPSConnection, | |
"http": http.client.HTTPConnection, | |
}[scheme](netloc, timeout=self.timeout) | |
try: | |
conn.request(self.command, path, req_body, dict(self.headers)) | |
res = conn.getresponse() | |
res_body = res.read() | |
self.response_handler(req_body, res, res_body) | |
return res, res_body | |
except Exception: | |
self.tls.conns.pop(origin, None) | |
raise | |
do_HEAD = do_GET | |
do_POST = do_GET | |
do_PUT = do_GET | |
do_DELETE = do_GET | |
do_OPTIONS = do_GET | |
def filter_headers(self, headers): | |
# http://tools.ietf.org/html/rfc2616#section-13.5.1 | |
hop_by_hop = ( | |
'connection', | |
'keep-alive', | |
'proxy-authenticate', | |
'proxy-authorization', | |
'te', | |
'trailers', | |
'transfer-encoding', | |
'upgrade' | |
) | |
for k in hop_by_hop: | |
del headers[k] | |
# accept only supported encodings | |
if 'Accept-Encoding' in headers: | |
ae = headers['Accept-Encoding'] | |
filtered_encodings = [x for x in re.split(r',\s*', ae) if x in ('identity', 'gzip', 'x-gzip', 'deflate')] | |
del headers['Accept-Encoding'] | |
headers['Accept-Encoding'] = ', '.join(filtered_encodings) | |
return headers | |
def request_handler(self, req_body): | |
pass | |
def response_handler(self, req_body, res, res_body): | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment