Last active
October 27, 2020 07:07
-
-
Save thehesiod/2e4094a1db1190f7e122e7043f1973a0 to your computer and use it in GitHub Desktop.
Moto Service Helper Class
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import functools | |
import logging | |
import os | |
import threading | |
import socket | |
import http.server | |
from typing import Dict, Any, Optional | |
# Third Party | |
import aiohttp | |
import moto.server | |
import botocore.session | |
import netifaces | |
import wrapt | |
import werkzeug.serving | |
import aiobotocore.session | |
_SERVICE_ENDPOINT_TEMPLATE = '{service_name}_mock_endpoint_url' | |
def get_free_tcp_port(release_socket: bool = False): | |
sckt = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
sckt.bind(('', 0)) | |
addr, port = sckt.getsockname() | |
if release_socket: | |
sckt.close() | |
return port | |
return sckt, port | |
# AMI, OSX | |
_iface_whitelist_prefixes = {'eth', 'en'} | |
def get_ip_address(): | |
for iface in netifaces.interfaces(): | |
if not any(iface.startswith(prefix) for prefix in _iface_whitelist_prefixes): | |
continue | |
addrs: Dict[int, Dict[str, Any]] = netifaces.ifaddresses(iface) | |
if netifaces.AF_INET in addrs: | |
assert len(addrs[netifaces.AF_INET]) == 1 | |
ip_address = addrs[netifaces.AF_INET][0]['addr'] | |
return ip_address | |
assert False | |
# Enable keep-alive | |
http.server.BaseHTTPRequestHandler.protocol_version = "HTTP/1.1" | |
class MotoService: | |
""" Will Create MotoService. | |
Service is ref-counted so there will only be one per process. Real Service will | |
be returned by `__aenter__`.""" | |
_services: Dict[str, Any] = dict() # {name: instance} | |
def __init__(self, service_name: str, port: Optional[int] = None, set_endpoint_url_env_var: bool = False): | |
self._service_name = service_name | |
if port: | |
self._socket = None | |
self._port = port | |
else: | |
self._socket, self._port = get_free_tcp_port() | |
self._thread = None | |
self._logger = logging.getLogger('MotoService') | |
self._refcount = None | |
self._ip_address = get_ip_address() | |
self._server: Optional[werkzeug.serving.ThreadedWSGIServer] = None | |
self._set_endpoint_url_env_var = set_endpoint_url_env_var | |
@staticmethod | |
def get_service_endpoint_url_from_env(service_name: str): | |
if service_name == "dynamodb2": | |
service_name = "dynamodb" | |
env_var = _SERVICE_ENDPOINT_TEMPLATE.format(service_name=service_name) | |
return os.environ.get(env_var) | |
@staticmethod | |
def set_service_endpoint_url_from_env(service_name: str, endpoint_url: str): | |
env_var = _SERVICE_ENDPOINT_TEMPLATE.format(service_name=service_name) | |
os.environ[env_var] = endpoint_url | |
@property | |
def endpoint_url(self): | |
return f'http://{self._ip_address}:{self._port}' | |
def __call__(self, func): | |
async def wrapper(*args, **kwargs): | |
await self._start() | |
try: | |
result = await func(*args, **kwargs) | |
finally: | |
await self._stop() | |
return result | |
functools.update_wrapper(wrapper, func) | |
wrapper.__wrapped__ = func | |
return wrapper | |
async def __aenter__(self): | |
svc = self._services.get(self._service_name) | |
if svc is None: | |
self._services[self._service_name] = self | |
self._refcount = 1 | |
await self._start() | |
return self | |
else: | |
svc._refcount += 1 | |
return svc | |
async def __aexit__(self, exc_type, exc_val, exc_tb): | |
self._refcount -= 1 | |
if self._socket: | |
self._socket.close() | |
self._socket = None | |
if self._refcount == 0: | |
del self._services[self._service_name] | |
await self._stop() | |
def _server_entry(self): | |
self._main_app = moto.server.DomainDispatcherApplication(moto.server.create_backend_app, service=self._service_name) | |
self._main_app.debug = True | |
if self._socket: | |
self._socket.close() # release right before we use it | |
self._socket = None | |
self._server = werkzeug.serving.make_server(self._ip_address, self._port, self._main_app, True) | |
self._server.serve_forever() | |
async def _start(self): | |
self._thread = threading.Thread(target=self._server_entry, daemon=True) | |
self._thread.start() | |
async with aiohttp.ClientSession() as session: | |
for i in range(0, 10): | |
if not self._thread.is_alive(): | |
break | |
try: | |
# we need to bypass the proxies due to monkeypatches | |
async with session.get(self.endpoint_url + '/static/', timeout=0.5): | |
pass | |
break | |
except (asyncio.TimeoutError, aiohttp.ClientConnectionError): | |
await asyncio.sleep(0.5) | |
else: | |
await self._stop() # pytest.fail doesn't call stop_process | |
raise Exception(f"Can not start service: {self._service_name}") | |
if self._set_endpoint_url_env_var: | |
self.set_service_endpoint_url_from_env(self._service_name, self.endpoint_url) | |
async def _stop(self): | |
if self._server: | |
self._server.shutdown() | |
self._thread.join() | |
def _wrapt_boto_create_client(wrapped, instance, args, kwargs): | |
def unwrap_args(service_name, region_name=None, api_version=None, | |
use_ssl=True, verify=None, endpoint_url=None, | |
aws_access_key_id=None, aws_secret_access_key=None, | |
aws_session_token=None, config=None): | |
if endpoint_url is None: | |
endpoint_url = MotoService.get_service_endpoint_url_from_env(service_name) | |
# https://github.com/spulec/moto/issues/2058 | |
aws_access_key_id = "foobar_key" | |
aws_secret_access_key = "foobar_secret" | |
return wrapped(service_name, region_name, api_version, use_ssl, verify, | |
endpoint_url, aws_access_key_id, aws_secret_access_key, | |
aws_session_token, config) | |
return unwrap_args(*args, **kwargs) | |
# https://github.com/spulec/moto/issues/2058 | |
for key in {'AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY'}: | |
if key in os.environ: | |
del os.environ[key] | |
def patch_boto(): | |
""" | |
Will patch botocore to set endpoint_url to: {SERVICE_NAME}_endpoint_url if | |
available | |
""" | |
if not isinstance(botocore.session.Session.create_client, wrapt.ObjectProxy): | |
wrapt.wrap_function_wrapper( | |
'botocore.session', | |
'Session.create_client', | |
_wrapt_boto_create_client | |
) | |
def unpatch_boto(): | |
if not isinstance(botocore.session.Session.create_client, wrapt.ObjectProxy): | |
return | |
botocore.session.Session.create_client = botocore.session.Session.create_client.__wrapped__ | |
def patch_aioboto(): | |
""" | |
Will patch aiobotocore to set endpoint_url to: {SERVICE_NAME}_endpoint_url if | |
available | |
""" | |
if not isinstance(aiobotocore.session.AioSession.create_client, wrapt.ObjectProxy): | |
wrapt.wrap_function_wrapper( | |
'aiobotocore.session', | |
'AioSession.create_client', | |
_wrapt_boto_create_client | |
) | |
def unpatch_aioboto(): | |
if not isinstance(aiobotocore.session.AioSession.create_client, wrapt.ObjectProxy): | |
return | |
aiobotocore.session.AioSession.create_client = aiobotocore.session.AioSession.create_client.__wrapped__ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment