Created
June 26, 2016 05:48
-
-
Save nicolas17/98b1acb6520b9ce1b32f11379ba060c3 to your computer and use it in GitHub Desktop.
Amazon Lambda runtime
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
aws_lambda.bootstrap.py | |
Amazon Lambda | |
Copyright (c) 2013 Amazon. All rights reserved. | |
Lambda runtime implemention | |
""" | |
from __future__ import print_function | |
import code | |
import json | |
import logging | |
import os | |
import socket | |
import sys | |
import traceback | |
import wsgi | |
import runtime as lambda_runtime | |
import imp | |
import fcntl | |
import time | |
import decimal | |
def _get_handlers(handler, mode, suppress_init): | |
if suppress_init: | |
return _get_handlers_delayed(handler, mode) | |
else: | |
return _get_handlers_immediate(handler, mode) | |
class number_str(float): | |
def __init__(self, o): | |
self.o = o | |
def __repr__(self): | |
return str(self.o) | |
def decimal_serializer(o): | |
if isinstance(o, decimal.Decimal): | |
return number_str(o) | |
raise TypeError(repr(o) + " is not JSON serializable") | |
""" | |
delay loading the user's code until an invoke occurs, to ensure we don't crash the runtime. | |
""" | |
def _get_handlers_delayed(handler, mode): | |
real_request_handler = [None] | |
def request_handler(*arg): | |
if real_request_handler[0] is None: | |
init_handler, request_handler = _get_handlers_immediate(handler, mode) | |
real_request_handler[0] = request_handler | |
init_handler() | |
return request_handler(*arg) | |
else: | |
return real_request_handler[0](*arg) | |
return lambda: None, request_handler | |
def load_handler_failed_handler(e, modname): | |
if isinstance(e, ImportError): | |
return make_fault_handler(wsgi.FaultException("Unable to import module '{}'".format(modname), str(e), None)) | |
elif isinstance(e, SyntaxError): | |
trace = "File \"%s\" Line %s\n\t%s" % (e.filename, e.lineno, e.text) | |
fault = wsgi.FaultException("Syntax error in module '{}'".format(modname), str(e), trace) | |
else: | |
exc_info = sys.exc_info() | |
trace = traceback.format_list(traceback.extract_tb(exc_info[2])) | |
fault = wsgi.FaultException("module initialization error", str(e), trace[1:]) | |
return make_fault_handler(fault) | |
def make_fault_handler(fault): | |
def result(*args): | |
raise fault | |
return result | |
def _get_handlers_immediate(handler, mode): | |
init_handler = lambda: None | |
""" | |
This is the old way we were loading modules. | |
It was causing intermittent build failures for unknown reasons. | |
Using the imp module seems to remove these failures. | |
The imp module appears to be more extreme in that it reloads | |
the module if it is already loaded, and it likely doesn't use any caches when | |
searching for the module but does a full directory search, which is what we want. | |
""" | |
#m = imp.load_module(modname, globals(), locals(), []) | |
try: | |
(modname, fname) = handler.rsplit('.', 1) | |
except ValueError as e: | |
fault = wsgi.FaultException("Bad handler '{}'".format(handler), str(e), None) | |
request_handler = make_fault_handler(fault) | |
return init_handler, request_handler | |
file_handle = None | |
try: | |
file_handle, pathname, desc = imp.find_module(modname) | |
if file_handle is None: | |
module_type = desc[2] | |
if module_type == imp.C_BUILTIN: | |
request_handler = make_fault_handler(wsgi.FaultException("Cannot use built-in module {} as a handler module".format(modname), None, None)) | |
return init_handler, request_handler | |
m = imp.load_module(modname, file_handle, pathname, desc) | |
except Exception as e: | |
request_handler = load_handler_failed_handler(e, modname) | |
return init_handler, request_handler | |
finally: | |
if file_handle is not None: | |
file_handle.close() | |
try: | |
init_handler = getattr(m, 'init') | |
except AttributeError as e: | |
pass | |
try: | |
request_handler = make_final_handler(getattr(m, fname), mode) | |
except AttributeError as e: | |
fault = wsgi.FaultException("Handler '{}' missing on module '{}'".format(fname, modname), str(e), None) | |
request_handler = make_fault_handler(fault) | |
return init_handler, request_handler | |
def set_environ(credentials): | |
key, secret, session = credentials.get('key'), credentials.get('secret'), credentials.get('session') | |
#TODO delete from environ if params not found | |
if credentials.get('key'): | |
os.environ['AWS_ACCESS_KEY_ID'] = key | |
if credentials.get('secret'): | |
os.environ['AWS_SECRET_ACCESS_KEY'] = secret | |
if credentials.get('session'): | |
os.environ['AWS_SESSION_TOKEN'] = session | |
os.environ['AWS_SECURITY_TOKEN'] = session | |
def wait_for_start(ctrl_sock): | |
(invokeid, mode, handler, suppress_init, credentials) = lambda_runtime.recv_start(ctrl_sock) | |
set_environ(credentials) | |
lambda_runtime.report_running(invokeid); | |
return (invokeid, mode, handler, suppress_init, credentials) | |
def wait_for_invoke(ctrl_sock): | |
(invokeid, data_sock, credentials, event_body, context_objs, invoked_function_arn) = lambda_runtime.receive_invoke(ctrl_sock) | |
set_environ(credentials) | |
return (invokeid, data_sock, credentials, event_body, context_objs, invoked_function_arn) | |
def make_final_handler(handlerfn, mode): | |
if mode == "http": | |
def result(sockfd): | |
invoke_http(handlerfn, sockfd) | |
elif mode == "event": | |
return handlerfn | |
else: | |
def result(sockfd): | |
raise wsgi.FaultException("specified mode is invalid: " + mode) | |
return result | |
def invoke_http(handlerfn, sockfd): | |
fault_data = wsgi.handle_one(sockfd, ('localhost', 80), handlerfn) | |
if fault_data: | |
raise wsgi.FaultException(fault_data.msg, fault_data.except_value, fault_data.trace) | |
def try_or_raise(function, error_message): | |
try: | |
return function() | |
except Exception as e: | |
raise JsonError(sys.exc_info(), error_message) | |
def make_error(errorMessage, errorType, stackTrace): #stackTrace is an array | |
result = {} | |
if errorMessage: | |
result['errorMessage'] = errorMessage | |
if errorType: | |
result['errorType'] = errorType | |
if stackTrace: | |
result['stackTrace'] = stackTrace | |
return result | |
def handle_http_request(request_handler, invokeid, sockfd): | |
try: | |
request_handler(sockfd) | |
except wsgi.FaultException as e: | |
lambda_runtime.report_fault(invokeid, e.msg, e.except_value, e.trace) | |
finally: | |
try: | |
os.close(sockfd) | |
except Exception as e: | |
print("Error closing original data connection descriptor", file=sys.stderr) | |
traceback.print_exc() | |
finally: | |
lambda_runtime.report_done(invokeid, None, None) | |
def to_json(obj): | |
return json.dumps(obj, default=decimal_serializer) | |
def handle_event_request(request_handler, invokeid, event_body, context_objs, invoked_function_arn): | |
errortype = None | |
try: | |
client_context = context_objs.get('client_context') | |
if client_context: | |
client_context = try_or_raise(lambda: json.loads(client_context), "Unable to parse client context") | |
context = LambdaContext(invokeid, context_objs, client_context, invoked_function_arn) | |
json_input = try_or_raise(lambda: json.loads(event_body), "Unable to parse input as json") | |
result = request_handler(json_input, context) | |
result = try_or_raise(lambda: to_json(result), "An error occurred during JSON serialization of response") | |
except wsgi.FaultException as e: | |
lambda_runtime.report_fault(invokeid, e.msg, e.except_value, None) | |
result = make_error(e.msg, None, None) | |
result = to_json(result) | |
errortype = "unhandled" | |
except JsonError as e: | |
result = report_fault_helper(invokeid, e.exc_info, e.msg) | |
result = to_json(result) | |
errortype = "unhandled" | |
except Exception as e: | |
result = report_fault_helper(invokeid, sys.exc_info(), None) | |
result = to_json(result) | |
errortype = "unhandled" | |
lambda_runtime.report_done(invokeid, errortype, result) | |
def report_fault_helper(invokeid, exc_info, msg): | |
etype, value, tb = exc_info | |
if msg: | |
msgs = [msg, str(value)] | |
else: | |
msgs = [str(value), etype.__name__] | |
tb_tuples = traceback.extract_tb(tb) | |
if sys.version_info[0] >= 3: | |
awesome_range = range | |
else: | |
awesome_range = xrange | |
for i in awesome_range(len(tb_tuples)): | |
if "/bootstrap.py" not in tb_tuples[i][0]: # filename of the tb tuple | |
tb_tuples = tb_tuples[i:] | |
break | |
lambda_runtime.report_fault(invokeid, msgs[0], msgs[1], "Traceback (most recent call last):\n" + ''.join(traceback.format_list(tb_tuples)) + ''.join(traceback.format_exception_only(etype, value))) | |
return make_error(str(value), etype.__name__, tb_tuples) | |
class CustomFile(object): | |
def __init__(self, fd): | |
self._fd = fd | |
def __getattr__(self, attr): | |
return getattr(self._fd, attr) | |
def write(self, msg): | |
lambda_runtime.log_bytes(msg, self._fd.fileno()) | |
self._fd.flush() | |
def writelines(self, msgs): | |
for msg in msgs: | |
lambda_runtime.log_bytes(msg, self._fd.fileno()) | |
self._fd.flush() | |
class CognitoIdentity(object): | |
__slots__ = ["cognito_identity_id", "cognito_identity_pool_id"] | |
class Client(object): | |
__slots__ = ["installation_id", "app_title", "app_version_name", "app_version_code", "app_package_name"] | |
class ClientContext(object): | |
__slots__ = ['custom', 'env', 'client'] | |
def make_obj_from_dict(_class, _dict, fields=None): | |
if _dict is None: | |
return None | |
obj = _class() | |
set_obj_from_dict(obj, _dict) | |
return obj | |
def set_obj_from_dict(obj, _dict, fields=None): | |
if fields is None: | |
fields = obj.__class__.__slots__ | |
for field in fields: | |
setattr(obj, field, _dict.get(field, None)) | |
class LambdaContext(object): | |
def __init__(self, invokeid, context_objs, client_context, invoked_function_arn=None): | |
self.aws_request_id = invokeid | |
self.log_group_name = os.environ['AWS_LAMBDA_LOG_GROUP_NAME'] | |
self.log_stream_name = os.environ['AWS_LAMBDA_LOG_STREAM_NAME'] | |
self.function_name = os.environ["AWS_LAMBDA_FUNCTION_NAME"] | |
self.memory_limit_in_mb = os.environ['AWS_LAMBDA_FUNCTION_MEMORY_SIZE'] | |
self.function_version = os.environ['AWS_LAMBDA_FUNCTION_VERSION'] | |
self.invoked_function_arn = invoked_function_arn | |
self.client_context = make_obj_from_dict(ClientContext, client_context) | |
if self.client_context is not None: | |
self.client_context.client = make_obj_from_dict(Client, self.client_context.client) | |
self.identity = make_obj_from_dict(CognitoIdentity, context_objs) | |
def get_remaining_time_in_millis(self): | |
return lambda_runtime.get_remaining_time() | |
def log(self, msg): | |
lambda_runtime.send_console_message(str(msg)) | |
class LambdaLoggerHandler(logging.Handler): | |
def __init__(self): | |
logging.Handler.__init__(self) | |
def emit(self, record): | |
lambda_runtime.send_console_message(self.format(record)) | |
class LambdaLoggerFilter(logging.Filter): | |
def filter(self, record): | |
record.aws_request_id = _GLOBAL_AWS_REQUEST_ID or "" | |
return True | |
class JsonError(Exception): | |
def __init__(self, exc_info, msg): | |
self.exc_info = exc_info | |
self.msg = msg | |
_GLOBAL_DEFAULT_TIMEOUT = socket._GLOBAL_DEFAULT_TIMEOUT | |
_GLOBAL_AWS_REQUEST_ID = None | |
def main(): | |
if sys.version_info[0] < 3: | |
reload(sys) | |
sys.setdefaultencoding('utf-8') | |
sys.stdout = CustomFile(sys.stdout) | |
sys.stderr = CustomFile(sys.stderr) | |
logging.Formatter.converter = time.gmtime | |
logger = logging.getLogger() | |
logger_handler = LambdaLoggerHandler() | |
logger_handler.setFormatter(logging.Formatter('[%(levelname)s]\t%(asctime)s.%(msecs)dZ\t%(aws_request_id)s\t%(message)s\n', '%Y-%m-%dT%H:%M:%S')) | |
logger_handler.addFilter(LambdaLoggerFilter()) | |
logger.addHandler(logger_handler) | |
global _GLOBAL_AWS_REQUEST_ID | |
ctrl_sock = os.getenv('LAMBDA_CONTROL_SOCKET') | |
if ctrl_sock is None: | |
raise Exception("LAMBDA_CONTROL_SOCKET not set"); | |
(invokeid, mode, handler, suppress_init, credentials) = wait_for_start(int(ctrl_sock)) | |
sys.path.insert(0, os.environ['LAMBDA_TASK_ROOT']) | |
init_handler, request_handler = _get_handlers(handler, mode, suppress_init) | |
try: | |
init_handler() | |
except wsgi.FaultException as e: | |
lambda_runtime.report_fault(invokeid, e.msg, e.except_value, e.trace) | |
finally: | |
lambda_runtime.report_done(invokeid, None, None) | |
while True: | |
(invokeid, sockfd, credentials, event_body, context_objs, invoked_function_arn) = wait_for_invoke(int(ctrl_sock)) | |
_GLOBAL_AWS_REQUEST_ID = invokeid | |
if mode == "http": | |
handle_http_request(request_handler, invokeid, sockfd) | |
elif mode == "event": | |
handle_event_request(request_handler, invokeid, event_body, context_objs, invoked_function_arn) | |
if __name__ == '__main__': | |
main() | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
aws_lambda.wsgi | |
Amazon Lambda | |
Copyright (c) 2013 Amazon. All rights reserved. | |
Lambda wsgi implementation | |
""" | |
from __future__ import print_function | |
try: | |
# for python 3 | |
from http.server import BaseHTTPRequestHandler, HTTPServer | |
import urllib.request, urllib.parse, urllib.error | |
from urllib.parse import unquote | |
except ImportError: | |
# for python 2 | |
from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer | |
import urllib | |
from urllib import unquote | |
import socket | |
from wsgiref.simple_server import ServerHandler | |
import sys | |
import os | |
import traceback | |
class FaultData(object): | |
""" | |
Contains three fields, msg, except_value, and trace | |
msg is mandatory and must be a string | |
except_value and trace are optional and must be a string or None. | |
The constructor will convert all values to strings through str(). | |
In addition, the constructor will try to join iterable trace values with "\n".join. | |
""" | |
def __init__(self, msg, except_value=None, trace=None): | |
try: | |
trace_is_string = isinstance(trace, basestring) | |
except NameError: | |
trace_is_string = isinstance(trace, str) | |
if not (trace is None or trace_is_string): | |
try: | |
trace = "\n".join(trace) | |
except TypeError: | |
trace = str(trace) | |
self.msg = str(msg) | |
self.except_value = except_value if except_value is None else str(except_value) | |
self.trace = trace if trace is None else str(trace) | |
class FaultException(Exception): | |
def __init__(self, msg, except_value=None, trace=None): | |
fault_data = FaultData(msg, except_value, trace) | |
self.msg = fault_data.msg | |
self.except_value = fault_data.except_value | |
self.trace = fault_data.trace | |
def handle_one(sockfd, client_addr, app): | |
""" This function calls the Request handler. It returns a FaultData object if a fault occurs, else None""" | |
try: | |
sock = socket.fromfd(sockfd, socket.AF_INET, socket.SOCK_STREAM) | |
handler = WSGIGir_RequestHandler(sock, client_addr, app) | |
return handler.fault | |
except socket.error as e: | |
print("Error building a socket object: {}".format(e)) | |
sys.exit(1) | |
finally: | |
sock.close() | |
class Handler(ServerHandler): | |
wsgi_run_once = True | |
def __init__(self, stdin, stdout, stderr, environ, request_handler): | |
"""set multithread=False and multiprocess=False""" | |
self.stdin = stdin | |
self.stdout = stdout | |
self.stderr = stderr | |
self.base_env = environ | |
self.request_handler = request_handler # back-pointer for logging | |
self.wsgi_multithread = False | |
self.wsgi_multiprocess = False | |
self.fault = None | |
def handle_error(self): | |
"""Catch errors that occur when serializing the return value to HTTP response and report fault""" | |
# There is a bug in some versions of wsgi where code here fails because status is None or environ is None | |
self.environ = self.environ or {'SERVER_PROTOCOL' : 'HTTP/1.0'} | |
self.status = self.status or "500 Internal server error" | |
exc_type, exc_value, exc_traceback = sys.exc_info() | |
trace = traceback.format_list(traceback.extract_tb(exc_traceback)) | |
self.fault = FaultData("Unable to convert result into http response", exc_value, trace) | |
ServerHandler.handle_error(self) | |
def close(self): | |
# There is a bug in some versions of wsgi where code here fails because status is None or environ is None | |
self.environ = self.environ or {'SERVER_PROTOCOL' : 'HTTP/1.0'} | |
self.status = self.status or "500 Internal server error" | |
ServerHandler.close(self) | |
# define helper function based on the version of python we are running | |
if sys.version_info[0] < 3: | |
def get_content_type_helper(self): | |
return self.headers.typeheader | |
def get_headers_helper(self): | |
return self.headers.headers | |
def get_length_helper(self): | |
return self.headers.getheader('content-length') | |
def parse_header_helper(h): | |
k, v = h.split(':', 1) | |
return (k , v) | |
else: | |
def get_content_type_helper(self): | |
return self.headers.get_content_type() | |
def get_headers_helper(self): | |
return self.headers.items() | |
def get_length_helper(self): | |
return self.headers.get('content-length') | |
def parse_header_helper(h): | |
k, v = h | |
return (k , v) | |
class WSGIGir_RequestHandler(BaseHTTPRequestHandler): | |
"""WSGI HTTP request handler | |
Class which inherits the HTTP request handler base class and takes application as the input. | |
Most of the things are taken from wsgiref package's WSGIRequestHandler class. | |
""" | |
def __init__(self, request, client_address, app): | |
self.app = app | |
self.fault = None | |
# set app and call super class constructor | |
BaseHTTPRequestHandler.__init__(self, request, client_address, '') | |
def get_app(self): | |
"""This function returns the application which has to be run""" | |
return self.app | |
def get_environ(self): | |
# Set up base environment | |
env = {} | |
env['CONTENT_LENGTH']='' | |
env['GATEWAY_INTERFACE'] = 'CGI/1.1' # TODO we may change this for simple-compute | |
env['SCRIPT_NAME'] = '' | |
env['SERVER_PROTOCOL'] = self.request_version | |
env['REQUEST_METHOD'] = self.command | |
if '?' in self.path: | |
path,query = self.path.split('?',1) | |
else: | |
path,query = self.path,'' | |
env['PATH_INFO'] = unquote(path) | |
env['QUERY_STRING'] = query | |
host = self.address_string() | |
if host != self.client_address[0]: | |
env['REMOTE_HOST'] = host | |
env['REMOTE_ADDR'] = self.client_address[0] | |
# 2 vs 3 | |
if get_content_type_helper(self) is None: | |
env['CONTENT_TYPE'] = self.headers.type | |
else: | |
env['CONTENT_TYPE'] = get_content_type_helper(self) | |
length = get_length_helper(self) | |
if length: | |
env['CONTENT_LENGTH'] = length | |
for h in get_headers_helper(self): | |
(k, v) = parse_header_helper(h) | |
k = k.replace('-', '_').upper() | |
v = v.strip() | |
if k in env: | |
continue # skip content length, type,etc. | |
if 'HTTP_' + k in env: | |
env['HTTP_' + k] += ',' + v # comma-separate multiple headers | |
else: | |
env['HTTP_' + k] = v | |
return env | |
def get_stderr(self): | |
return sys.stderr | |
def send_error(self, code, message=None): | |
"""Detect errors that occur when reading the HTTP request""" | |
if message is None and code in self.responses: | |
message = self.responses[code][0] | |
self.fault = FaultData("Unable to parse HTTP request", message) | |
BaseHTTPRequestHandler.send_error(self, code, message) | |
def handle(self): | |
"""Handle a single HTTP request""" | |
self.raw_requestline = self.rfile.readline() | |
if not self.parse_request(): #An error code has been sent, just exit | |
return | |
handler = Handler( | |
self.rfile, self.wfile, self.get_stderr(), self.get_environ(), self | |
) | |
def wrapped_app(environ, start_response): | |
"""Catch user code exceptions so we can report a fault""" | |
try: | |
return self.get_app()(environ, start_response) | |
except FaultException as e: | |
self.fault = FaultData(e.msg, e.except_value, e.trace) | |
return handler.error_output(environ, start_response) | |
except Exception as e: | |
trace = traceback.format_list(traceback.extract_tb(sys.exc_info()[2])) | |
self.fault = FaultData("Failure while running task", e, trace[1:]) | |
return handler.error_output(environ, start_response) | |
handler.run(wrapped_app) # pass wrapped application to handler to run it. | |
self.fault = self.fault or handler.fault | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment