Last active
February 22, 2024 20:16
-
-
Save jborean93/28a3e44e3645d0ba56ad876adf33164a to your computer and use it in GitHub Desktop.
A test HTTP server with TLS enabled to test out some TLS behaviour for web based commands
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
"""Test TLS Enabled Web Server | |
A script that can start a temporary TLS enabled web server. This server | |
supports a basic GET request and will return metadata on the request from the | |
client. By default it will create an ephemeral certificate when starting up but | |
a custom certificate can be provided. Also supports client authentication by | |
providing a CA bundle to use for verification or using --tls-client-auth to | |
generate a new set of keys. | |
""" | |
from __future__ import annotations | |
import argparse | |
import datetime | |
import http.server | |
import json | |
import os | |
import os.path | |
import pathlib | |
import socket | |
import ssl | |
import sys | |
import typing as t | |
from cryptography import x509 | |
from cryptography.hazmat.backends import default_backend | |
from cryptography.hazmat.primitives.asymmetric import ec, rsa, types | |
from cryptography.hazmat.primitives.hashes import SHA256 | |
from cryptography.hazmat.primitives.serialization import ( | |
BestAvailableEncryption, | |
Encoding, | |
NoEncryption, | |
PrivateFormat, | |
pkcs12, | |
) | |
from cryptography.x509.oid import ExtendedKeyUsageOID | |
HAS_ARGCOMPLETE = True | |
try: | |
import argcomplete | |
except ImportError: | |
HAS_ARGCOMPLETE = False | |
FILE_NAME = pathlib.Path(__file__).stem | |
class HTTPHandler(http.server.BaseHTTPRequestHandler): | |
def do_GET(self): | |
self.send_response(200) | |
self.send_header("Content-type", "application/json; charset=utf-8") | |
self.end_headers() | |
cipher = self.connection.cipher() | |
tls_info = { | |
"protocol": cipher[1], | |
"cipher": cipher[0], | |
"client_cert": None, | |
} | |
b_peer_cert = self.connection.getpeercert(binary_form=True) | |
if b_peer_cert: | |
peer_cert = x509.load_der_x509_certificate(b_peer_cert) | |
tls_info["client_cert"] = peer_cert.subject.rfc4514_string() | |
print(f"TLS Client {tls_info}") | |
data = { | |
"tls": tls_info, | |
"request_headers": dict(self.headers), | |
} | |
self.wfile.write(json.dumps(data).encode("utf-8")) | |
def parse_args(argv: list[str]) -> argparse.Namespace: | |
parser = argparse.ArgumentParser( | |
prog="tls_server.py", | |
description="Test TLS HTTP Server in Python", | |
) | |
parse_path = lambda v: pathlib.Path(os.path.expanduser(os.path.expandvars(v))) | |
parser.add_argument( | |
"--tls-cert", | |
action="store", | |
type=parse_path, | |
help="Path to PEM encoded certificate with option embedded key, will use self signed certificate if not set.", | |
) | |
parser.add_argument( | |
"--tls-key", | |
action="store", | |
type=parse_path, | |
help="Path to PEM encoded key for the certificate if not present in --tls-cert.", | |
) | |
parser.add_argument( | |
"--tls-key-pass", | |
action="store", | |
type=str, | |
help="The password needed to decrypt the TLS key provided, can be omitted if the key is not encrypted.", | |
) | |
client_ca = parser.add_mutually_exclusive_group() | |
client_ca.add_argument( | |
"--tls-client-ca", | |
action="store", | |
type=parse_path, | |
help="Path to a TLS CA bundle file or directory to use with identifying the client. This enforces client cert authentication if set.", | |
) | |
client_ca.add_argument( | |
"--tls-client-auth", | |
action="store_true", | |
help="Require TLS Client authentication through pre-generated certificates next to this script", | |
) | |
parser.add_argument( | |
"--tls-min-protocol", | |
action="store", | |
choices=["default", "tlsv1_2", "tlsv1_3"], | |
default="default", | |
type=str.lower, | |
help="The minimum TLS protocol to allow, the default is the default for Python.", | |
) | |
parser.add_argument( | |
"--tls-max-protocol", | |
action="store", | |
choices=["default", "tlsv1_2", "tlsv1_3"], | |
default="default", | |
type=str.lower, | |
help="The maximum TLS protocol to allow, the default is the default for Python.", | |
) | |
parser.add_argument( | |
"--tls-ciphers", | |
action="store", | |
default="", | |
type=str, | |
help="The TLS cipher suites to allow in the format of the OpenSSL cipher list string, this cannot restrict ciphers in TLS 1.3.", | |
) | |
parser.add_argument( | |
"--port", | |
action="store", | |
type=int, | |
default=0, | |
help="The port to listen on, defaults to an ephemeral port available on the host", | |
) | |
if HAS_ARGCOMPLETE: | |
argcomplete.autocomplete(parser) | |
return parser.parse_args(argv) | |
def generate_cert( | |
subject: str, | |
*, | |
issuer: ( | |
tuple[x509.Certificate, types.CertificateIssuerPrivateKeyTypes] | None | |
) = None, | |
key_type: t.Literal["rsa", "ecdsa"] = "rsa", | |
extensions: list[tuple[x509.ExtensionType, bool]] | None = None, | |
) -> tuple[x509.Certificate, types.CertificateIssuerPrivateKeyTypes]: | |
private_key: types.PrivateKeyTypes | |
if key_type == "rsa": | |
private_key = rsa.generate_private_key( | |
public_exponent=65537, | |
key_size=2048, | |
backend=default_backend(), | |
) | |
else: | |
private_key = ec.generate_private_key( | |
curve=ec.SECP384R1(), | |
) | |
subject_name = x509.Name( | |
[ | |
x509.NameAttribute(x509.NameOID.COUNTRY_NAME, "Au"), | |
x509.NameAttribute(x509.NameOID.STATE_OR_PROVINCE_NAME, "State"), | |
x509.NameAttribute(x509.NameOID.LOCALITY_NAME, "City"), | |
x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, "Organization"), | |
x509.NameAttribute(x509.NameOID.COMMON_NAME, subject), | |
] | |
) | |
issuer_name = subject_name | |
sign_key: types.CertificateIssuerPrivateKeyTypes = private_key | |
if issuer: | |
issuer_name = issuer[0].subject | |
sign_key = issuer[1] | |
now = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(days=1) | |
builder = x509.CertificateBuilder() | |
builder = ( | |
x509.CertificateBuilder() | |
.subject_name(subject_name) | |
.issuer_name(issuer_name) | |
.public_key(private_key.public_key()) | |
.serial_number(x509.random_serial_number()) | |
.not_valid_before(now) | |
.not_valid_after(now + datetime.timedelta(days=365)) | |
) | |
if extensions: | |
for ext, critical in extensions: | |
builder = builder.add_extension(ext, critical) | |
return builder.sign(sign_key, SHA256()), private_key | |
def serialize_cert( | |
cert: x509.Certificate, | |
key: types.CertificateIssuerPrivateKeyTypes, | |
path: pathlib.Path, | |
*, | |
key_password: bytes | None = None, | |
cert_only: bool = False, | |
generate_pfx: bool = False, | |
) -> None: | |
b_pub_key = cert.public_bytes(Encoding.PEM) | |
b_key = b"" | |
if not cert_only: | |
encryption_algorithm = ( | |
BestAvailableEncryption(key_password) if key_password else NoEncryption() | |
) | |
b_key = key.private_bytes( | |
encoding=Encoding.PEM, | |
format=PrivateFormat.TraditionalOpenSSL, | |
encryption_algorithm=encryption_algorithm, | |
) | |
with open(path, mode="wb") as fd: | |
if b_key: | |
fd.write(b_key) | |
fd.write(b_pub_key) | |
if generate_pfx: | |
b_pfx = pkcs12.serialize_key_and_certificates( | |
cert.subject.rfc4514_string().encode(), | |
key, | |
cert, | |
None, | |
BestAvailableEncryption(key_password or b"password"), | |
) | |
pfx_path = path.with_suffix(".pfx") | |
with open(pfx_path, mode="wb") as fd: | |
fd.write(b_pfx) | |
def create_tls_context( | |
args: argparse.Namespace, | |
) -> tuple[ssl.SSLContext, list[pathlib.Path]]: | |
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) | |
if (min_protocol := args.tls_min_protocol) != "default": | |
for tls_version in ssl.TLSVersion: | |
if tls_version.name.lower() == min_protocol.lower(): | |
context.minimum_version = tls_version | |
break | |
else: | |
raise ValueError(f"Unknown --tls-min-protocol '{min_protocol}' specified") | |
if (max_protocol := args.tls_max_protocol) != "default": | |
for tls_version in ssl.TLSVersion: | |
if tls_version.name.lower() == max_protocol.lower(): | |
context.maximum_version = tls_version | |
break | |
else: | |
raise ValueError(f"Unknown --tls-max-protocol '{max_protocol}' specified") | |
if args.tls_ciphers: | |
context.set_ciphers(args.tls_ciphers) | |
my_ca: ( | |
tuple[x509.Certificate, types.CertificateIssuerPrivateKeyTypes, pathlib.Path] | |
| None | |
) = None | |
temp_files: list[pathlib.Path] = [] | |
def generate_ca() -> ( | |
tuple[x509.Certificate, types.CertificateIssuerPrivateKeyTypes, pathlib.Path] | |
): | |
my_ca = generate_cert( | |
"TlsWebServerCA", | |
extensions=[(x509.BasicConstraints(ca=True, path_length=None), True)], | |
) | |
ca_path = pathlib.Path(__file__).parent / f"{FILE_NAME}_ca.pem" | |
serialize_cert( | |
my_ca[0], | |
my_ca[1], | |
ca_path, | |
cert_only=True, | |
) | |
temp_files.append(ca_path) | |
return my_ca[0], my_ca[1], ca_path | |
if args.tls_client_ca: | |
context.verify_mode = ssl.VerifyMode.CERT_REQUIRED | |
tls_client_ca = t.cast(pathlib.Path, args.tls_client_ca) | |
if tls_client_ca.is_dir(): | |
context.load_verify_locations(capath=str(tls_client_ca.absolute())) | |
elif tls_client_ca.exists(): | |
context.load_verify_locations(cafile=str(tls_client_ca.absolute())) | |
else: | |
raise ValueError( | |
f"Certificate CA verify path '{tls_client_ca}' does not exist" | |
) | |
elif args.tls_client_auth: | |
context.verify_mode = ssl.VerifyMode.CERT_REQUIRED | |
my_ca = generate_ca() | |
client_cert = generate_cert( | |
"TlsWebServerClient", | |
issuer=(my_ca[0], my_ca[1]), | |
extensions=[ | |
( | |
x509.KeyUsage( | |
digital_signature=True, | |
content_commitment=False, | |
key_encipherment=False, | |
data_encipherment=False, | |
key_agreement=False, | |
key_cert_sign=False, | |
crl_sign=False, | |
encipher_only=False, | |
decipher_only=False, | |
), | |
True, | |
), | |
(x509.ExtendedKeyUsage([ExtendedKeyUsageOID.CLIENT_AUTH]), False), | |
], | |
) | |
client_ca_path = pathlib.Path(__file__).parent / f"{FILE_NAME}_client.pem" | |
serialize_cert( | |
client_cert[0], | |
client_cert[1], | |
client_ca_path, | |
generate_pfx=True, | |
) | |
context.load_verify_locations(cafile=str(my_ca[2].absolute())) | |
temp_files.append(client_ca_path) | |
temp_files.append(client_ca_path.with_suffix(".pfx")) | |
if not args.tls_cert: | |
if not my_ca: | |
my_ca = generate_ca() | |
hostname = socket.gethostname() | |
san = x509.SubjectAlternativeName( | |
[ | |
x509.DNSName(hostname), | |
x509.DNSName("localhost"), | |
] | |
) | |
for key_type in ["rsa", "ecdsa"]: | |
cert, key = generate_cert( | |
hostname, | |
issuer=(my_ca[0], my_ca[1]), | |
extensions=[(san, False)], | |
key_type=key_type, # type: ignore[arg-type] # This is the literal string | |
) | |
tls_key_pass = os.urandom(32) | |
temp_cert = ( | |
pathlib.Path(__file__).parent / f"tls_server_temp_cert_{key_type}.pem" | |
) | |
try: | |
serialize_cert( | |
cert, | |
key, | |
temp_cert, | |
key_password=tls_key_pass, | |
) | |
context.load_cert_chain( | |
certfile=str(temp_cert.absolute()), | |
password=tls_key_pass, | |
) | |
finally: | |
temp_cert.unlink(missing_ok=True) | |
else: | |
context.load_cert_chain( | |
certfile=str(args.tls_cert.absolute()), | |
keyfile=str(args.tls_key.absolute()) if args.tls_key else None, | |
password=args.tls_key_pass, | |
) | |
return context, temp_files | |
def main() -> None: | |
args = parse_args(sys.argv[1:]) | |
tls_context, temp_files = create_tls_context(args) | |
try: | |
httpd = http.server.HTTPServer(("", args.port), HTTPHandler) | |
httpd.socket = tls_context.wrap_socket( | |
httpd.socket, | |
server_side=True, | |
) | |
print(f"Listening on {httpd.server_address}") | |
httpd.serve_forever() | |
finally: | |
for file in temp_files: | |
file.unlink(missing_ok=True) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment