Created
March 10, 2016 22:12
-
-
Save zbyte64/2378034a4ecfca1d71d8 to your computer and use it in GitHub Desktop.
TLS authentication with rethinkdb python client
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from rethinkdb.net import ConnectionInstance, SocketWrapper, Connection, decodeUTF | |
from rethinkdb.errors import * | |
import socket | |
import time | |
import ssl | |
class TLSConnectionInstance(ConnectionInstance): | |
def connect(self, timeout): | |
self._socket = TLSSocketWrapper(self, timeout) | |
return self._parent | |
class TLSSocketWrapper(SocketWrapper): | |
def __init__(self, parent, timeout): | |
self.host = parent._parent.host | |
self.port = parent._parent.port | |
self._read_buffer = None | |
self._socket = None | |
self.ssl = parent._parent.ssl | |
deadline = time.time() + timeout | |
try: | |
self._socket = \ | |
socket.create_connection((self.host, self.port), timeout) | |
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) | |
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) | |
if len(self.ssl) > 0: | |
ssl_context = self._get_ssl_context(self.ssl["ca_certs"]) | |
try: | |
self._socket = ssl_context.wrap_socket(self._socket, | |
server_hostname=self.host) | |
except IOError as exc: | |
self._socket.close() | |
raise ReqlDriverError("SSL handshake failed: %s" % (str(exc),)) | |
#why not just use check_hostname like you should? | |
#try: | |
# match_hostname(self._socket.getpeercert(), hostname=self.host) | |
#except CertificateError: | |
# self._socket.close() | |
# raise | |
self.sendall(parent._parent.handshake) | |
# The response from the server is a null-terminated string | |
response = b'' | |
while True: | |
char = self.recvall(1, deadline) | |
if char == b'\0': | |
break | |
response += char | |
except ReqlAuthError: | |
raise | |
except ReqlTimeoutError: | |
raise | |
except ReqlDriverError as ex: | |
self.close() | |
error = str(ex)\ | |
.replace('receiving from', 'during handshake with')\ | |
.replace('sending to', 'during handshake with') | |
raise #ReqlDriverError(error) | |
except socket.timeout as ex: | |
self.close() | |
raise ReqlTimeoutError(self.host, self.port) | |
except Exception as ex: | |
self.close() | |
raise ReqlDriverError("Could not connect to %s:%s. Error: %s" % | |
(self.host, self.port, ex)) | |
if response != b"SUCCESS": | |
self.close() | |
message = decodeUTF(response).strip() | |
if message == "ERROR: Incorrect authorization key.": | |
raise ReqlAuthError(self.host, self.port) | |
else: | |
raise ReqlDriverError("Server dropped connection with message: \"%s\"" % | |
(message, )) | |
def _get_ssl_context(self, ca_certs): | |
#self.ssl #passed from connect | |
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) | |
ctx.verify_mode = ssl.CERT_REQUIRED | |
ctx.check_hostname = False | |
ctx.load_verify_locations(ca_certs) | |
if ('certfile' in self.ssl and 'keyfile' in self.ssl): | |
print("Loading certs:", self.ssl) | |
ctx.load_cert_chain(certfile=self.ssl['certfile'], keyfile=self.ssl['keyfile']) | |
return ctx | |
#ssl = {ca_certs, keyfile, certfile} | |
def connect(host='localhost', port=28015, db=None, auth_key="", timeout=20, ssl=dict(), **kwargs): | |
conn = Connection(TLSConnectionInstance, host, port, db, auth_key, timeout, ssl, **kwargs) | |
return conn.reconnect(timeout=timeout) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment