Last active
August 14, 2016 14:55
-
-
Save boyxuper/d196978d00ddcc55d367 to your computer and use it in GitHub Desktop.
WIP: hyper-h2 + tornado 4.1
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
import io | |
import ssl | |
import sys | |
import copy | |
import base64 | |
import socket | |
import httplib | |
import urlparse | |
import functools | |
import collections | |
import h2.errors | |
import h2.events | |
import h2.settings | |
import h2.connection | |
import h2.exceptions | |
from tornado import ( | |
httputil, log, stack_context, | |
simple_httpclient, netutil | |
) | |
from tornado.escape import _unicode, utf8 | |
from tornado.httpclient import ( | |
HTTPResponse, HTTPError, HTTPRequest, _RequestProxy | |
) | |
class HTTP20Response(HTTPResponse): | |
def __init__(self, *args, **kwargs): | |
self.pushed_responses = kwargs.pop('pushed_responses', []) | |
self.new_request = kwargs.pop('new_request', None) | |
super(HTTP20Response, self).__init__(*args, **kwargs) | |
class HTTP2Error(HTTPError): | |
pass | |
class ConnectionTimeout(HTTP2Error): | |
def __init__(self, time_cost=None): | |
self.time_cost = time_cost | |
class HTTP2ConnectionClosed(HTTP2Error): | |
def __init__(self, reason=None): | |
self.reason = reason | |
class _RequestTimeout(Exception): | |
pass | |
class SimpleAsyncHTTP20Client(simple_httpclient.SimpleAsyncHTTPClient): | |
MAX_CONNECTION_BACKOFF = 10 | |
def initialize(self, io_loop, host, port=None, max_streams=200, | |
hostname_mapping=None, max_buffer_size=104857600, | |
resolver=None, defaults=None, secure=True, | |
cert_options=None, enable_push=False, **conn_kwargs): | |
super(SimpleAsyncHTTP20Client, self).initialize( | |
io_loop=io_loop, max_clients=1, | |
hostname_mapping=hostname_mapping, max_buffer_size=max_buffer_size, | |
resolver=resolver, defaults=defaults, max_header_size=None, | |
) | |
self.max_streams = max_streams | |
self.host = host | |
self.port = port | |
self.secure = secure | |
self.enable_push = enable_push | |
self.connection_factory = _HTTP20ConnectionFactory( | |
io_loop=self.io_loop, host=host, port=port, | |
max_buffer_size=self.max_buffer_size, secure=secure, | |
tcp_client=self.tcp_client, cert_options=cert_options, | |
) | |
# back-off | |
self.connection_backoff = 0 | |
self.next_connect_time = 0 | |
# open connection | |
self.connection = None | |
self.io_stream = None | |
self.connection_factory.make_connection( | |
self._on_connection_ready, self._on_connection_close) | |
def _adjust_settings(self, event): | |
log.gen_log.debug('settings updated: %r', event.__dict__) | |
settings = event.changed_settings.get(h2.settings.MAX_CONCURRENT_STREAMS) | |
if settings: | |
self.max_clients = min(settings.new_value, self.max_streams) | |
if settings.new_value > settings.original_value: | |
self._process_queue() | |
def _on_connection_close(self, io_stream, reason): | |
if self.io_stream is not io_stream: | |
return | |
connection = self.connection | |
self.io_stream = None | |
self.connection = None | |
if connection is not None: | |
connection.on_connection_close(io_stream.error) | |
self.connection_backoff = min( | |
self.connection_backoff + 1, self.MAX_CONNECTION_BACKOFF) | |
now_time = self.io_loop.time() | |
self.next_connect_time = max( | |
self.next_connect_time, | |
now_time + self.connection_backoff) | |
if io_stream is None: | |
log.gen_log.info( | |
'Connection to %s:%s failed due: %r. Reconnect in %.2f seconds', | |
self.host, self.port, reason, self.next_connect_time - now_time) | |
else: | |
log.gen_log.info( | |
'Connection closed due: %r. Reconnect in %.2f seconds', | |
reason, self.next_connect_time - now_time) | |
self.io_loop.add_timeout( | |
self.next_connect_time, self.connection_factory.make_connection, | |
self._on_connection_ready, self._on_connection_close) | |
# move active request to pending | |
for key, (request, callback) in self.active.items(): | |
self.queue.appendleft((key, request, callback)) | |
self.active.clear() | |
def _connection_terminated(self, event): | |
self._on_connection_close( | |
'Server requested: ERR 0x%x' % event.error_code, self.io_stream) | |
def _on_connection_ready(self, io_stream): | |
# back-off | |
self.next_connect_time = max(self.io_loop.time(), self.next_connect_time) | |
self.connection_backoff = 0 | |
self.io_stream = io_stream | |
self.connection = _HTTP20ConnectionContext( | |
io_stream=io_stream, secure=self.secure, | |
enable_push=self.enable_push, | |
max_buffer_size=self.max_buffer_size, | |
) | |
self.connection.add_event_handler( | |
h2.events.RemoteSettingsChanged, self._adjust_settings | |
) | |
self.connection.add_event_handler( | |
h2.events.ConnectionTerminated, self._connection_terminated | |
) | |
self._process_queue() | |
def _process_queue(self): | |
if not self.connection: | |
return | |
super(SimpleAsyncHTTP20Client, self)._process_queue() | |
def _handle_request(self, request, release_callback, final_callback): | |
_HTTP20Stream( | |
self.io_loop, self.connection, request, | |
self.host, release_callback, final_callback | |
) | |
class _HTTP20ConnectionFactory(object): | |
def __init__(self, io_loop, host, port, max_buffer_size, tcp_client, | |
secure=True, cert_options=None, connect_timeout=None): | |
self.start_time = io_loop.time() | |
self.io_loop = io_loop | |
self.max_buffer_size = max_buffer_size | |
self.tcp_client = tcp_client | |
self.cert_options = collections.defaultdict(lambda: None, **cert_options or {}) | |
if port is None: | |
port = 443 if secure else 80 | |
self.host = host | |
self.port = port | |
self.connect_timeout = connect_timeout | |
self.ssl_options = self._get_ssl_options(self.cert_options) if secure else None | |
def make_connection(self, ready_callback, close_callback): | |
if self.connect_timeout: | |
timed_out = [False] | |
start_time = self.io_loop.time() | |
def _on_timeout(): | |
timed_out[0] = True | |
close_callback(None, ConnectionTimeout(self.io_loop.time() - start_time)) | |
def _on_connect(io_stream): | |
if timed_out[0]: | |
io_stream.close() | |
return | |
self.io_loop.remove_timeout(timeout_handle) | |
self._on_connect(io_stream, ready_callback, close_callback) | |
timeout_handle = self.io_loop.add_timeout( | |
self.start_time + self.connect_timeout, | |
stack_context.wrap(_on_timeout)) | |
else: | |
_on_connect = functools.partial( | |
self._on_connect, | |
ready_callback=ready_callback, | |
close_callback=close_callback, | |
) | |
with stack_context.ExceptionStackContext( | |
functools.partial(self._handle_exception, close_callback)): | |
self.tcp_client.connect( | |
self.host, self.port, af=socket.AF_UNSPEC, | |
ssl_options=self.ssl_options, | |
max_buffer_size=self.max_buffer_size, | |
callback=_on_connect) | |
@classmethod | |
def _handle_exception(cls, close_callback, typ, value, tb): | |
close_callback(None, value) | |
return True | |
@classmethod | |
def _get_ssl_options(cls, cert_options): | |
ssl_options = {} | |
if cert_options['validate_cert']: | |
ssl_options["cert_reqs"] = ssl.CERT_REQUIRED | |
if cert_options['ca_certs'] is not None: | |
ssl_options["ca_certs"] = cert_options['ca_certs'] | |
else: | |
ssl_options["ca_certs"] = simple_httpclient._default_ca_certs() | |
if cert_options['client_key'] is not None: | |
ssl_options["keyfile"] = cert_options['client_key'] | |
if cert_options['client_cert'] is not None: | |
ssl_options["certfile"] = cert_options['client_cert'] | |
# SSL interoperability is tricky. We want to disable | |
# SSLv2 for security reasons; it wasn't disabled by default | |
# until openssl 1.0. The best way to do this is to use | |
# the SSL_OP_NO_SSLv2, but that wasn't exposed to python | |
# until 3.2. Python 2.7 adds the ciphers argument, which | |
# can also be used to disable SSLv2. As a last resort | |
# on python 2.6, we set ssl_version to TLSv1. This is | |
# more narrow than we'd like since it also breaks | |
# compatibility with servers configured for SSLv3 only, | |
# but nearly all servers support both SSLv3 and TLSv1: | |
# http://blog.ivanristic.com/2011/09/ssl-survey-protocol-support.html | |
if sys.version_info >= (2, 7): | |
# In addition to disabling SSLv2, we also exclude certain | |
# classes of insecure ciphers. | |
ssl_options["ciphers"] = "DEFAULT:!SSLv2:!EXPORT:!DES" | |
else: | |
# This is really only necessary for pre-1.0 versions | |
# of openssl, but python 2.6 doesn't expose version | |
# information. | |
ssl_options["ssl_version"] = ssl.PROTOCOL_TLSv1 | |
ssl_options = netutil.ssl_options_to_context(ssl_options) | |
ssl_options.set_alpn_protocols(['h2']) | |
return ssl_options | |
def _on_connect(self, io_stream, ready_callback, close_callback): | |
io_stream.set_close_callback(lambda: close_callback(io_stream, io_stream.error)) | |
self.io_loop.add_callback(ready_callback, io_stream) | |
io_stream.set_nodelay(True) | |
class _HTTP20ConnectionContext(object): | |
"""maintenance a http/2 connection state on specific io_stream | |
""" | |
def __init__(self, io_stream, secure, enable_push, max_buffer_size): | |
self.io_stream = io_stream | |
self.schema = 'https' if secure else 'http' | |
self.enable_push = bool(enable_push) | |
self.max_buffer_size = max_buffer_size | |
self.is_closed = False | |
# h2 contexts | |
self.stream_delegates = {} | |
self.event_handlers = {} # connection level event, event -> handler | |
self.h2_conn = h2.connection.H2Connection(client_side=True) | |
self.h2_conn.initiate_connection() | |
self.h2_conn.update_settings({ | |
h2.settings.ENABLE_PUSH: int(self.enable_push), | |
}) | |
self._setup_reading() | |
self._flush_to_stream() | |
def on_connection_close(self, reason): | |
if self.is_closed: | |
return | |
self.is_closed = True | |
for delegate in self.stream_delegates.values(): | |
delegate.on_connection_close(reason) | |
# h2 related | |
def _on_connection_streaming(self, data): | |
"""flush h2 connection data to IOStream""" | |
if self.is_closed: | |
return | |
try: | |
events = self.h2_conn.receive_data(data) | |
except h2.exceptions.ProtocolError as err: | |
self.h2_conn.close_connection(h2.errors.PROTOCOL_ERROR) | |
self._flush_to_stream() | |
self.io_stream.close() | |
self.on_connection_close(err) | |
return | |
if events: | |
self._process_events(events) | |
self._flush_to_stream() | |
def _flush_to_stream(self): | |
"""flush h2 connection data to IOStream""" | |
data_to_send = self.h2_conn.data_to_send() | |
if data_to_send: | |
self.io_stream.write(data_to_send) | |
def handle_request(self, request): | |
http2_headers = [ | |
(':authority', request.headers.pop('Host')), | |
(':path', request.url), | |
(':scheme', self.schema), | |
(':method', request.method), | |
] + request.headers.items() | |
stream_id = self.h2_conn.get_next_available_stream_id() | |
self.h2_conn.send_headers(stream_id, http2_headers, end_stream=not request.body) | |
if request.body: | |
self.h2_conn.send_data(stream_id, request.body, end_stream=True) | |
self._flush_to_stream() | |
return stream_id | |
def add_stream_delegate(self, stream_id, stream_delegate): | |
self.stream_delegates[stream_id] = stream_delegate | |
def remove_stream_delegate(self, stream_id): | |
del self.stream_delegates[stream_id] | |
def add_event_handler(self, event_type, event_handler): | |
self.event_handlers[event_type] = event_handler | |
def remove_event_handler(self, event_type): | |
del self.event_handlers[event_type] | |
def reset_stream(self, stream_id, reason=h2.errors.REFUSED_STREAM, flush=False): | |
if self.is_closed: | |
return | |
try: | |
self.h2_conn.reset_stream(stream_id, reason) | |
except h2.exceptions.StreamClosedError: | |
return | |
else: | |
if flush: | |
self._flush_to_stream() | |
def _process_events(self, events): | |
for event in events: | |
if isinstance(event, h2.events.DataReceived): | |
if event.flow_controlled_length: | |
self.h2_conn.increment_flow_control_window( | |
event.flow_controlled_length) | |
if isinstance(event, h2.events.PushedStreamReceived): | |
stream_id = event.parent_stream_id | |
else: | |
stream_id = getattr(event, 'stream_id', None) | |
if stream_id is not None and stream_id != 0: | |
if stream_id in self.stream_delegates: | |
stream_delegate = self.stream_delegates[stream_id] | |
with stack_context.ExceptionStackContext(stream_delegate.handle_exception): | |
stream_delegate.handle_event(event) | |
else: | |
self.reset_stream(stream_id) | |
log.gen_log.warning('unexpected stream: %s, event: %r', stream_id, event) | |
continue | |
event_type = type(event) | |
if event_type in self.event_handlers: | |
try: | |
self.event_handlers[event_type](event) | |
except Exception as err: | |
log.gen_log.exception('Exception while handling event: %r', err) | |
continue | |
log.gen_log.debug('ignored event: %r, %r', event, event.__dict__) | |
def _setup_reading(self, *_): | |
if self.is_closed: | |
return | |
self.io_stream.read_bytes( | |
num_bytes=65535, callback=self._setup_reading, | |
streaming_callback=self._on_connection_streaming) | |
class _HTTP20Stream(httputil.HTTPMessageDelegate): | |
_SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]) | |
def __init__( | |
self, io_loop, context, request, default_host=None, | |
release_callback=None, final_callback=None, stream_id=None): | |
self.start_time = io_loop.time() | |
self.io_loop = io_loop | |
self.context = context | |
self.release_callback = release_callback | |
self.final_callback = final_callback | |
self.chunks = [] | |
self.headers = None | |
self.code = None | |
self.reason = None | |
self._timeout = None | |
self._pushed_streams = {} | |
self._pushed_responses = {} | |
self._stream_ended = False | |
self._finalized = False | |
with stack_context.ExceptionStackContext(self.handle_exception): | |
if request.request_timeout: | |
self._timeout = self.io_loop.add_timeout( | |
self.start_time + request.request_timeout, | |
stack_context.wrap(self._on_timeout)) | |
if stream_id is None: | |
self.request = self.prepare_request(request, default_host) | |
self.stream_id = self.context.handle_request(self.request) | |
else: | |
self.request = request | |
self.stream_id = stream_id | |
self.context.add_stream_delegate(self.stream_id, self) | |
@classmethod | |
def build_http_headers(cls, headers): | |
http_headers = httputil.HTTPHeaders() | |
for name, value in headers: | |
http_headers.add(name, value) | |
return http_headers | |
def from_push_stream(self, event): | |
headers = self.build_http_headers(event.headers) | |
method = headers.pop(':method') | |
scheme = headers.pop(':scheme') | |
authority = headers.pop(':authority') | |
path = headers.pop(':path') | |
full_url = '%s://%s%s' % (scheme, authority, path) | |
request = HTTPRequest(url=full_url, method=method, headers=headers) | |
return _HTTP20Stream( | |
io_loop=self.io_loop, context=self.context, | |
request=request, stream_id=event.pushed_stream_id, | |
final_callback=functools.partial( | |
self.finish_push_stream, event.pushed_stream_id) | |
) | |
def finish_push_stream(self, stream_id, response): | |
if self._finalized: | |
return | |
self._pushed_responses[stream_id] = response | |
if not self._stream_ended: | |
return | |
if len(self._pushed_streams) == len(self._pushed_responses): | |
self.finish() | |
@classmethod | |
def prepare_request(cls, request, default_host): | |
parsed = urlparse.urlsplit(_unicode(request.url)) | |
if (request.method not in cls._SUPPORTED_METHODS and | |
not request.allow_nonstandard_methods): | |
raise KeyError("unknown method %s" % request.method) | |
request.follow_redirects = False | |
for key in ('network_interface', | |
'proxy_host', 'proxy_port', | |
'proxy_username', 'proxy_password', | |
'expect_100_continue', 'body_producer', | |
): | |
if getattr(request, key, None): | |
raise NotImplementedError('%s not supported' % key) | |
request.headers.pop('Connection', None) | |
if "Host" not in request.headers: | |
if not parsed.netloc: | |
request.headers['Host'] = default_host | |
elif '@' in parsed.netloc: | |
request.headers["Host"] = parsed.netloc.rpartition('@')[-1] | |
else: | |
request.headers["Host"] = parsed.netloc | |
username, password = None, None | |
if parsed.username is not None: | |
username, password = parsed.username, parsed.password | |
elif request.auth_username is not None: | |
username = request.auth_username | |
password = request.auth_password or '' | |
if username is not None: | |
if request.auth_mode not in (None, "basic"): | |
raise ValueError("unsupported auth_mode %s", | |
request.auth_mode) | |
auth = utf8(username) + b":" + utf8(password) | |
request.headers["Authorization"] = ( | |
b"Basic " + base64.b64encode(auth)) | |
if request.user_agent: | |
request.headers["User-Agent"] = request.user_agent | |
if not request.allow_nonstandard_methods: | |
# Some HTTP methods nearly always have bodies while others | |
# almost never do. Fail in this case unless the user has | |
# opted out of sanity checks with allow_nonstandard_methods. | |
body_expected = request.method in ("POST", "PATCH", "PUT") | |
body_present = (request.body is not None or | |
request.body_producer is not None) | |
if ((body_expected and not body_present) or | |
(body_present and not body_expected)): | |
raise ValueError( | |
'Body must %sbe None for method %s (unless ' | |
'allow_nonstandard_methods is true)' % | |
('not ' if body_expected else '', request.method)) | |
if request.body is not None: | |
# When body_producer is used the caller is responsible for | |
# setting Content-Length (or else chunked encoding will be used). | |
request.headers["Content-Length"] = str(len( | |
request.body)) | |
if (request.method == "POST" and | |
"Content-Type" not in request.headers): | |
request.headers["Content-Type"] = "application/x-www-form-urlencoded" | |
if request.decompress_response: | |
request.headers["Accept-Encoding"] = "gzip" | |
request.url = ( | |
(parsed.path or '/') + | |
(('?' + parsed.query) if parsed.query else '') | |
) | |
return request | |
def headers_received(self, first_line, headers): | |
self.headers = headers | |
self.code = first_line.code | |
self.reason = first_line.reason | |
if self.request.header_callback is not None: | |
# Reassemble the start line. | |
self.request.header_callback('%s %s %s\r\n' % first_line) | |
for k, v in self.headers.get_all(): | |
self.request.header_callback("%s: %s\r\n" % (k, v)) | |
self.request.header_callback('\r\n') | |
def _run_callback(self, response): | |
if self._finalized: | |
return | |
if self.release_callback is not None: | |
self.release_callback() | |
self.io_loop.add_callback(self.final_callback, response) | |
self._finalized = True | |
def handle_event(self, event): | |
if isinstance(event, h2.events.ResponseReceived): | |
headers = self.build_http_headers(event.headers) | |
status_code = int(headers.pop(':status')) | |
start_line = httputil.ResponseStartLine( | |
'HTTP/2.0', status_code, httplib.responses[status_code] | |
) | |
self.headers_received(start_line, headers) | |
elif isinstance(event, h2.events.DataReceived): | |
self.data_received(event.data) | |
elif isinstance(event, h2.events.StreamEnded): | |
self._stream_ended = True | |
self.context.remove_stream_delegate(self.stream_id) | |
if len(self._pushed_responses) == len(self._pushed_streams): | |
self.finish() | |
elif isinstance(event, h2.events.PushedStreamReceived): | |
stream = self.from_push_stream(event) | |
self._pushed_streams[event.pushed_stream_id] = stream | |
else: | |
log.gen_log.warning('ignored event: %r, %r', event, event.__dict__) | |
def finish(self): | |
self._remove_timeout() | |
self._unregister_unfinished_streams() | |
data = b''.join(self.chunks) | |
original_request = getattr(self.request, "original_request", | |
self.request) | |
new_request = None | |
if (self.request.follow_redirects and | |
self.request.max_redirects > 0 and | |
self.code in (301, 302, 303, 307)): | |
assert isinstance(self.request, _RequestProxy) | |
new_request = copy.copy(self.request.request) | |
new_request.url = urlparse.urljoin(self.request.url, | |
self.headers["Location"]) | |
new_request.max_redirects = self.request.max_redirects - 1 | |
del new_request.headers["Host"] | |
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.4 | |
# Client SHOULD make a GET request after a 303. | |
# According to the spec, 302 should be followed by the same | |
# method as the original request, but in practice browsers | |
# treat 302 the same as 303, and many servers use 302 for | |
# compatibility with pre-HTTP/1.1 user agents which don't | |
# understand the 303 status. | |
if self.code in (302, 303): | |
new_request.method = "GET" | |
new_request.body = None | |
for h in ["Content-Length", "Content-Type", | |
"Content-Encoding", "Transfer-Encoding"]: | |
try: | |
del self.request.headers[h] | |
except KeyError: | |
pass | |
new_request.original_request = original_request | |
if self.request.streaming_callback: | |
buff = io.BytesIO() | |
else: | |
buff = io.BytesIO(data) # TODO: don't require one big string? | |
response = HTTP20Response( | |
original_request, self.code, reason=self.reason, | |
headers=self.headers, request_time=self.io_loop.time() - self.start_time, | |
buffer=buff, effective_url=self.request.url, | |
pushed_responses=self._pushed_responses.values(), | |
new_request=new_request, | |
) | |
self._run_callback(response) | |
def data_received(self, chunk): | |
if self.request.streaming_callback is not None: | |
self.request.streaming_callback(chunk) | |
else: | |
self.chunks.append(chunk) | |
def handle_exception(self, typ, value, tb): | |
if isinstance(value, _RequestTimeout) and self._stream_ended: | |
self.finish() | |
return True | |
self._remove_timeout() | |
self._unregister_unfinished_streams() | |
self.context.remove_stream_delegate(self.stream_id) | |
# TODO: should we reset & flush immediately? | |
self.context.reset_stream(self.stream_id, flush=True) | |
response = HTTP20Response( | |
self.request, 599, error=value, | |
request_time=self.io_loop.time() - self.start_time, | |
) | |
self._run_callback(response) | |
return True | |
def _unregister_unfinished_streams(self): | |
for stream_id in self._pushed_streams: | |
if stream_id not in self._pushed_responses: | |
self.context.remove_stream_delegate(stream_id) | |
def _remove_timeout(self): | |
if self._timeout is not None: | |
self.io_loop.remove_timeout(self._timeout) | |
self._timeout = None | |
def _on_timeout(self): | |
self._timeout = None | |
self.connection_timeout = True | |
raise _RequestTimeout() | |
def on_connection_close(self, reason=None): | |
try: | |
raise HTTP2ConnectionClosed(reason) | |
except Exception: | |
self.handle_exception(*sys.exc_info()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment