Skip to content

Instantly share code, notes, and snippets.

@nicolas17
Created June 26, 2016 05:48
Show Gist options
  • Save nicolas17/98b1acb6520b9ce1b32f11379ba060c3 to your computer and use it in GitHub Desktop.
Save nicolas17/98b1acb6520b9ce1b32f11379ba060c3 to your computer and use it in GitHub Desktop.
Amazon Lambda runtime
# -*- coding: utf-8 -*-
"""
aws_lambda.bootstrap.py
Amazon Lambda
Copyright (c) 2013 Amazon. All rights reserved.
Lambda runtime implemention
"""
from __future__ import print_function
import code
import json
import logging
import os
import socket
import sys
import traceback
import wsgi
import runtime as lambda_runtime
import imp
import fcntl
import time
import decimal
def _get_handlers(handler, mode, suppress_init):
if suppress_init:
return _get_handlers_delayed(handler, mode)
else:
return _get_handlers_immediate(handler, mode)
class number_str(float):
def __init__(self, o):
self.o = o
def __repr__(self):
return str(self.o)
def decimal_serializer(o):
if isinstance(o, decimal.Decimal):
return number_str(o)
raise TypeError(repr(o) + " is not JSON serializable")
"""
delay loading the user's code until an invoke occurs, to ensure we don't crash the runtime.
"""
def _get_handlers_delayed(handler, mode):
real_request_handler = [None]
def request_handler(*arg):
if real_request_handler[0] is None:
init_handler, request_handler = _get_handlers_immediate(handler, mode)
real_request_handler[0] = request_handler
init_handler()
return request_handler(*arg)
else:
return real_request_handler[0](*arg)
return lambda: None, request_handler
def load_handler_failed_handler(e, modname):
if isinstance(e, ImportError):
return make_fault_handler(wsgi.FaultException("Unable to import module '{}'".format(modname), str(e), None))
elif isinstance(e, SyntaxError):
trace = "File \"%s\" Line %s\n\t%s" % (e.filename, e.lineno, e.text)
fault = wsgi.FaultException("Syntax error in module '{}'".format(modname), str(e), trace)
else:
exc_info = sys.exc_info()
trace = traceback.format_list(traceback.extract_tb(exc_info[2]))
fault = wsgi.FaultException("module initialization error", str(e), trace[1:])
return make_fault_handler(fault)
def make_fault_handler(fault):
def result(*args):
raise fault
return result
def _get_handlers_immediate(handler, mode):
init_handler = lambda: None
"""
This is the old way we were loading modules.
It was causing intermittent build failures for unknown reasons.
Using the imp module seems to remove these failures.
The imp module appears to be more extreme in that it reloads
the module if it is already loaded, and it likely doesn't use any caches when
searching for the module but does a full directory search, which is what we want.
"""
#m = imp.load_module(modname, globals(), locals(), [])
try:
(modname, fname) = handler.rsplit('.', 1)
except ValueError as e:
fault = wsgi.FaultException("Bad handler '{}'".format(handler), str(e), None)
request_handler = make_fault_handler(fault)
return init_handler, request_handler
file_handle = None
try:
file_handle, pathname, desc = imp.find_module(modname)
if file_handle is None:
module_type = desc[2]
if module_type == imp.C_BUILTIN:
request_handler = make_fault_handler(wsgi.FaultException("Cannot use built-in module {} as a handler module".format(modname), None, None))
return init_handler, request_handler
m = imp.load_module(modname, file_handle, pathname, desc)
except Exception as e:
request_handler = load_handler_failed_handler(e, modname)
return init_handler, request_handler
finally:
if file_handle is not None:
file_handle.close()
try:
init_handler = getattr(m, 'init')
except AttributeError as e:
pass
try:
request_handler = make_final_handler(getattr(m, fname), mode)
except AttributeError as e:
fault = wsgi.FaultException("Handler '{}' missing on module '{}'".format(fname, modname), str(e), None)
request_handler = make_fault_handler(fault)
return init_handler, request_handler
def set_environ(credentials):
key, secret, session = credentials.get('key'), credentials.get('secret'), credentials.get('session')
#TODO delete from environ if params not found
if credentials.get('key'):
os.environ['AWS_ACCESS_KEY_ID'] = key
if credentials.get('secret'):
os.environ['AWS_SECRET_ACCESS_KEY'] = secret
if credentials.get('session'):
os.environ['AWS_SESSION_TOKEN'] = session
os.environ['AWS_SECURITY_TOKEN'] = session
def wait_for_start(ctrl_sock):
(invokeid, mode, handler, suppress_init, credentials) = lambda_runtime.recv_start(ctrl_sock)
set_environ(credentials)
lambda_runtime.report_running(invokeid);
return (invokeid, mode, handler, suppress_init, credentials)
def wait_for_invoke(ctrl_sock):
(invokeid, data_sock, credentials, event_body, context_objs, invoked_function_arn) = lambda_runtime.receive_invoke(ctrl_sock)
set_environ(credentials)
return (invokeid, data_sock, credentials, event_body, context_objs, invoked_function_arn)
def make_final_handler(handlerfn, mode):
if mode == "http":
def result(sockfd):
invoke_http(handlerfn, sockfd)
elif mode == "event":
return handlerfn
else:
def result(sockfd):
raise wsgi.FaultException("specified mode is invalid: " + mode)
return result
def invoke_http(handlerfn, sockfd):
fault_data = wsgi.handle_one(sockfd, ('localhost', 80), handlerfn)
if fault_data:
raise wsgi.FaultException(fault_data.msg, fault_data.except_value, fault_data.trace)
def try_or_raise(function, error_message):
try:
return function()
except Exception as e:
raise JsonError(sys.exc_info(), error_message)
def make_error(errorMessage, errorType, stackTrace): #stackTrace is an array
result = {}
if errorMessage:
result['errorMessage'] = errorMessage
if errorType:
result['errorType'] = errorType
if stackTrace:
result['stackTrace'] = stackTrace
return result
def handle_http_request(request_handler, invokeid, sockfd):
try:
request_handler(sockfd)
except wsgi.FaultException as e:
lambda_runtime.report_fault(invokeid, e.msg, e.except_value, e.trace)
finally:
try:
os.close(sockfd)
except Exception as e:
print("Error closing original data connection descriptor", file=sys.stderr)
traceback.print_exc()
finally:
lambda_runtime.report_done(invokeid, None, None)
def to_json(obj):
return json.dumps(obj, default=decimal_serializer)
def handle_event_request(request_handler, invokeid, event_body, context_objs, invoked_function_arn):
errortype = None
try:
client_context = context_objs.get('client_context')
if client_context:
client_context = try_or_raise(lambda: json.loads(client_context), "Unable to parse client context")
context = LambdaContext(invokeid, context_objs, client_context, invoked_function_arn)
json_input = try_or_raise(lambda: json.loads(event_body), "Unable to parse input as json")
result = request_handler(json_input, context)
result = try_or_raise(lambda: to_json(result), "An error occurred during JSON serialization of response")
except wsgi.FaultException as e:
lambda_runtime.report_fault(invokeid, e.msg, e.except_value, None)
result = make_error(e.msg, None, None)
result = to_json(result)
errortype = "unhandled"
except JsonError as e:
result = report_fault_helper(invokeid, e.exc_info, e.msg)
result = to_json(result)
errortype = "unhandled"
except Exception as e:
result = report_fault_helper(invokeid, sys.exc_info(), None)
result = to_json(result)
errortype = "unhandled"
lambda_runtime.report_done(invokeid, errortype, result)
def report_fault_helper(invokeid, exc_info, msg):
etype, value, tb = exc_info
if msg:
msgs = [msg, str(value)]
else:
msgs = [str(value), etype.__name__]
tb_tuples = traceback.extract_tb(tb)
if sys.version_info[0] >= 3:
awesome_range = range
else:
awesome_range = xrange
for i in awesome_range(len(tb_tuples)):
if "/bootstrap.py" not in tb_tuples[i][0]: # filename of the tb tuple
tb_tuples = tb_tuples[i:]
break
lambda_runtime.report_fault(invokeid, msgs[0], msgs[1], "Traceback (most recent call last):\n" + ''.join(traceback.format_list(tb_tuples)) + ''.join(traceback.format_exception_only(etype, value)))
return make_error(str(value), etype.__name__, tb_tuples)
class CustomFile(object):
def __init__(self, fd):
self._fd = fd
def __getattr__(self, attr):
return getattr(self._fd, attr)
def write(self, msg):
lambda_runtime.log_bytes(msg, self._fd.fileno())
self._fd.flush()
def writelines(self, msgs):
for msg in msgs:
lambda_runtime.log_bytes(msg, self._fd.fileno())
self._fd.flush()
class CognitoIdentity(object):
__slots__ = ["cognito_identity_id", "cognito_identity_pool_id"]
class Client(object):
__slots__ = ["installation_id", "app_title", "app_version_name", "app_version_code", "app_package_name"]
class ClientContext(object):
__slots__ = ['custom', 'env', 'client']
def make_obj_from_dict(_class, _dict, fields=None):
if _dict is None:
return None
obj = _class()
set_obj_from_dict(obj, _dict)
return obj
def set_obj_from_dict(obj, _dict, fields=None):
if fields is None:
fields = obj.__class__.__slots__
for field in fields:
setattr(obj, field, _dict.get(field, None))
class LambdaContext(object):
def __init__(self, invokeid, context_objs, client_context, invoked_function_arn=None):
self.aws_request_id = invokeid
self.log_group_name = os.environ['AWS_LAMBDA_LOG_GROUP_NAME']
self.log_stream_name = os.environ['AWS_LAMBDA_LOG_STREAM_NAME']
self.function_name = os.environ["AWS_LAMBDA_FUNCTION_NAME"]
self.memory_limit_in_mb = os.environ['AWS_LAMBDA_FUNCTION_MEMORY_SIZE']
self.function_version = os.environ['AWS_LAMBDA_FUNCTION_VERSION']
self.invoked_function_arn = invoked_function_arn
self.client_context = make_obj_from_dict(ClientContext, client_context)
if self.client_context is not None:
self.client_context.client = make_obj_from_dict(Client, self.client_context.client)
self.identity = make_obj_from_dict(CognitoIdentity, context_objs)
def get_remaining_time_in_millis(self):
return lambda_runtime.get_remaining_time()
def log(self, msg):
lambda_runtime.send_console_message(str(msg))
class LambdaLoggerHandler(logging.Handler):
def __init__(self):
logging.Handler.__init__(self)
def emit(self, record):
lambda_runtime.send_console_message(self.format(record))
class LambdaLoggerFilter(logging.Filter):
def filter(self, record):
record.aws_request_id = _GLOBAL_AWS_REQUEST_ID or ""
return True
class JsonError(Exception):
def __init__(self, exc_info, msg):
self.exc_info = exc_info
self.msg = msg
_GLOBAL_DEFAULT_TIMEOUT = socket._GLOBAL_DEFAULT_TIMEOUT
_GLOBAL_AWS_REQUEST_ID = None
def main():
if sys.version_info[0] < 3:
reload(sys)
sys.setdefaultencoding('utf-8')
sys.stdout = CustomFile(sys.stdout)
sys.stderr = CustomFile(sys.stderr)
logging.Formatter.converter = time.gmtime
logger = logging.getLogger()
logger_handler = LambdaLoggerHandler()
logger_handler.setFormatter(logging.Formatter('[%(levelname)s]\t%(asctime)s.%(msecs)dZ\t%(aws_request_id)s\t%(message)s\n', '%Y-%m-%dT%H:%M:%S'))
logger_handler.addFilter(LambdaLoggerFilter())
logger.addHandler(logger_handler)
global _GLOBAL_AWS_REQUEST_ID
ctrl_sock = os.getenv('LAMBDA_CONTROL_SOCKET')
if ctrl_sock is None:
raise Exception("LAMBDA_CONTROL_SOCKET not set");
(invokeid, mode, handler, suppress_init, credentials) = wait_for_start(int(ctrl_sock))
sys.path.insert(0, os.environ['LAMBDA_TASK_ROOT'])
init_handler, request_handler = _get_handlers(handler, mode, suppress_init)
try:
init_handler()
except wsgi.FaultException as e:
lambda_runtime.report_fault(invokeid, e.msg, e.except_value, e.trace)
finally:
lambda_runtime.report_done(invokeid, None, None)
while True:
(invokeid, sockfd, credentials, event_body, context_objs, invoked_function_arn) = wait_for_invoke(int(ctrl_sock))
_GLOBAL_AWS_REQUEST_ID = invokeid
if mode == "http":
handle_http_request(request_handler, invokeid, sockfd)
elif mode == "event":
handle_event_request(request_handler, invokeid, event_body, context_objs, invoked_function_arn)
if __name__ == '__main__':
main()
"""
aws_lambda.wsgi
Amazon Lambda
Copyright (c) 2013 Amazon. All rights reserved.
Lambda wsgi implementation
"""
from __future__ import print_function
try:
# for python 3
from http.server import BaseHTTPRequestHandler, HTTPServer
import urllib.request, urllib.parse, urllib.error
from urllib.parse import unquote
except ImportError:
# for python 2
from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer
import urllib
from urllib import unquote
import socket
from wsgiref.simple_server import ServerHandler
import sys
import os
import traceback
class FaultData(object):
"""
Contains three fields, msg, except_value, and trace
msg is mandatory and must be a string
except_value and trace are optional and must be a string or None.
The constructor will convert all values to strings through str().
In addition, the constructor will try to join iterable trace values with "\n".join.
"""
def __init__(self, msg, except_value=None, trace=None):
try:
trace_is_string = isinstance(trace, basestring)
except NameError:
trace_is_string = isinstance(trace, str)
if not (trace is None or trace_is_string):
try:
trace = "\n".join(trace)
except TypeError:
trace = str(trace)
self.msg = str(msg)
self.except_value = except_value if except_value is None else str(except_value)
self.trace = trace if trace is None else str(trace)
class FaultException(Exception):
def __init__(self, msg, except_value=None, trace=None):
fault_data = FaultData(msg, except_value, trace)
self.msg = fault_data.msg
self.except_value = fault_data.except_value
self.trace = fault_data.trace
def handle_one(sockfd, client_addr, app):
""" This function calls the Request handler. It returns a FaultData object if a fault occurs, else None"""
try:
sock = socket.fromfd(sockfd, socket.AF_INET, socket.SOCK_STREAM)
handler = WSGIGir_RequestHandler(sock, client_addr, app)
return handler.fault
except socket.error as e:
print("Error building a socket object: {}".format(e))
sys.exit(1)
finally:
sock.close()
class Handler(ServerHandler):
wsgi_run_once = True
def __init__(self, stdin, stdout, stderr, environ, request_handler):
"""set multithread=False and multiprocess=False"""
self.stdin = stdin
self.stdout = stdout
self.stderr = stderr
self.base_env = environ
self.request_handler = request_handler # back-pointer for logging
self.wsgi_multithread = False
self.wsgi_multiprocess = False
self.fault = None
def handle_error(self):
"""Catch errors that occur when serializing the return value to HTTP response and report fault"""
# There is a bug in some versions of wsgi where code here fails because status is None or environ is None
self.environ = self.environ or {'SERVER_PROTOCOL' : 'HTTP/1.0'}
self.status = self.status or "500 Internal server error"
exc_type, exc_value, exc_traceback = sys.exc_info()
trace = traceback.format_list(traceback.extract_tb(exc_traceback))
self.fault = FaultData("Unable to convert result into http response", exc_value, trace)
ServerHandler.handle_error(self)
def close(self):
# There is a bug in some versions of wsgi where code here fails because status is None or environ is None
self.environ = self.environ or {'SERVER_PROTOCOL' : 'HTTP/1.0'}
self.status = self.status or "500 Internal server error"
ServerHandler.close(self)
# define helper function based on the version of python we are running
if sys.version_info[0] < 3:
def get_content_type_helper(self):
return self.headers.typeheader
def get_headers_helper(self):
return self.headers.headers
def get_length_helper(self):
return self.headers.getheader('content-length')
def parse_header_helper(h):
k, v = h.split(':', 1)
return (k , v)
else:
def get_content_type_helper(self):
return self.headers.get_content_type()
def get_headers_helper(self):
return self.headers.items()
def get_length_helper(self):
return self.headers.get('content-length')
def parse_header_helper(h):
k, v = h
return (k , v)
class WSGIGir_RequestHandler(BaseHTTPRequestHandler):
"""WSGI HTTP request handler
Class which inherits the HTTP request handler base class and takes application as the input.
Most of the things are taken from wsgiref package's WSGIRequestHandler class.
"""
def __init__(self, request, client_address, app):
self.app = app
self.fault = None
# set app and call super class constructor
BaseHTTPRequestHandler.__init__(self, request, client_address, '')
def get_app(self):
"""This function returns the application which has to be run"""
return self.app
def get_environ(self):
# Set up base environment
env = {}
env['CONTENT_LENGTH']=''
env['GATEWAY_INTERFACE'] = 'CGI/1.1' # TODO we may change this for simple-compute
env['SCRIPT_NAME'] = ''
env['SERVER_PROTOCOL'] = self.request_version
env['REQUEST_METHOD'] = self.command
if '?' in self.path:
path,query = self.path.split('?',1)
else:
path,query = self.path,''
env['PATH_INFO'] = unquote(path)
env['QUERY_STRING'] = query
host = self.address_string()
if host != self.client_address[0]:
env['REMOTE_HOST'] = host
env['REMOTE_ADDR'] = self.client_address[0]
# 2 vs 3
if get_content_type_helper(self) is None:
env['CONTENT_TYPE'] = self.headers.type
else:
env['CONTENT_TYPE'] = get_content_type_helper(self)
length = get_length_helper(self)
if length:
env['CONTENT_LENGTH'] = length
for h in get_headers_helper(self):
(k, v) = parse_header_helper(h)
k = k.replace('-', '_').upper()
v = v.strip()
if k in env:
continue # skip content length, type,etc.
if 'HTTP_' + k in env:
env['HTTP_' + k] += ',' + v # comma-separate multiple headers
else:
env['HTTP_' + k] = v
return env
def get_stderr(self):
return sys.stderr
def send_error(self, code, message=None):
"""Detect errors that occur when reading the HTTP request"""
if message is None and code in self.responses:
message = self.responses[code][0]
self.fault = FaultData("Unable to parse HTTP request", message)
BaseHTTPRequestHandler.send_error(self, code, message)
def handle(self):
"""Handle a single HTTP request"""
self.raw_requestline = self.rfile.readline()
if not self.parse_request(): #An error code has been sent, just exit
return
handler = Handler(
self.rfile, self.wfile, self.get_stderr(), self.get_environ(), self
)
def wrapped_app(environ, start_response):
"""Catch user code exceptions so we can report a fault"""
try:
return self.get_app()(environ, start_response)
except FaultException as e:
self.fault = FaultData(e.msg, e.except_value, e.trace)
return handler.error_output(environ, start_response)
except Exception as e:
trace = traceback.format_list(traceback.extract_tb(sys.exc_info()[2]))
self.fault = FaultData("Failure while running task", e, trace[1:])
return handler.error_output(environ, start_response)
handler.run(wrapped_app) # pass wrapped application to handler to run it.
self.fault = self.fault or handler.fault
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment