Last active
March 18, 2024 09:28
-
-
Save max-arnold/004f0827d563638039bed8adb413df95 to your computer and use it in GitHub Desktop.
Test Yandex Cloud Functions written in Python and invoked as API Gateway integrations locally
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" | |
Test Yandex Cloud Functions written in Python and invoked as API Gateway integrations locally. | |
Heavily based on https://github.com/amancevice/python-lambda-gateway | |
MIT License | |
Copyright (c) 2020 Alexander Mancevice | |
Copyright (c) 2022-2023 Max Arnold | |
""" | |
import argparse | |
import asyncio | |
import datetime | |
import importlib | |
import json | |
import logging | |
import os | |
import random | |
import re | |
import socket | |
import string | |
import sys | |
import time | |
import uuid | |
from contextlib import contextmanager | |
from http import server | |
from urllib import parse | |
def set_stream_logger(name, level=logging.DEBUG, format_string=None): | |
""" | |
Adapted from boto3.set_stream_logger() | |
""" | |
if format_string is None: | |
format_string = "%(addr)s - - [%(asctime)s] %(levelname)s - %(message)s" | |
logger = logging.getLogger(name) | |
handler = logging.StreamHandler() | |
formatter = logging.Formatter(format_string, "%-d/%b/%Y %H:%M:%S") | |
adapter = logging.LoggerAdapter(logger, dict(addr="::1")) | |
logger.setLevel(level) | |
handler.setLevel(level) | |
handler.setFormatter(formatter) | |
logger.addHandler(handler) | |
return adapter | |
logger = set_stream_logger(__name__) | |
@contextmanager | |
def context_start(timeout=None): | |
""" | |
Yield mock YCF context object. | |
""" | |
yield RuntimeContext(timeout) | |
class RuntimeContext: | |
""" | |
Mock YCF runtime context object. | |
:param int timeout: YCF timeout in seconds | |
""" | |
def __init__(self, timeout=None): | |
self._start = datetime.datetime.now(datetime.UTC) | |
self._timeout = timeout or 30 | |
@staticmethod | |
def _random_id(): | |
return "".join( | |
random.choice(string.ascii_lowercase + string.digits) for _ in range(20) | |
) | |
@property | |
def function_name(self): | |
return self._random_id() | |
@property | |
def function_version(self): | |
return self._random_id() | |
@property | |
def function_folder_id(self): | |
return self._random_id() | |
@property | |
def invoked_function_arn(self): | |
return self.function_name | |
@property | |
def memory_limit_in_mb(self): | |
return 256 | |
@property | |
def aws_request_id(self): | |
return self.request_id | |
@property | |
def request_id(self): | |
return str(uuid.uuid1()) | |
@property | |
def log_group_name(self): | |
return self._random_id() | |
@property | |
def log_stream_name(self): | |
return self.function_version | |
@property | |
def deadline_ms(self): | |
deadline = self._start + datetime.timedelta(seconds=self._timeout) | |
return int(time.mktime(deadline.timetuple())) * 1000 | |
@property | |
def token(self): | |
return { | |
"access_token": str(uuid.uuid1()), | |
"expires_in": 42299, | |
"token_type": "Bearer", | |
} | |
def get_remaining_time_in_millis(self): | |
""" | |
Get remaining TTL for YCF context. | |
""" | |
delta = datetime.datetime.now(datetime.UTC) - self._start | |
remaining_time_in_s = self._timeout - delta.total_seconds() | |
if remaining_time_in_s < 0: | |
return 0 | |
return remaining_time_in_s * 1000 | |
class EventProxy: | |
def __init__(self, handler, src_path, url_pattern, operation, timeout=None): | |
self.handler = handler | |
self.src_path = src_path | |
self.url_pattern = url_pattern | |
self.operation = operation | |
self.timeout = timeout | |
def get_handler(self): | |
""" | |
Load handler function. | |
:returns function: YCF handler function | |
""" | |
*path, func = self.handler.split(".") | |
name = ".".join(path) | |
if not name: | |
raise ValueError(f"Bad handler signature '{self.handler}'") | |
try: | |
if self.src_path not in sys.path: | |
sys.path.append(os.path.abspath(self.src_path)) | |
module = importlib.import_module(name) | |
if name in sys.modules: | |
importlib.reload(module) | |
handler = getattr(module, func) | |
return handler | |
except ModuleNotFoundError as e: | |
logger.exception("Module import error: %s", e) | |
raise ValueError(f"Unable to import module '{name}'") | |
except AttributeError as e: | |
logger.exception("Handler import error: %s", e) | |
raise ValueError(f"Handler '{func}' missing on module '{name}'") | |
def get_httpMethod(self, event): | |
""" | |
Helper to get httpMethod from v0.1 or v1.0 events. | |
""" | |
if event.get("version") == "1.0": | |
return event["requestContext"]["httpMethod"] | |
elif event.get("httpMethod"): | |
return event["httpMethod"] | |
raise ValueError( # pragma: no cover | |
f"Unknown API Gateway payload version: {event.get('version')}" | |
) | |
def get_path(self, event): | |
""" | |
Helper to get path from v0.1 or v1.0 events. | |
""" | |
if event.get("version") == "1.0": | |
return event["path"] | |
elif event.get("url"): | |
return event["url"] | |
raise ValueError( # pragma: no cover | |
f"Unknown API Gateway payload version: {event.get('version')}" | |
) | |
def invoke(self, event): | |
with context_start(self.timeout) as context: | |
logger.info('Invoking "%s"', self.handler) | |
return asyncio.run(self.invoke_async_with_timeout(event, context)) | |
def match_pattern(self, path): | |
return re.match(self.url_pattern, path) | |
async def invoke_async(self, event, context=None): | |
""" | |
Wrapper to invoke the YCF handler asynchronously. | |
:param dict event: YCF event object | |
:param RuntimeContext context: Mock YCF runtime context | |
:returns dict: YCF invocation result | |
""" | |
httpMethod = self.get_httpMethod(event) | |
path = self.get_path(event) | |
# Reject request if not matching the pattern | |
if not self.match_pattern(path): | |
err = f"Rejected {path} :: URL pattern is {self.url_pattern}" | |
logger.error(err) | |
return self.jsonify(httpMethod, 403, message="Forbidden") | |
# Get & invoke YCF handler | |
try: | |
handler = self.get_handler() | |
loop = asyncio.get_running_loop() | |
return await loop.run_in_executor(None, handler, event, context) | |
except Exception as err: | |
logger.exception(err) | |
message = "Internal server error" | |
return self.jsonify(httpMethod, 502, message=message) | |
async def invoke_async_with_timeout(self, event, context=None): | |
""" | |
Wrapper to invoke the YCF handler with a timeout. | |
:param dict event: YCF event object | |
:param RuntimeContext context: Mock YCF runtime context | |
:returns dict: YCF invocation result or 408 TIMEOUT | |
""" | |
try: | |
coroutine = self.invoke_async(event, context) | |
return await asyncio.wait_for(coroutine, self.timeout) | |
except asyncio.TimeoutError: | |
httpMethod = self.get_httpMethod(event) | |
message = "Endpoint request timed out" | |
return self.jsonify(httpMethod, 504, message=message) | |
@staticmethod | |
def jsonify(httpMethod, statusCode, **kwargs): | |
""" | |
Convert dict into API Gateway response object. | |
:params str httpMethod: HTTP request method | |
:params int statusCode: Response status code | |
:params dict kwargs: Response object | |
""" | |
body = "" if httpMethod in ["HEAD"] else json.dumps(kwargs) | |
return { | |
"body": body, | |
"statusCode": statusCode, | |
"headers": { | |
"Content-Type": "application/json", | |
"Content-Length": len(body), | |
}, | |
} | |
class YCFRequestHandler(server.SimpleHTTPRequestHandler): | |
def do_DELETE(self): | |
self.invoke("DELETE") | |
def do_GET(self): | |
self.invoke("GET") | |
def do_HEAD(self): | |
self.invoke("HEAD") | |
def do_OPTIONS(self): | |
self.invoke("OPTIONS") | |
def do_PATCH(self): | |
self.invoke("PATCH") | |
def do_POST(self): | |
self.invoke("POST") | |
def do_PUT(self): | |
self.invoke("PUT") | |
def get_body(self): | |
""" | |
Get request body to forward to YCF handler. | |
""" | |
try: | |
content_length = int(self.headers.get("Content-Length")) | |
return self.rfile.read(content_length).decode() | |
except TypeError: | |
return "" | |
def get_event(self, httpMethod): | |
""" | |
Get YCF input event object. | |
:param str httpMethod: HTTP request method | |
:return dict: YCF event object | |
""" | |
if self.version == "0.1": | |
return self.get_event_v01(httpMethod) | |
elif self.version == "1.0": | |
return self.get_event_v10(httpMethod) | |
raise ValueError( # pragma: no cover | |
f"Unknown API Gateway payload version: {self.version}" | |
) | |
def get_params(self): | |
match = self.proxy.match_pattern(self.path) | |
if match: | |
return match.groupdict() | |
else: | |
return {} | |
def get_event_v01(self, httpMethod): | |
""" | |
Get YCF input event object (v0.1). | |
:param str httpMethod: HTTP request method | |
:return dict: YCF event object | |
""" | |
url = parse.urlparse(self.path) | |
path, *_ = url.path.split("?") | |
params = dict(parse.parse_qsl(url.query)) | |
req_time = datetime.datetime.now(datetime.UTC) | |
headers = dict(self.headers) | |
headers["X-Forwarded-For"] = "1.1.1.1" | |
headers["X-Forwarded-Proto"] = "http" | |
return { | |
"httpMethod": httpMethod, | |
"headers": headers, | |
"url": path, # "/slug/123?abc=d" | |
"params": self.get_params(), # {"param": "123"} | |
"multiValueParams": { | |
k: [v] for k, v in self.get_params().items() | |
}, # {"param": ["123"]} | |
"pathParams": self.get_params(), # {"param": "123"} | |
"multiValueHeaders": {k: [v] for k, v in headers.items()}, | |
"queryStringParameters": params, | |
"multiValueQueryStringParameters": {k: [v] for k, v in params.items()}, | |
"requestContext": { | |
"identity": { | |
"sourceIp": "1.1.1.1", | |
"userAgent": "Mozilla/5.0", | |
}, | |
"httpMethod": httpMethod, | |
"requestId": str(uuid.uuid1()), | |
"requestTime": req_time.strftime("%d/%b/%Y:%H:%M:%S +0000"), | |
"requestTimeEpoch": int(time.mktime(req_time.timetuple())), | |
}, | |
"body": self.get_body(), | |
"isBase64Encoded": False, | |
"path": path, # /slug/{param} | |
} | |
def get_event_v10(self, httpMethod): | |
""" | |
Get YCF input event object (v1.0). | |
:param str httpMethod: HTTP request method | |
:return dict: YCF event object | |
""" | |
url = parse.urlparse(self.path) | |
path, *_ = url.path.split("?") | |
params = dict(parse.parse_qsl(url.query)) | |
req_time = datetime.datetime.now(datetime.UTC) | |
headers = dict(self.headers) | |
headers["X-Forwarded-For"] = "1.1.1.1" | |
headers["X-Forwarded-Proto"] = "http" | |
return { | |
"httpMethod": httpMethod, | |
"headers": headers, | |
"multiValueHeaders": {k: [v] for k, v in headers.items()}, | |
"queryStringParameters": dict(parse.parse_qsl(url.query)), | |
"multiValueQueryStringParameters": {k: [v] for k, v in params.items()}, | |
"requestContext": { | |
"identity": { | |
"sourceIp": "1.1.1.1", | |
"userAgent": "Mozilla/5.0", | |
}, | |
"httpMethod": httpMethod, | |
"requestId": str(uuid.uuid1()), | |
"requestTime": req_time.strftime("%d/%b/%Y:%H:%M:%S +0000"), | |
"requestTimeEpoch": int(time.mktime(req_time.timetuple())), | |
}, | |
"version": "1.0", | |
"resource": path, # /slug/{param} | |
"path": path, # /slug/123 | |
"pathParameters": self.get_params(), # {"param": "123"} | |
"body": self.get_body(), | |
"isBase64Encoded": False, | |
"parameters": self.get_params(), # {"param": "123"} | |
"multiValueParameters": { | |
k: [v] for k, v in self.get_params().items() | |
}, # {"param": ["123"]} | |
"operationId": self.proxy.operation, | |
} | |
def invoke(self, httpMethod): | |
""" | |
Proxy requests to YCF handler | |
:param dict event: YCF event object | |
:param RuntimeContext context: Mock YCF runtime context | |
:returns dict: YCF invocation result | |
""" | |
# Get YCF event | |
event = self.get_event(httpMethod) | |
cors = { | |
"access-control-allow-headers": "*", | |
"access-control-allow-methods": "OPTIONS, GET, HEAD, POST", | |
"access-control-allow-origin": "*", | |
} | |
if httpMethod == "OPTIONS": | |
status = 200 | |
headers = cors | |
mvheaders = {} | |
body = "" | |
else: | |
# Get YCF result | |
res = self.proxy.invoke(event) | |
# Parse response | |
status = res.get("statusCode") or 500 | |
headers = res.get("headers") or {} | |
headers.update(cors) | |
mvheaders = res.get("multiValueHeaders") or {} | |
body = res.get("body") or "" | |
# Send response | |
self.send_response(status) | |
for key, val in headers.items(): | |
self.send_header(key, val) | |
for key, val in mvheaders.items(): | |
for v in val: | |
self.send_header(key, v) | |
self.end_headers() | |
self.wfile.write(body.encode()) | |
@classmethod | |
def set_proxy(cls, proxy, version): | |
""" | |
Set up YCFRequestHandler. | |
""" | |
cls.proxy = proxy | |
cls.version = version | |
def get_best_family(*address): # pragma: no cover | |
""" | |
Helper for Python 3.7 compat. | |
:params tuple address: host/port tuple | |
""" | |
# Python 3.8+ | |
try: | |
return server._get_best_family(*address) | |
# Python 3.7 -- taken from http.server._get_best_family() in 3.8 | |
except AttributeError: | |
infos = socket.getaddrinfo( | |
*address, | |
type=socket.SOCK_STREAM, | |
flags=socket.AI_PASSIVE, | |
) | |
family, type, proto, canonname, sockaddr = next(iter(infos)) | |
return family, sockaddr | |
def run(httpd): | |
""" | |
Run API Gateway server. | |
:param object httpd: ThreadingHTTPServer instance | |
:param str base_path: REST API base path | |
""" | |
host, port = httpd.socket.getsockname()[:2] | |
url_host = f"[{host}]" if ":" in host else host | |
sys.stderr.write( | |
f"Serving HTTP on {host} port {port} " f"(http://{url_host}:{port}) ...\n" | |
) | |
try: | |
httpd.serve_forever() | |
except KeyboardInterrupt: | |
sys.stderr.write("\nKeyboard interrupt received, exiting.\n") | |
finally: | |
httpd.shutdown() | |
def export_variables(env_file): | |
""" | |
Export environment variables from JSON file | |
""" | |
with open(env_file) as json_file: | |
env_vars = json.loads(json_file.read()) | |
for env_name, env_value in env_vars.items(): | |
os.environ[str(env_name)] = str(env_value) | |
def get_opts(): | |
""" | |
Get CLI options. | |
""" | |
parser = argparse.ArgumentParser( | |
description="Start a simple YC API Gateway server", | |
) | |
parser.add_argument( | |
"-e", | |
"--env", | |
dest="env", | |
help="Path to environment JSON file", | |
metavar="ENV", | |
) | |
parser.add_argument( | |
"-s", | |
"--src-path", | |
dest="src_path", | |
help="Set base path for source code", | |
metavar="SRC_PATH", | |
default="", | |
) | |
parser.add_argument( | |
"-o", | |
"--operation", | |
dest="operation", | |
help="Operation ID", | |
metavar="OPERATION", | |
default="operation_id", | |
) | |
parser.add_argument( | |
"-u", | |
"--url-pattern", | |
dest="url_pattern", | |
help="URL pattern regex with named parameter groups", | |
metavar="URL_PATTERN", | |
default="/", | |
) | |
parser.add_argument( | |
"-b", | |
"--bind", | |
dest="bind", | |
metavar="ADDR", | |
help="Specify alternate bind address [default: all interfaces]", | |
) | |
parser.add_argument( | |
"-p", | |
"--port", | |
dest="port", | |
default=8000, | |
help="Specify alternate port [default: 8000]", | |
type=int, | |
) | |
parser.add_argument( | |
"-t", | |
"--timeout", | |
dest="timeout", | |
help="YCF timeout.", | |
metavar="SECONDS", | |
type=int, | |
) | |
parser.add_argument( | |
"-V", | |
"--payload-version", | |
choices=["0.1", "1.0"], | |
default="1.0", | |
help="API Gateway payload version [default: 1.0]", | |
) | |
parser.add_argument( | |
"HANDLER", | |
help="YCF handler signature", | |
) | |
return parser.parse_args() | |
def main(): | |
""" | |
Main entrypoint. | |
""" | |
# Parse opts | |
opts = get_opts() | |
if opts.env: | |
export_variables(opts.env) | |
# Setup handler | |
address_family, addr = get_best_family(opts.bind, opts.port) | |
proxy = EventProxy( | |
opts.HANDLER, opts.src_path, f"^{opts.url_pattern}$", opts.operation, opts.timeout | |
) | |
YCFRequestHandler.set_proxy(proxy, opts.payload_version) | |
server.ThreadingHTTPServer.address_family = address_family | |
# Start server | |
with server.ThreadingHTTPServer(addr, YCFRequestHandler) as httpd: | |
run(httpd) | |
if __name__ == "__main__": # pragma: no cover | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment