Last active
November 9, 2016 16:08
-
-
Save moriyoshi/8f203adeef94835897deb3e51a57a0ee to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datetime import timedelta | |
import http.server | |
import mock | |
import socket | |
import socketserver | |
import ssl | |
import tempfile | |
import time | |
import threading | |
import urllib.request | |
import arrow | |
import pytest | |
from OpenSSL import crypto | |
def generate_dummy_certificate(cn, signer=None, ca=False): | |
key = crypto.PKey() | |
key.generate_key(crypto.TYPE_RSA, 2048) | |
x509 = crypto.X509() | |
if ca: | |
x509.add_extensions([ | |
crypto.X509Extension(b'basicConstraints', True, b'CA:TRUE'), | |
crypto.X509Extension(b'keyUsage', False, b'digitalSignature,keyCertSign,cRLSign'), | |
]) | |
else: | |
x509.add_extensions([ | |
crypto.X509Extension(b'basicConstraints', False, b'CA:FALSE'), | |
crypto.X509Extension(b'keyUsage', False, b'digitalSignature,dataEncipherment,keyEncipherment'), | |
]) | |
x509.set_notBefore(b'19700101000000Z') | |
x509.set_notAfter((arrow.utcnow() + timedelta(days=365)).strftime('%Y%m%d%H%M%SZ').encode('utf-8')) | |
subject = crypto.X509Name(x509.get_subject()) # the argument is just a dummy; PyOpenSSL does not allow us to make an empty X509Name object... | |
subject.C = 'JP' | |
subject.ST = 'Tokyo' | |
subject.L = 'Shibuya-ku' | |
subject.O = 'XXX' | |
subject.OU = 'XXX' | |
subject.CN = cn | |
x509.set_pubkey(key) | |
x509.set_subject(subject) | |
x509.set_serial_number(1) | |
x509.set_issuer(signer['cert'].get_subject() if signer is not None else subject) | |
x509.sign(signer['key'] if signer is not None else key, 'sha256') | |
return {'cert': x509, 'key': key} | |
class TLSTCPServer(socketserver.TCPServer): | |
arrow_reuse_address = 1 | |
def __init__(self, server_address, RequestHandlerClass, bind_and_activate, cert_file, key_file): | |
socketserver.BaseServer.__init__(self, server_address, RequestHandlerClass) | |
ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) | |
ssl_ctx.load_cert_chain(cert_file, key_file) | |
self.socket = ssl_ctx.wrap_socket(socket.socket(self.address_family, self.socket_type), server_side=True) | |
if bind_and_activate: | |
try: | |
self.server_bind() | |
self.server_activate() | |
except: | |
self.server_close() | |
raise | |
def server_bind(self): | |
socketserver.TCPServer.server_bind(self) | |
host, port = self.socket.getsockname()[:2] | |
self.server_name = socket.getfqdn(host) | |
self.server_port = port | |
def main(): | |
root = generate_dummy_certificate('Root CA', None, ca=True) | |
imm = generate_dummy_certificate('Intermediate', root, ca=True) | |
mine = generate_dummy_certificate('mycert', imm) | |
with tempfile.NamedTemporaryFile() as ca_bundle_file, tempfile.NamedTemporaryFile() as cert_file, tempfile.NamedTemporaryFile() as key_file: | |
cert_file.write(crypto.dump_certificate(crypto.FILETYPE_PEM, mine['cert'])) | |
cert_file.write(b'\n') | |
cert_file.write(crypto.dump_certificate(crypto.FILETYPE_PEM, imm['cert'])) | |
cert_file.write(b'\n') | |
cert_file.write(crypto.dump_certificate(crypto.FILETYPE_PEM, root['cert'])) | |
cert_file.flush() | |
ca_bundle_file.write(crypto.dump_certificate(crypto.FILETYPE_PEM, imm['cert'])) | |
ca_bundle_file.write(b'\n') | |
ca_bundle_file.write(crypto.dump_certificate(crypto.FILETYPE_PEM, root['cert'])) | |
ca_bundle_file.flush() | |
key_file.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, mine['key'])) | |
key_file.flush() | |
class MyHandler(http.server.BaseHTTPRequestHandler): | |
def do_GET(self): | |
self.send_response(200) | |
self.send_header('Content-Length', '3') | |
self.end_headers() | |
self.wfile.write(b'HEY') | |
server = TLSTCPServer(('localhost', 0), MyHandler, True, cert_file.name, key_file.name) | |
ev = threading.Event() | |
old_service_actions = server.service_actions | |
def service_actions(): | |
ev.set() | |
old_service_actions() | |
server.service_actions = service_actions | |
def run(): | |
try: | |
server.serve_forever() | |
except: | |
import traceback | |
traceback.print_exc() | |
t = threading.Thread(target=run) | |
try: | |
t.start() | |
ev.wait() | |
ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH, cafile=ca_bundle_file.name) | |
print(urllib.request.urlopen('https://localhost:{.server_port}'.format(server), context=ctx).read()) | |
finally: | |
server.shutdown() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment