Created
October 24, 2014 20:54
-
-
Save lukasgraf/399b04e328a5ad318bb9 to your computer and use it in GitHub Desktop.
LoggingSSLSocket
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from http.server import BaseHTTPRequestHandler, HTTPServer | |
from socket import socket | |
from socketserver import ThreadingMixIn | |
from threading import Thread | |
import logging | |
import ssl | |
class LoggingSSLSocket(ssl.SSLSocket): | |
def do_handshake(self, *args, **kwargs): | |
logger.debug('Starting handshake...') | |
result = super(LoggingSSLSocket, self).do_handshake(*args, **kwargs) | |
logger.debug('Done with handshake.') | |
return result | |
def accept(self): | |
"""Accepts a new connection from a remote client, and returns | |
a tuple containing that new connection wrapped with a server-side | |
SSL channel, and the address of the remote client.""" | |
newsock, addr = socket.accept(self) | |
logger.debug("Accepting connection from '%s'..." % (addr, )) | |
newsock = self.context.wrap_socket(newsock, | |
do_handshake_on_connect=self.do_handshake_on_connect, | |
suppress_ragged_eofs=self.suppress_ragged_eofs, | |
server_side=True) | |
logger.debug('Done accepting connection.') | |
return newsock, addr | |
def wrap_socket(sock, keyfile=None, certfile=None, | |
server_side=False, cert_reqs=ssl.CERT_NONE, | |
ssl_version=ssl.PROTOCOL_SSLv23, ca_certs=None, | |
do_handshake_on_connect=True, | |
suppress_ragged_eofs=True, | |
ciphers=None): | |
return LoggingSSLSocket(sock=sock, keyfile=keyfile, certfile=certfile, | |
server_side=server_side, cert_reqs=cert_reqs, | |
ssl_version=ssl_version, ca_certs=ca_certs, | |
do_handshake_on_connect=do_handshake_on_connect, | |
suppress_ragged_eofs=suppress_ragged_eofs, | |
ciphers=ciphers) | |
class MyHTTPHandler(BaseHTTPRequestHandler): | |
def log_message(self, format, *args): | |
logger.info("%s - - %s" % (self.address_string(), format % args)) | |
def do_GET(self): | |
self.send_response(200) | |
self.end_headers() | |
self.wfile.write('test'.encode("utf-8")) | |
class ThreadedHTTPServer(ThreadingMixIn, HTTPServer): | |
pass | |
logger = logging.getLogger('myserver') | |
handler = logging.FileHandler('server.log') | |
formatter = logging.Formatter('[%(asctime)s] %(message)s') | |
handler.setFormatter(formatter) | |
logger.addHandler(handler) | |
logger.setLevel(logging.DEBUG) | |
server = ThreadedHTTPServer(('', 443), MyHTTPHandler) | |
server.socket = ssl.wrap_socket(server.socket, keyfile='server.key', | |
certfile='server.crt', server_side=True, | |
cert_reqs=ssl.CERT_REQUIRED, | |
ca_certs='client.crt') | |
Thread(target=server.serve_forever).start() | |
try: | |
quitcheck = input("Type 'quit' at any time to quit.\n") | |
if quitcheck == "quit": | |
server.shutdown() | |
except (KeyboardInterrupt) as error: | |
server.shutdown() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment