Created
December 11, 2020 16:41
-
-
Save adiroiban/20f938db677f66da5de2c37a8b3a3fd9 to your computer and use it in GitHub Desktop.
Some twisted proto helpers
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2012 Adi Roiban. | |
# See LICENSE for details. | |
""" | |
Protocol to help with tests. | |
This comes in addition to standard twisted.test.proto_helpers | |
""" | |
from io import BytesIO | |
from StringIO import StringIO | |
from bunch import Bunch | |
from mock import patch | |
from OpenSSL import SSL | |
from twisted.internet import address, defer, protocol | |
from twisted.internet.abstract import _ConsumerMixin | |
from twisted.internet.error import ConnectionAborted, ConnectError | |
from twisted.internet.protocol import ServerFactory, Protocol | |
from twisted.internet.task import Clock | |
from twisted.internet.tcp import Connector, Port | |
from twisted.internet.ssl import ( | |
DefaultOpenSSLContextFactory, | |
ClientContextFactory, | |
) | |
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol | |
from twisted.internet._newtls import ( | |
ConnectionMixin, | |
ClientMixin, | |
ServerMixin, | |
startTLS, | |
) | |
from twisted.protocols import basic, loopback | |
from twisted.python.failure import Failure | |
from twisted.test.proto_helpers import ( | |
StringTransportWithDisconnection as | |
TwistedStringTransportWithDisconnection, | |
) | |
class AuditingProtocol(Protocol, object): | |
""" | |
A protocol which records its events. | |
""" | |
def __init__(self): | |
self.connection_made = [] | |
self.connection_lost = [] | |
self.data_received = [] | |
# Called when connection is lost. | |
self.lost_deferred = defer.Deferred() | |
def connectionMade(self): | |
""" | |
See: Protocol. | |
""" | |
self.connection_made.append(True) | |
def dataReceived(self, line): | |
""" | |
See: Protocol. | |
""" | |
self.data_received.append(line) | |
def connectionLost(self, reason): | |
""" | |
See: Protocol. | |
""" | |
text_reason = str(reason.value) | |
self.connection_lost.append(text_reason) | |
self.lost_deferred.callback(text_reason) | |
class EchoProtocol(Protocol): | |
""" | |
A protocol which replies what it receives, and keep track of it. | |
""" | |
def __init__(self): | |
self.value = [] | |
def dataReceived(self, data): | |
""" | |
Called when data is received. | |
""" | |
self.value.append(data) | |
self.transport.write(data) | |
class SSLInspectorProtocol(Protocol): | |
""" | |
Force the handshake, get the remote certificate / chain and other SSL/TLS | |
related info, and then close the connection. | |
""" | |
# The peer certificate and chain as obtained after the handshake. | |
peer_certificate = None | |
peer_chain = None | |
# Deferred which is fired when we got the peer details. | |
_handshake_done = None | |
def triggerHandshake(self): | |
""" | |
Send some data to trigger the handshake finalization. | |
""" | |
self.transport.write('Some data\r\n') | |
self._handshake_done = defer.Deferred() | |
return self._handshake_done | |
def connectionLost(self, reason): | |
""" | |
Called | |
""" | |
if self._handshake_done.called: | |
return | |
self._handshake_done.errback(reason) | |
def dataReceived(self, data): | |
""" | |
Called when data was received. | |
By this time we should have a finalized handshake. | |
""" | |
self.peer_certificate = self.transport.getPeerCertificate() | |
self.peer_chain = self.transport._tlsConnection.get_peer_cert_chain() | |
# Close the connection once we are done. | |
self.transport.loseConnection() | |
self._handshake_done.callback(None) | |
class AccumulatingLineProtocol(basic.LineReceiver): | |
""" | |
Stored each received line and can fire a deferred on connection and | |
disconnection | |
factory.protocol_open_deferred is fired on connection made. | |
protocol.close_deferred is fired on connection close. | |
line_received_deferred deferred is called for each new line which is | |
received, creating a new deferred after each line. | |
It injects the protocol instance as a factory member. | |
""" | |
# Protocol connection indicators. | |
connection_made = False | |
connection_close = False | |
# The reason why connection was closed. | |
close_reason = None | |
# An optional deferred called on connection done. | |
close_deferred = None | |
line_received_deferred = None | |
# Remote peer as available at connection time. | |
peer = None | |
# Lines received so far. | |
lines = None | |
factory = None | |
def connectionMade(self): | |
""" | |
See: Protocol. | |
""" | |
self.lines = [] | |
self.connection_made = True | |
if self.factory.protocol_open_deferred is not None: | |
self.factory.protocol_open_deferred.callback(self) | |
self.line_received_deferred = defer.Deferred() | |
self.peer = self.transport.getPeer() | |
def lineReceived(self, line): | |
""" | |
See: Protocol. | |
""" | |
self.lines.append(line) | |
self.line_received_deferred.callback(line) | |
self.line_received_deferred = defer.Deferred() | |
def connectionLost(self, reason): | |
""" | |
See: Protocol. | |
""" | |
self.connection_close = True | |
self.close_reason = reason | |
if self.close_deferred is not None: | |
self.close_deferred.callback(None) | |
self.close_deferred = None | |
class AccumulatingDatagramServerProtocol(protocol.DatagramProtocol): | |
""" | |
A datagram protocol used for accumulating all received data as a server. | |
It has a set of deferred which can be used for waiting for various | |
connection events. | |
""" | |
def __init__(self): | |
self.start_deferred = defer.Deferred() | |
self.stop_deferred = defer.Deferred() | |
self.received_deferred = defer.Deferred() | |
self.started = False | |
self.stopped = False | |
self.client_address = False | |
self.received_data = [] | |
def stopProtocol(self): | |
""" | |
See: DatagramProtocol. | |
""" | |
self.stopped = True | |
self.stop_deferred.callback(None) | |
def startProtocol(self): | |
""" | |
See: DatagramProtocol. | |
""" | |
self.started = True | |
self.start_deferred.callback(None) | |
def datagramReceived(self, data, addr): | |
""" | |
See: DatagramProtocol. | |
""" | |
self.client_address = addr | |
self.received_data.append(data) | |
self.received_deferred.callback(data) | |
class StringTransportWithDisconnection( | |
TwistedStringTransportWithDisconnection, object): | |
""" | |
Transport over StringIO. | |
""" | |
def abortConnection(self): | |
""" | |
Abort the transport. | |
""" | |
# FIXME:1370: | |
# Check if fix is included in latest Twisted release and remove this | |
# patch. | |
# https://twistedmatrix.com/trac/ticket/8161 | |
return self._closeConnection(ConnectionAborted()) | |
def _closeConnection(self, reason): | |
""" | |
Common code for closing the connection. | |
""" | |
# This is here since the Twisted implementation does not have | |
# this method. | |
if not self.connected: # noqa:cover | |
return | |
self.connected = False | |
self.protocol.connectionLost(Failure(reason)) | |
class StringTLSTransport(StringTransportWithDisconnection): | |
""" | |
FIXME:3600: | |
DEPRECATED: String transport with TLS. | |
""" | |
context = None | |
def __init__(self, certificate=None): | |
super(StringTLSTransport, self).__init__() | |
self.test_peer_certificate = certificate | |
self.TLS = False | |
self.transport = self | |
def startTLS(self, context): | |
self.context = context | |
self.TLS = True | |
self._tlsConnection = SSL.Connection(context.getContext(), None) | |
self.protocol._tlsConnection = self._tlsConnection | |
def stopTLS(self): | |
self.TLS = False | |
self.protocol._tlsConnection = None | |
def getPeerCertificate(self): | |
return self.test_peer_certificate | |
class _StringSTARTTLSTransport( | |
StringTransportWithDisconnection, ConnectionMixin): | |
""" | |
String transport with TLS start/stop capabilities for both | |
client and server side. | |
""" | |
def __init__(self, certificate=None): | |
super(_StringSTARTTLSTransport, self).__init__() | |
# A reference to the last context used. | |
self.context = None | |
# Data written over TLS. | |
self._tls_io = None | |
# Fake the peer certificate. | |
self.test_peer_certificate = certificate | |
# Reference to the last TLS protected protocol. | |
self._tls_protocol = None | |
def startTLS(self, context, normal=True): | |
""" | |
Switch the transport from clear to secure mode. | |
""" | |
if self.context is not None: | |
raise AssertionError('SSL/TLS already started.') | |
self.context = context | |
startTLS(self, context, normal, _StringSTARTTLSTransport) | |
# Keep a copy of the tls protocol so that we can fake its | |
# shutdown after stop tls. | |
self._tls_protocol = self.protocol | |
# We don't have a real peer, so shutdown will always fail. | |
# Here we pretend that all is ok. | |
try: | |
self.protocol._tlsConnection.shutdown = lambda: None | |
self.protocol._tlsConnection.get_peer_certificate = ( | |
lambda: self.test_peer_certificate) | |
except AttributeError: | |
# On PyOpenSSL 0.13 OpenSSL.SSL.Connection is a C object so we | |
# do a more aggressive mocking. | |
# The Twisted API is using both bio_* and non bio version in both | |
# client and server side operations. | |
def recv(length): | |
""" | |
Called when consumer want data from the connection. | |
""" | |
# Just signal that shutdown is complete. | |
raise SSL.ZeroReturnError() | |
self.protocol._tlsConnection = Bunch( | |
shutdown=lambda: None, | |
bio_shutdown=lambda: None, | |
recv=recv, | |
bio_read=lambda length: b'', | |
get_peer_certificate=lambda: self.test_peer_certificate, | |
) | |
def finalizeTLSShutdown(self): | |
""" | |
Fake the finalization of TLS shutdown as received from the | |
remote peer. | |
""" | |
self._tls_protocol._tlsShutdownDeferred.callback(None) | |
def tls_clear(self): | |
""" | |
Clear the data sent over TLS. | |
""" | |
if not self._tls_io: | |
return | |
self._tls_io = BytesIO() | |
def tls_value(self): | |
""" | |
Return the clear text data as it would have been written | |
over a TLS/SSL protected channel. | |
""" | |
if not self._tls_io: | |
return b'' | |
return self._tls_io.getvalue() | |
def write(self, bytes): | |
""" | |
Write the bytes. | |
""" | |
if self.TLS: | |
if self._tls_io is None: | |
# First time, we write the handshake, but we ignore it | |
# for the purpose of the test as we only care about the | |
# payload data. | |
self._tls_io = BytesIO() | |
else: | |
self._tls_io.write(bytes) | |
else: | |
self.io.write(bytes) | |
def getPeerCertificate(self): | |
""" | |
Return the certificate of the remote peer. | |
""" | |
return self.test_peer_certificate | |
class StringSTARTTLSClientTransport(_StringSTARTTLSTransport, ClientMixin): | |
""" | |
A transport as used by the client-side connection which support | |
STARTTLS. | |
""" | |
class StringSTARTTLSServerTransport(_StringSTARTTLSTransport, ServerMixin): | |
""" | |
A transport as used by the server-side connection. | |
""" | |
def _start_tls(klass, context_factory, protocol, certificate=None): | |
""" | |
To reuse the code we are using the STARTTLS logic for firing the | |
secure connection. | |
""" | |
base_transport = klass(certificate=certificate) | |
base_transport.protocol = protocol | |
base_transport.startTLS(context_factory) | |
# But the transport is then updated to look like one which was not | |
# started with STARTTLS. | |
base_transport._tlsConnection = base_transport.protocol._tlsConnection | |
base_transport.protocol = None | |
return base_transport | |
def StringTLSClientTransport(context_factory, protocol, certificate=None): | |
""" | |
A transport as used by the client-side connection which is already | |
secured by TLS without STARTLS | |
""" | |
return _start_tls( | |
klass=StringSTARTTLSClientTransport, | |
context_factory=context_factory, | |
certificate=certificate, | |
protocol=protocol, | |
) | |
def StringTLSServerTransport(context_factory, protocol, certificate=None): | |
""" | |
A transport as used by the server-side connection. | |
""" | |
return _start_tls( | |
klass=StringSTARTTLSServerTransport, | |
context_factory=context_factory, | |
certificate=certificate, | |
protocol=protocol, | |
) | |
class InMemoryConsumer(_ConsumerMixin): | |
""" | |
A consumer which keeps all data in memory. | |
""" | |
connected = True | |
disconnecting = False | |
disconnected = False | |
def __init__(self, data=None): | |
if data is None: | |
data = StringIO() | |
self._data = data | |
def registerProducer(self, producer, streaming): | |
result = super(InMemoryConsumer, self).registerProducer( | |
producer, streaming) | |
# Trigger the producer right away. | |
producer.resumeProducing() | |
return result | |
def write(self, data): | |
""" | |
Accumulate data. | |
""" | |
self._data.write(data) | |
def value(self): | |
""" | |
Return accumulated data so far. | |
""" | |
return self._data.getvalue() | |
@property | |
def isConnected(self): | |
return self.connected | |
class StreamPullProducer(object): | |
""" | |
A pull producer to `consumer` for the content of `file`. | |
""" | |
# The chunk is big enough so that it will read most data from one call. | |
CHUNK_SIZE = 8092 | |
def __init__(self, consumer, file): | |
self.consumer = consumer | |
self.file = file | |
self.deferred = None | |
def resumeProducing(self): | |
chunk = '' | |
if self.file: | |
chunk = self.file.read(self.CHUNK_SIZE) | |
if not chunk: | |
# We are at EOF. | |
self.file = None | |
self.consumer.unregisterProducer() | |
if self.deferred: | |
self.deferred.callback(None) | |
self.deferred = None | |
return | |
self.consumer.write(chunk) | |
def pauseProducing(self): | |
pass | |
def stopProducing(self): | |
if self.deferred: | |
self.deferred.errback( | |
Exception("Consumer asked us to stop producing")) | |
self.deferred = None | |
class ConnectionTrackingServerFactory(ServerFactory, object): | |
""" | |
A factory which will track its connections. | |
""" | |
def __init__(self): | |
# Latest connection. | |
self.protocol_instance = None | |
# All connections. | |
self.protocol_instances = [] | |
# A deferred which can be called when the protocol when connected. | |
self.protocol_open_deferred = defer.Deferred() | |
def buildProtocol(self, addr): | |
protocol = super(ConnectionTrackingServerFactory, self).buildProtocol( | |
addr) | |
self.protocol_instance = protocol | |
self.protocol_instances.append(protocol) | |
return protocol | |
def serverFactoryForProtocol(protocol_class): | |
""" | |
Create a new factory instance for `protocol_class`. | |
""" | |
factory = ConnectionTrackingServerFactory() | |
factory.protocol = protocol_class | |
return factory | |
# FIXME:1370: | |
# Patch the loopback code to support abortConnection as our forked version | |
# don't support it. | |
loopback._LoopbackTransport.abortConnection = ( | |
lambda self: self.loseConnection()) | |
loopback._LoopbackTransport.pauseProducing = lambda self: None | |
loopback._LoopbackTransport.resumeProducing = lambda self: None | |
class InMemorySocket(Bunch): | |
""" | |
A socket which will not touch the network. | |
""" | |
def __init__(self, host='127.0.0.1', port=4224): | |
self._host = host | |
self._port = port | |
self.fileno = port | |
def close(self): | |
""" | |
No-operation as there is nothing to close. | |
See socket.close. | |
""" | |
def getsockname(self): | |
""" | |
Return our/local side of the socket. | |
See socket.getsockname. | |
""" | |
return (self._host, self._port) | |
def setblocking(self, flag): | |
""" | |
Does nothing as we don't have custom logic for | |
blocking vs non-blocking. | |
""" | |
def recv(self, bufsize): | |
""" | |
No data is transferred over this socket, as the code should use the | |
high level transport and use Protocol.dataReceived. | |
""" | |
return b'' | |
class InMemoryReactorAbstract(Clock, object): | |
""" | |
A simple reactor which connects without touching the network using | |
a client side endpoint. | |
It is initialized with a list of (address, port) tuple for which | |
connections are allowed. | |
For client connection the list of tuple can be a | |
(address, port, make_connection) for which when `make_connection` is | |
False it will not trigger the connection right away. | |
At this point it can not be used for connecting a client and a server using | |
the same reactor. | |
""" | |
def __init__(self, expected_addresses, client_transport_factory=None): | |
super(InMemoryReactorAbstract, self).__init__() | |
# Delay the import as mk is also importing proto_helpers. | |
from chevah.server.testing import mk | |
self._mk = mk | |
self._expected_addresses = expected_addresses[:] | |
self.latest_protocol = None | |
self._clientTransportFactory = client_transport_factory | |
# Ports listening for conenctions | |
self._ports = {} | |
def addReader(self, port): | |
""" | |
Called when we are waiting for incoming connections. | |
""" | |
# Keep a reference, in case we want to initiate a client connection. | |
self._ports[port.port] = port | |
def removeReader(self, port): | |
""" | |
Called when we are no longer waiting for incoming connections. | |
""" | |
try: | |
del self._ports[port.port] | |
except Exception: | |
# Might be a client port, and we don't keep a record of these | |
# ports... or it might be a port which is not listening yet. | |
pass | |
def removeWriter(self, port): | |
""" | |
Called when we should remove a connection from the reactor loop. | |
Does nothing as we don't keep a record of writers. | |
""" | |
def connectTCP(self, host, port, factory, timeout=30, bindAddress=None): | |
""" | |
Do or reject the connection based on reactor configuration. | |
""" | |
raise NotImplementedError('connectTCP not implemented.') | |
def connectSSL(self, host, port, factory, timeout=30, bindAddress=None): | |
""" | |
Do or reject the connection based on reactor configuration. | |
""" | |
raise NotImplementedError('connectSSL not implemented.') | |
def listenTCP(self, port, factory, backlog=50, interface=''): | |
""" | |
Set up a listening port. | |
Code copied from Twisted, with the exception of the fake socket | |
injection. | |
""" | |
p = Port( | |
port=port, | |
factory=factory, | |
backlog=backlog, | |
interface=interface, | |
reactor=self, | |
) | |
# We inject our own socket to not touch the network. | |
p._preexistingSocket = InMemorySocket() | |
p.startListening() | |
return p | |
def triggerConnectionWithClose( | |
self, | |
server_protocol, server_port, | |
client_protocol, client_port=1234, | |
server_host='127.0.0.1', | |
client_host='1.2.3.4', | |
): | |
""" | |
Rig a client connection to the server listening on `port` and using | |
the `client_protocol` to communicate on the client side. | |
Will return a deferred which is fired when the connection is closed. | |
""" | |
# Default loopbackAsync code does not allow injecting the peers into | |
# the transport. | |
server_to_client = loopback._LoopbackQueue() | |
client_to_server = loopback._LoopbackQueue() | |
server_address = address.IPv4Address('TCP', server_host, server_port) | |
client_address = address.IPv4Address('TCP', client_host, client_port) | |
server_transport = loopback._LoopbackTransport(server_to_client) | |
server_transport.getPeer = lambda: client_address | |
server_transport.getHost = lambda: server_address | |
server_transport.socket = InMemorySocket( | |
host=server_host, port=server_host) | |
client_transport = loopback._LoopbackTransport(client_to_server) | |
client_transport.getPeer = lambda: server_address | |
client_transport.getHost = lambda: client_address | |
client_transport.socket = InMemorySocket( | |
host=client_host, port=client_host) | |
server_protocol.makeConnection(server_transport) | |
client_protocol.makeConnection(client_transport) | |
defered = loopback._loopbackAsyncBody( | |
server=server_protocol, | |
serverToClient=server_to_client, | |
client=client_protocol, | |
clientToServer=client_to_server, | |
pumpPolicy=loopback.identityPumpPolicy, | |
) | |
self.latest_protocol = server_protocol | |
return defered | |
def triggerClientConnectionWithClose(self, port, client_protocol): | |
""" | |
Rig a client connection to the server listening on `port` and using | |
the `client_protocol` to communicate on the client side. | |
Will return a deferred which is fired when the connection is closed. | |
""" | |
try: | |
server_port = self._ports[port] | |
except KeyError: # noqa:cover | |
raise AssertionError( | |
'No server in this reactor is listening to %s.' % (port,)) | |
server_protocol = server_port.factory.buildProtocol(('1.2.3.4', 1234)) | |
return self.triggerConnectionWithClose( | |
server_protocol=server_protocol, | |
server_port=0, | |
client_protocol=client_protocol, | |
client_port=1234, | |
) | |
def triggerTLSClientConnectionWithClose(self, port, client_protocol): | |
""" | |
Rig a TLS client connection similar to | |
triggerClientConnectionWithClose. | |
""" | |
tls_factory = TLSMemoryBIOFactory(ClientContextFactory(), True, None) | |
tls_protocol = TLSMemoryBIOProtocol( | |
tls_factory, client_protocol, _connectWrapped=True) | |
return self.triggerClientConnectionWithClose( | |
port=port, client_protocol=tls_protocol) | |
def tls_wrap_server_protocol(protocol): | |
""" | |
Wrap the `protocol` into a server side TLS protocol. | |
""" | |
def get_context(protocol): | |
from chevah.server.testing import OPENSSL_SECLEVEL | |
context = SSL.Context(protocol) | |
context.set_cipher_list('ALL' + OPENSSL_SECLEVEL) | |
return context | |
context = DefaultOpenSSLContextFactory( | |
privateKeyFileName='test_data/pki/server-cert-and-key-2048.pem', | |
certificateFileName='test_data/pki/server-cert-and-key-2048.pem', | |
_contextFactory=get_context | |
) | |
tls_factory = TLSMemoryBIOFactory(context, True, None) | |
return TLSMemoryBIOProtocol( | |
tls_factory, protocol, _connectWrapped=True) | |
class InMemoryTCPReactor(InMemoryReactorAbstract): | |
""" | |
A simple reactor which implements only TCP connections. | |
""" | |
def connectTCP(self, host, port, factory, timeout=30, bindAddress=None): | |
""" | |
Do or reject the connection based on reactor configuration. | |
""" | |
try: | |
expected_address = self._expected_addresses.pop(0) | |
except Exception: # noqa:cover | |
factory.clientConnectionFailed( | |
self, ConnectError( | |
osError=None, string='!!!FAIL!!! Unexpected connection')) | |
return | |
try: | |
expected_host, expected_port, make_connection = expected_address | |
except ValueError: | |
make_connection = True | |
expected_host, expected_port = expected_address | |
if (host, port) == (expected_host, expected_port): | |
protocol = factory.buildProtocol((host, port)) | |
if make_connection: | |
# Use a simple string transport to fake a non-existent | |
# server connection. | |
if self._clientTransportFactory is None: | |
transport = self._mk.makeStringTransportWithDisconnection() | |
else: | |
transport = self._clientTransportFactory(protocol) | |
protocol.makeConnection(transport) | |
transport.protocol = protocol | |
self.latest_transport = transport | |
else: | |
factory.clientConnectionFailed( | |
self, ConnectError( | |
osError=None, string='!!!FAIL!!! unknown address/port')) | |
return | |
self.latest_protocol = protocol | |
# Make it similar to Twisted code, but use a fake socket. | |
connector = Connector( | |
host, port, factory, timeout, bindAddress, reactor=self) | |
# Rig the socket creation so that we don't touch the network. | |
with patch( | |
'twisted.internet.tcp.Client.createInternetSocket', | |
return_value=InMemorySocket(), | |
): | |
connector.connect() | |
return connector | |
class InMemorySTARTTLSReactor(InMemoryReactorAbstract): | |
""" | |
A simple reactor which implements only TCP connections which can be later | |
upgraded to TLS/SSL using STARTTLS. | |
""" | |
def connectTCP(self, host, port, factory, timeout=30, bindAddress=None): | |
""" | |
Do or reject the connection based on reactor configuration. | |
""" | |
try: | |
expected_address = self._expected_addresses.pop(0) | |
except Exception: # noqa:cover | |
factory.clientConnectionFailed( | |
self, ConnectError( | |
osError=None, string='!!!FAIL!!! Unexpected connection')) | |
return | |
if (host, port) == expected_address: | |
protocol = factory.buildProtocol((host, port)) | |
if self._clientTransportFactory is None: | |
transport = StringSTARTTLSServerTransport() | |
else: | |
transport = self._clientTransportFactory(protocol) | |
protocol.makeConnection(transport) | |
transport.protocol = protocol | |
self.latest_protocol = protocol | |
self.latest_transport = transport | |
else: | |
factory.clientConnectionFailed( | |
self, ConnectError( | |
osError=None, string='!!!FAIL!!! unknown address/port')) | |
class InMemorySSLReactor(InMemoryReactorAbstract): | |
""" | |
A simple reactor which connects only with SSL. | |
It is initialized with a list of (address, port) tuple for which | |
connections are allowed. | |
""" | |
def connectSSL( | |
self, host, port, factory, contextFactory, timeout=30, | |
bindAddress=None, | |
): | |
""" | |
Do or reject the connection based on reactor configuration. | |
""" | |
try: | |
expected_address = self._expected_addresses.pop(0) | |
except Exception: | |
factory.clientConnectionFailed( | |
self, ConnectError( | |
osError=None, string='!!!FAIL!!! Unexpected connection')) | |
return | |
if (host, port) == expected_address: | |
protocol = factory.buildProtocol((host, port)) | |
if self._clientTransportFactory is None: | |
transport = self._mk.makeStringTransportWithDisconnection() | |
else: | |
transport = self._clientTransportFactory(protocol) | |
protocol.makeConnection(transport) | |
transport.protocol = protocol | |
self.latest_protocol = protocol | |
self.latest_context_factory = contextFactory | |
self.latest_transport = transport | |
else: | |
factory.clientConnectionFailed( | |
self, ConnectError( | |
osError=None, string='!!!FAIL!!! unknown address/port')) | |
def listenSSL( | |
self, port, factory, contextFactory, backlog=50, interface=''): | |
""" | |
Taken from Twisted code. | |
""" | |
tlsFactory = TLSMemoryBIOFactory(contextFactory, False, factory) | |
port = self.listenTCP(port, tlsFactory, backlog, interface) | |
port._type = 'TLS' | |
self._ports[port] = port | |
return port |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment