Last active
September 3, 2022 05:57
-
-
Save flisboac/48762e176061b520a606868bc4ce089f to your computer and use it in GitHub Desktop.
Utility Python library and CLI capable of downloading a whole certificate chain; barely tested
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Requires at least Python 3.7, and typing_extensions | |
from __future__ import annotations | |
import collections.abc | |
import datetime | |
import functools | |
import os | |
import pathlib | |
import re | |
import shutil | |
import ssl | |
import sys | |
import tempfile | |
import urllib.request | |
from abc import ABCMeta, abstractmethod | |
from contextlib import contextmanager | |
from dataclasses import dataclass | |
from typing import ( | |
Any, | |
Iterable, | |
Iterator, | |
List, | |
Mapping, | |
NamedTuple, | |
Sequence, | |
Sized, | |
TextIO, | |
Tuple, | |
TypeVar, | |
) | |
from typing_extensions import TypedDict | |
_CERT_ENCODING = "utf-8" | |
_PEM_CERTIFICATE_RE = re.compile( | |
r"(?P<content>-----BEGIN CERTIFICATE-----[a-zA-Z0-9+\/=\r\n]+(-----END CERTIFICATE-----)?\r?\n?)", | |
flags=re.MULTILINE, | |
) | |
_WS_RE_S = "\s*" | |
_ENVNAME_RE_S = r"[a-zA-Z][\-_a-zA-Z0-9]*" | |
_HOSTNAME_RE_S = r"([^\[\]\\#?&:=,;]+|\[[\d:]\])" | |
_HOSTPORT_RE_S = r"\d+" | |
_INPUT_HOST_RE = re.compile( | |
rf"\A({_WS_RE_S},{_WS_RE_S})?" | |
rf"((?P<envname>{_ENVNAME_RE_S}){_WS_RE_S}\={_WS_RE_S})?" | |
rf"(?P<hostname>{_HOSTNAME_RE_S})" | |
rf"(:(?P<hostport>{_HOSTPORT_RE_S}))?" | |
) | |
_T = TypeVar("_T") | |
class RawCertificateProperty(NamedTuple): | |
key: str | |
value: str | |
RawCertificateDatetime = str | |
RawCertificateProperties = Sequence[RawCertificateProperty] # Why? | |
class _RawCertificateInfo_Optional(TypedDict, total=False): | |
OCSP: Sequence[str] | |
caIssuers: Sequence[str] | |
crlDistributionPoints: Sequence[str] | |
class _RawCertificateInfo_Required(TypedDict, total=True): | |
subject: Sequence[RawCertificateProperties] | |
issuer: Sequence[RawCertificateProperties] | |
version: int | |
serialNumber: str | |
notBefore: RawCertificateDatetime | |
notAfter: RawCertificateDatetime | |
subjectAltName: Sequence[RawCertificateProperty] | |
class RawCertificateInfo( | |
_RawCertificateInfo_Required, | |
_RawCertificateInfo_Optional, | |
TypedDict, | |
): | |
pass | |
@dataclass(frozen=True) | |
class CertificateHostInfo: | |
name: str | |
port: int | |
sni_name: str | None = None | |
@classmethod | |
def parse(self, value: str, *, sni_name: str | None = None) -> CertificateHostInfo: | |
match = _INPUT_HOST_RE.match(value) | |
if not match: | |
raise ValueError(f"Invalid certificate host/hostname value: {value}") | |
return CertificateHostInfo( | |
name=match.group("hostname"), | |
port=int(match.group("hostport") or "443"), | |
sni_name=sni_name, | |
) | |
@property | |
def server_name(self) -> str: | |
return self.sni_name or self.name | |
@dataclass(frozen=True) | |
class CertificatePrincipal( | |
collections.abc.Mapping, | |
Mapping[str, str], | |
): | |
_properties: Mapping[str, str] | |
@classmethod | |
def from_raw( | |
cls, | |
value: Sequence[RawCertificateProperties], | |
) -> CertificatePrincipal: | |
return cls({p[0][0]: p[0][1] for p in value}) | |
def __len__(self) -> int: | |
return len(self._properties) | |
def __iter__(self) -> int: | |
yield from self._properties | |
def __getitem__(self, key: str) -> str: | |
return self._properties[key] | |
class Certificate(metaclass=ABCMeta): | |
@property | |
@abstractmethod | |
def subject(self) -> CertificatePrincipal: | |
... | |
@property | |
@abstractmethod | |
def issuer(self) -> CertificatePrincipal: | |
... | |
@property | |
@abstractmethod | |
def host(self) -> CertificateHostInfo | None: | |
... | |
@property | |
@abstractmethod | |
def pem_content(self) -> str: | |
... | |
@property | |
@abstractmethod | |
def der_content(self) -> bytes: | |
... | |
@property | |
@abstractmethod | |
def not_before(self) -> datetime.datetime: | |
... | |
@property | |
@abstractmethod | |
def not_after(self) -> datetime.datetime: | |
... | |
def is_root(self) -> bool: | |
... | |
class SingleCertificate(Certificate): | |
def __init__( | |
self, | |
*, | |
host: CertificateHostInfo | None = None, | |
raw_info: RawCertificateInfo | None = None, | |
pem_content: str | None = None, | |
der_content: bytes | None = None, | |
location: str | pathlib.PurePath = None, | |
not_before: datetime.datetime | None = None, | |
not_after: datetime.datetime | None = None, | |
) -> None: | |
self._host = host | |
self._raw_info_ = raw_info | |
self._pem_content = pem_content | |
self._der_content = der_content | |
self._not_before = not_before | |
self._not_after = not_after | |
self._location = pathlib.Path(location) if location is not None else None | |
self._subject: CertificatePrincipal | None = None | |
self._issuer: CertificatePrincipal | None = None | |
@property | |
def subject(self) -> CertificatePrincipal: | |
if self._subject is not None: | |
return self._subject | |
self._subject = self._get_subject() | |
return self._subject | |
@property | |
def issuer(self) -> CertificatePrincipal: | |
if self._issuer is not None: | |
return self._issuer | |
self._issuer = self._get_issuer() | |
return self._issuer | |
@property | |
def host(self) -> CertificateHostInfo | None: | |
return self._host | |
@property | |
def location(self) -> CertificatePrincipal: | |
return self._location | |
@property | |
def pem_content(self) -> str: | |
if self._pem_content is not None: | |
return self._pem_content | |
self._pem_content = self._get_pem_content() | |
return self._pem_content | |
@property | |
def der_content(self) -> bytes: | |
if self._der_content is not None: | |
return self._der_content | |
self._der_content = self._get_der_content() | |
return self._der_content | |
@property | |
def not_before(self) -> datetime.datetime: | |
if self._not_before is not None: | |
return self._not_before | |
self._not_before = self._get_not_before_date() | |
return self._not_before | |
@property | |
def not_after(self) -> datetime.datetime: | |
if self._not_after is not None: | |
return self._not_after | |
self._not_after = self._get_not_after_date() | |
return self._not_after | |
@property | |
def _raw_info(self) -> RawCertificateInfo: | |
if self._raw_info_ is not None: | |
return self._raw_info_ | |
self._raw_info_ = self._get_raw_info() | |
return self._raw_info_ | |
def is_root(self) -> bool: | |
return self.subject == self.issuer | |
def to_chain(self) -> CertificateChain: | |
certificates = [] | |
current = self | |
while current is not None: | |
certificates.append(current) | |
current_ca_url = current._get_issuer_url() | |
if current_ca_url is not None: | |
der_content = _download(current_ca_url) | |
current = SingleCertificate(der_content=der_content) | |
else: | |
current = None | |
current = certificates[-1] | |
if not current.is_root(): | |
while current is not None: | |
for system_ca_certificate in get_system_ca_certificates(): | |
if current.issuer == system_ca_certificate.subject: | |
certificates.append(system_ca_certificate) | |
current = system_ca_certificate | |
break | |
else: | |
current = None | |
return CertificateChain(certificates) | |
def __repr__(self) -> str: | |
props = ", ".join([f"{k}={v!r}" for k, v in vars(self).items()]) | |
return f"{type(self).__name__}({props})" | |
def _get_der_content(self) -> bytes: | |
assert self._pem_content is not None, "Missing certificate content." | |
return ssl.DER_cert_to_PEM_cert(self._pem_content) | |
def _get_pem_content(self) -> str: | |
assert self._der_content is not None, "Missing certificate content." | |
return ssl.DER_cert_to_PEM_cert(self._der_content) | |
def _get_raw_info(self) -> RawCertificateInfo: | |
with _open_temp_rw_text_file(suffix=".pem") as (tmp_file, tmp_path): | |
tmp_file.write(self.pem_content) | |
tmp_file.flush() | |
info = ssl._ssl._test_decode_cert(str(tmp_path)) | |
return info | |
def _get_not_before_date(self) -> datetime.datetime: | |
raw_date_str = self._raw_info["notBefore"] | |
return self._parse_datetime(raw_date_str) | |
def _get_not_after_date(self) -> datetime.datetime: | |
raw_date_str = self._raw_info["notAfter"] | |
return self._parse_datetime(raw_date_str) | |
def _get_subject(self) -> int: | |
return CertificatePrincipal.from_raw(self._raw_info["subject"]) | |
def _get_issuer(self) -> int: | |
return CertificatePrincipal.from_raw(self._raw_info["issuer"]) | |
def _parse_datetime(self, value: str) -> datetime.datetime: | |
timestamp = ssl.cert_time_to_seconds(value) | |
return datetime.datetime.utcfromtimestamp(timestamp) | |
def _get_issuer_url(self) -> str | None: | |
if "caIssuers" in self._raw_info and len(self._raw_info["caIssuers"]) > 0: | |
return self._raw_info["caIssuers"][0] | |
return None | |
class CertificateBundle( | |
collections.abc.Sequence, | |
Sequence[SingleCertificate], | |
): | |
def __init__( | |
self, | |
certificates: Iterable[SingleCertificate] | None = None, | |
*, | |
host: CertificateHostInfo | None = None, | |
pem_content: str | Iterable[str] | None = None, | |
der_content: bytes | None = None, | |
location: str | pathlib.PurePath = None, | |
) -> None: | |
self._host = host | |
self._input_pem_content = pem_content | |
self._input_der_content = der_content | |
self._location = pathlib.Path(location) if location is not None else None | |
self._certificates_: Sequence[SingleCertificate] | None = ( | |
tuple(certificates) if certificates is not None else None | |
) | |
self._proper_chain: bool | None = None | |
self._pem_content: str | None = None | |
self._der_content: bytes | None = None | |
self._pem_content_list_: Sequence[str] | None = None | |
@property | |
def location(self) -> pathlib.PurePath | None: | |
return self._location | |
@property | |
def pem_content(self) -> str: | |
if self._pem_content is not None: | |
return self._pem_content | |
self._pem_content = self._get_pem_content() | |
return self._pem_content | |
@property | |
def der_content(self) -> bytes: | |
if self._der_content is not None: | |
return self._der_content | |
self._der_content = self._get_der_content() | |
return self._der_content | |
def is_proper_chain(self) -> bool: | |
if self._proper_chain is not None: | |
return self._proper_chain | |
self._proper_chain = self._is_proper_chain() | |
return self._proper_chain | |
def to_chain(self) -> CertificateChain: | |
assert ( | |
self.is_proper_chain() | |
), "This certificate bundle is not a proper certificate chain." | |
self_certificates = self._certificates | |
root_certificate = self_certificates[-1] | |
all_certificates = [*self_certificates[:-1], *root_certificate.to_chain()] | |
return CertificateChain(all_certificates) | |
@property | |
def _pem_content_list(self) -> str: | |
if self._pem_content_list_ is not None: | |
return self._pem_content_list_ | |
self._pem_content_list_ = self._get_pem_content_list() | |
return self._pem_content_list_ | |
@property | |
def _certificates(self) -> Sequence[SingleCertificate]: | |
if self._certificates_ is not None: | |
return self._certificates_ | |
self._certificates_ = self._get_certificates() | |
return self._certificates_ | |
def __len__(self) -> int: | |
return len(self._certificates) | |
def __getitem__(self, key: Any) -> SingleCertificate: | |
return self._certificates[key] | |
def __iter__(self) -> Iterator[SingleCertificate]: | |
yield from self._certificates | |
def __repr__(self) -> str: | |
props = ", ".join([f"{k}={v!r}" for k, v in vars(self).items()]) | |
return f"{type(self).__name__}({props})" | |
def _get_pem_content_list(self) -> Sequence[str]: | |
if self._input_pem_content is not None: | |
pem_content = tuple(_split_pem_content(_join_pem_content(self._input_pem_content))) | |
else: | |
assert ( | |
self._input_der_content is not None | |
), "Either PEM or DER content must be provided." | |
pem_content = tuple( | |
_split_pem_content(ssl.DER_cert_to_PEM_cert(self._input_der_content)) | |
) | |
return pem_content | |
def _get_pem_content(self) -> str: | |
return _join_pem_content(self._pem_content_list) | |
def _get_der_content(self) -> str: | |
return ssl.PEM_cert_to_DER_cert(self.pem_content) | |
def _get_certificates(self) -> Sequence[SingleCertificate]: | |
certificates: List[SingleCertificate] = [] | |
for i, pem_content in enumerate(self._pem_content_list): | |
if i == 0: | |
host = self._host | |
else: | |
host = None | |
certificate = SingleCertificate( | |
pem_content=pem_content, | |
host=host, | |
) | |
certificates.append(certificate) | |
return tuple(certificates) | |
def _is_proper_chain(self) -> bool: | |
return _is_proper_certificate_chain(self._certificates) | |
class CertificateChain( | |
Certificate, | |
collections.abc.Sequence, | |
Sequence[SingleCertificate], | |
): | |
def __init__( | |
self, | |
certificates: Iterable[SingleCertificate] | None = None, | |
) -> None: | |
self._certificates = list(certificates or ()) | |
self._pem_content: str | None = None | |
self._der_content: bytes | None = None | |
@property | |
def root(self) -> SingleCertificate: | |
return self._certificates[-1] | |
@property | |
def target(self) -> SingleCertificate: | |
return self._certificates[0] | |
@property | |
# Certificate | |
def subject(self) -> CertificatePrincipal: | |
return self.target.subject | |
@property | |
# Certificate | |
def issuer(self) -> CertificatePrincipal: | |
return self.target.subject | |
@property | |
# Certificate | |
def host(self) -> CertificateHostInfo | None: | |
return self.target.host | |
@property | |
# Certificate | |
def raw_info(self) -> RawCertificateInfo: | |
return self.target._raw_info | |
@property | |
# Certificate | |
def pem_content(self) -> str: | |
if self._pem_content is not None: | |
return self._pem_content | |
self._pem_content = self._get_pem_content() | |
return self._pem_content | |
@property | |
# Certificate | |
def der_content(self) -> bytes: | |
if self._der_content is not None: | |
return self._der_content | |
self._der_content = self._get_der_content() | |
return self._der_content | |
@property | |
# Certificate | |
def not_before(self) -> datetime.datetime: | |
return self.target.not_before | |
@property | |
# Certificate | |
def not_after(self) -> datetime.datetime: | |
return self.target.not_after | |
def is_root(self) -> bool: | |
return self.target.is_root() | |
def with_root(self, *certificates: SingleCertificate) -> CertificateChain: | |
return CertificateChain((*self._certificates, *certificates)) | |
def __len__(self) -> int: | |
return len(self._certificates) | |
def __getitem__(self, key: Any) -> SingleCertificate: | |
return self._certificates[key] | |
def __iter__(self) -> Iterator[SingleCertificate]: | |
yield from self._certificates | |
def __repr__(self) -> str: | |
props = ", ".join([f"{k}={v!r}" for k, v in vars(self).items()]) | |
return f"{type(self).__name__}({props})" | |
def _get_pem_content(self) -> str: | |
pem_contents = [c.pem_content for c in self._certificates] | |
return _join_pem_content(pem_contents) | |
def _get_der_content(self) -> str: | |
return ssl.PEM_cert_to_DER_cert(self.pem_content) | |
@functools.lru_cache(maxsize=None) | |
def get_system_ca_certificates(cls) -> Sequence[SingleCertificate]: | |
default_context = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH) | |
certificates: List[SingleCertificate] = [] | |
for der_content in default_context.get_ca_certs(binary_form=True): | |
certificate = SingleCertificate(der_content=der_content) | |
certificates.append(certificate) | |
return tuple(certificates) | |
def get_server_certificate_pem( | |
host: CertificateHostInfo | str, | |
*, | |
ssl_version: int | None = None, | |
ca_certs: str | pathlib.PurePath | None = None, | |
timeout: int | None = None, | |
) -> str: | |
if isinstance(host, str): | |
host = CertificateHostInfo.parse(host) | |
params = {} | |
if ssl_version is not None: | |
params["ssl_version"] = ssl_version | |
if ca_certs is not None: | |
params["ca_certs"] = str(ssl_version) | |
if timeout is not None: | |
params["timeout"] = timeout | |
hostname = (host.server_name, host.port) | |
pem_content = ssl.get_server_certificate(hostname, **params) | |
return pem_content | |
def get_server_certificate_chain( | |
host: CertificateHostInfo | str, | |
*, | |
ssl_version: int | None = None, | |
ca_certs: str | pathlib.PurePath | None = None, | |
timeout: int | None = None, | |
) -> CertificateChain: | |
if isinstance(host, str): | |
host = CertificateHostInfo.parse(host) | |
pem_content = get_server_certificate_pem( | |
host, | |
ssl_version=ssl_version, | |
ca_certs=ca_certs, | |
timeout=timeout, | |
) | |
certificate_bundle = CertificateBundle( | |
pem_content=pem_content, | |
host=host, | |
) | |
if certificate_bundle.is_proper_chain(): | |
return certificate_bundle.to_chain() | |
return certificate_bundle[0].to_chain() | |
def _is_proper_certificate_chain(chain: Iterable[SingleCertificate]) -> bool: | |
chain = list(chain) | |
root_index = len(chain) - 1 | |
if len(chain) > 1: | |
return all( | |
i == root_index or chain[i].issuer == chain[i + 1].subject | |
for i in range(len(chain)) | |
) | |
if len(chain) == 1: | |
return chain[0].is_root() | |
return False | |
def _split_pem_content(content: str) -> Iterator[str]: | |
certs = _PEM_CERTIFICATE_RE.finditer(content) | |
for c in certs: | |
yield c.group("content") | |
def _join_pem_content(contents: str | Iterable[str]) -> str: | |
if isinstance(contents, str): | |
contents = [contents] | |
return re.sub(r"(\r?\n)+", "\n", "\n".join(contents)) | |
def _download(url: str) -> bytes: | |
with urllib.request.urlopen(url) as opened_url: | |
content = opened_url.read() | |
return content | |
@contextmanager | |
def _open_temp_dir() -> Iterator[pathlib.Path]: | |
path = pathlib.Path(tempfile.gettempdir()) / f"scc-py-dir-{os.urandom(24).hex()}" | |
os.makedirs(path, exist_ok=False) | |
yield path | |
shutil.rmtree(path) | |
@contextmanager | |
def _open_temp_rw_text_file( | |
*, suffix: str | None = "" | |
) -> Iterator[Tuple[TextIO, pathlib.Path]]: | |
path = ( | |
pathlib.Path(tempfile.gettempdir()) | |
/ f"scc-py-file-{os.urandom(24).hex()}{suffix}" | |
) | |
with open(path, "w+", encoding="utf-8") as file: | |
yield file, path | |
if __name__ == "__main__": | |
## USAGE | |
sys.stdout.write(get_server_certificate_chain(sys.argv[1]).pem_content) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment