Last active
August 23, 2019 20:22
-
-
Save mmerickel/2516052c6ba5e20b299f7666fe6426d3 to your computer and use it in GitHub Desktop.
subparse base cli
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from subparse import command | |
@command('.shell') | |
def shell(parser): | |
""" Launch a python interpreter.""" | |
@command('.run') | |
def run(parser): | |
""" Run a script. | |
By default, this command will execute a ``main(cli, args)`` function in | |
the specified script. If the path points directly at a callable then it | |
will be used instead but must still receive the args. | |
If an 'options' callable is found attached to the main function then | |
it will be invoked with an argparse parser and used to compute the | |
arguments. | |
""" | |
parser.add_argument( | |
'-m', | |
dest='module', | |
required=True, | |
nargs=argparse.REMAINDER, | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from contextlib import contextmanager | |
from getpass import getpass | |
import importlib_metadata | |
import logging | |
import logging.config | |
import os | |
import plaster | |
from pyramid.decorator import reify | |
import sentry_sdk | |
from subparse import CLI | |
import sys | |
from myapp.utils.sentry import init_sentry_from_settings | |
from myapp.utils.settings import asbool, load_settings_from_file | |
from myapp.utils.tm import tm_context | |
from .validation import ValidationError | |
CONFIG_ENVIRON_KEY = 'MYAPP_CONFIG' | |
class MyApp(object): | |
stdin = sys.stdin | |
stdout = sys.stdout | |
stderr = sys.stderr | |
def __init__(self, config_file, service_flags=None): | |
self.config_file = config_file | |
self.service_flags = service_flags | |
def read_stdin(self, text=True): | |
if text: | |
return self.stdin.read() | |
else: | |
return self.stdin.buffer.read() | |
def out(self, msg): | |
self.stdout.write(msg) | |
if not msg.endswith('\n'): | |
self.stdout.write('\n') | |
@reify | |
def _log(self): | |
return logging.getLogger(__name__) | |
def error(self, msg): | |
self.stderr.write(msg) | |
if not msg.endswith('\n'): | |
self.stderr.write('\n') | |
def abort(self, error, code=1): | |
self.error(error) | |
raise AbortCLI(error, code) | |
@reify | |
def plaster(self): | |
return plaster.get_loader(self.config_file, protocols=['wsgi']) | |
@reify | |
def settings(self): | |
return load_settings_from_file(self.config_file) | |
@reify | |
def service_factory(self): | |
from myapp.services import make_service_factory | |
flags = self.service_flags or {} | |
return make_service_factory(self.settings, flags) | |
@reify | |
def services(self): | |
return self.service_factory.create_container() | |
@contextmanager | |
def services_context(self): | |
services = self.service_factory.create_container() | |
tm = services.get(name='tm') | |
with tm_context(tm): | |
yield services | |
def get_service(self, *args, **kwargs): | |
return self.services.get(*args, **kwargs) | |
@property | |
def tm(self): | |
return self.get_service(name='tm') | |
@property | |
def db(self): | |
return self.get_service(name='db') | |
def prompt(self, | |
prompt, | |
validator=None, | |
attempts=3, | |
confirm=False, | |
secure=False, | |
default=None, | |
): | |
if default is None: | |
default = '' | |
while attempts > 0: | |
value = _get_input(prompt, secure=secure) | |
if not value: | |
value = default | |
try: | |
if validator is not None: | |
value = validator(value) | |
except ValidationError as ex: | |
self.error(ex.message) | |
else: | |
if confirm: | |
confirm_value = _get_input(confirm, secure=secure) | |
if not confirm_value: | |
confirm_value = default | |
if confirm_value == value: | |
return value | |
self.error('Values do not match.') | |
else: | |
return value | |
attempts -= 1 | |
self.abort('Too many attempts, aborting.') | |
def input_file(self, path, **kw): | |
from .utils import input_file | |
return input_file(self, path, **kw) | |
def output_file(self, path, **kw): | |
from .utils import output_file | |
return output_file(self, path, **kw) | |
def _get_input(prompt, secure=False): | |
prompt = prompt.strip() + ' ' | |
if secure: | |
value = getpass(prompt) | |
else: | |
value = input(prompt) | |
value = value.strip() | |
return value | |
class AbortCLI(Exception): | |
def __init__(self, message, code): | |
self.message = message | |
self.code = code | |
def context_factory( | |
cli, | |
args, | |
ignore_config_file=False, | |
ignore_schema=False, | |
without_tm=False, | |
service_flags=None, | |
): | |
if getattr(args, 'reload', False): | |
import hupper | |
reloader = hupper.start_reloader( | |
__name__ + '.main', | |
shutdown_interval=30, | |
) | |
reloader.watch_files([args.config_file]) | |
app = MyApp( | |
config_file=args.config_file, | |
service_flags=service_flags, | |
) | |
if not ignore_config_file and not os.path.exists(app.config_file): | |
app.abort('Invalid config file, does not exist.') | |
if args.quiet: | |
root_logger = logging.getLogger('') | |
root_logger.setLevel(logging.CRITICAL) | |
elif ignore_config_file: | |
root_logger = logging.getLogger('') | |
root_logger.setLevel(logging.INFO) | |
else: | |
app.plaster.setup_logging() | |
if not ignore_config_file: | |
init_sentry_from_settings(app.settings) | |
try: | |
if ignore_config_file or ignore_schema or without_tm: | |
yield app | |
else: | |
with tm_context(app.tm): | |
yield app | |
except Exception as ex: | |
if args.pdb is True or ( | |
args.pdb is None | |
and not ignore_config_file | |
and asbool(app.settings.get('app.auto_pdb'), False) | |
and not isinstance(ex, AbortCLI) | |
): | |
import pdb # noqa T100 | |
pdb.post_mortem() | |
raise | |
def generic_options(parser): | |
default_config = os.environ.get(CONFIG_ENVIRON_KEY, 'site.ini') | |
parser.add_argument( | |
'-c', '--config-file', | |
default=default_config, | |
) | |
parser.add_argument( | |
'-q', '--quiet', | |
action='store_true', | |
) | |
parser.add_argument( | |
'--pdb', | |
action='store_true', | |
) | |
parser.add_argument( | |
'---no-pdb', | |
action='store_false', | |
dest='pdb', | |
) | |
def main(): | |
cli = CLI( | |
version=importlib_metadata.version('myapp'), | |
context_factory=context_factory, | |
) | |
cli.add_generic_options(generic_options) | |
cli.load_commands('.commands') | |
try: | |
return cli.run() | |
except AbortCLI as ex: | |
return ex.code | |
except Exception: | |
sentry_sdk.capture_exception() | |
raise |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
from subparse import parse_docstring | |
from myapp.utils.resolver import maybe_resolve | |
def main(cli, args): | |
module_name, *module_args = args.module | |
fn = maybe_resolve(module_name) | |
if not callable(fn): | |
fn = fn.main | |
run_options = getattr(fn, 'options', None) | |
if run_options: | |
short_desc, long_desc = parse_docstring(run_options.__doc__) | |
if long_desc: | |
long_desc = short_desc + '\n\n' + long_desc | |
parser = argparse.ArgumentParser( | |
prog=module_name, | |
description=long_desc, | |
formatter_class=argparse.RawTextHelpFormatter, | |
) | |
run_options(parser) | |
module_args = parser.parse_args(module_args) | |
return fn(cli, module_args) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def make_default_shell(interact=interact): | |
import sys | |
def shell(env, help): | |
cprt = 'Type "help" for more information.' | |
banner = 'Python %s on %s\n%s' % (sys.version, sys.platform, cprt) | |
banner += '\n\n' + help + '\n' if help else '\n' | |
interact(banner, local=env) | |
return shell | |
def make_ipython_shell(IPShellFactory=None): | |
try: | |
import IPython | |
IPShellFactory = IPython.start_ipython | |
except ImportError: | |
return None | |
def shell(env, help): | |
from traitlets.config import Config | |
c = Config() | |
c.TerminalInteractiveShell.banner2 = help + '\n' if help else '' | |
IPShellFactory(argv=[], user_ns=env, config=c) | |
return shell | |
def make_shell(): | |
shell = make_ipython_shell() | |
if shell is None: | |
shell = make_default_shell() | |
return shell | |
def run_shell(env, help='', *, shell=None): | |
if shell is None: | |
shell = make_shell() | |
return shell(env, help) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from contextlib import suppress | |
import textwrap | |
from transaction.interfaces import NoTransaction | |
from myapp.utils.resolver import maybe_resolve | |
from .utils import run_shell | |
def main(cli, args): | |
env = { | |
'cli': cli, | |
'settings': cli.settings, | |
'model': maybe_resolve('myapp.model'), | |
'S': maybe_resolve('myapp.services'), | |
'tm': cli.tm, | |
'db': cli.get_service(name='db'), | |
'services': cli.services, | |
# helper methods | |
'find_user': maybe_resolve('.utils.find_user'), | |
# extra imports for fun | |
'datetime': maybe_resolve('datetime.datetime'), | |
'timedelta': maybe_resolve('datetime.timedelta'), | |
'json': maybe_resolve('json'), | |
'pytz': maybe_resolve('pytz'), | |
'sa': maybe_resolve('sqlalchemy'), | |
} | |
help = textwrap.dedent( | |
""" | |
Imports: | |
datetime datetime.datetime | |
sa sqlalchemy | |
timedelta datetime.timedelta | |
json json | |
pytz pytz | |
model myapp.model | |
S myapp.services | |
Context: | |
cli The CLI application object. | |
settings The settings parsed from the ini file. | |
Helpers Helper functions bound to a service container. | |
Service context: | |
db database session | |
services services factory | |
tm transaction manager | |
Helper methods: | |
find_user(id) -> User | |
""").strip() | |
run_shell(env, help) | |
with suppress(NoTransaction): | |
cli.tm.get() | |
log.debug('aborting pending changes') | |
cli.tm.abort() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datetime import datetime | |
class ValidationError(ValueError): | |
def __init__(self, message, invalid_value): | |
super().__init__(message) | |
self.message = message | |
self.invalid_value = invalid_value | |
def validate_text(min=1, max=None, field='Value', next_validator=None): | |
if min is not None and max is not None: | |
assert 0 < min <= max | |
if min == max: | |
if min == 1: | |
msg = f'{field} must be exactly {min} characters.' | |
else: | |
msg = f'{field} must be exactly 1 character.' | |
else: | |
msg = f'{field} must be between {min} and {max} characters.' | |
elif min is not None: | |
assert min > 0 | |
if min > 1: | |
msg = f'{field} must be at least {min} characters.' | |
else: | |
msg = f'{field} must not be empty.' | |
elif max is not None: | |
assert max > 0 | |
if max == 1: | |
msg = f'{field} must be at most 1 character.' | |
else: | |
msg = f'{field} cannot be longer than {max} characters.' | |
def validator(value): | |
if min is not None and len(value) < min: | |
raise ValidationError(msg, value) | |
if max is not None and len(value) > max: | |
raise ValidationError(msg, value) | |
if next_validator is not None: | |
return next_validator(value) | |
return value | |
return validator | |
validate_email = validate_text(min=1, max=255, field='Email') | |
validate_password = validate_text(min=8, field='Password') | |
def validate_bool(value): | |
value = value.lower() | |
if value == 'y': | |
return True | |
if value == 'n': | |
return False | |
raise ValidationError('Please answer "y" or "n".', value) | |
def validate_choice(choices): | |
choices_str = ', '.join(f'"{x}"' for x in sorted(choices)) | |
def validator(value): | |
if value not in choices: | |
raise ValidationError(f'Value must be one of {choices_str}', value) | |
return value | |
return validator | |
def validate_user(account_svc): | |
def validator(value): | |
user = account_svc.find_user_by_email(value) | |
if user is None: | |
raise ValidationError('Could not find user.', value) | |
return user | |
return validator | |
def validate_date(value): | |
try: | |
dt = datetime.strptime(value, '%Y-%m-%d') | |
except Exception: | |
raise ValidationError('Invalid date format, must be YYYY-MM-DD.') | |
return dt.date() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment