Skip to content

Instantly share code, notes, and snippets.

@mmerickel
Last active August 23, 2019 20:22
Show Gist options
  • Save mmerickel/2516052c6ba5e20b299f7666fe6426d3 to your computer and use it in GitHub Desktop.
Save mmerickel/2516052c6ba5e20b299f7666fe6426d3 to your computer and use it in GitHub Desktop.
subparse base cli
from subparse import command
@command('.shell')
def shell(parser):
""" Launch a python interpreter."""
@command('.run')
def run(parser):
""" Run a script.
By default, this command will execute a ``main(cli, args)`` function in
the specified script. If the path points directly at a callable then it
will be used instead but must still receive the args.
If an 'options' callable is found attached to the main function then
it will be invoked with an argparse parser and used to compute the
arguments.
"""
parser.add_argument(
'-m',
dest='module',
required=True,
nargs=argparse.REMAINDER,
)
from contextlib import contextmanager
from getpass import getpass
import importlib_metadata
import logging
import logging.config
import os
import plaster
from pyramid.decorator import reify
import sentry_sdk
from subparse import CLI
import sys
from myapp.utils.sentry import init_sentry_from_settings
from myapp.utils.settings import asbool, load_settings_from_file
from myapp.utils.tm import tm_context
from .validation import ValidationError
CONFIG_ENVIRON_KEY = 'MYAPP_CONFIG'
class MyApp(object):
stdin = sys.stdin
stdout = sys.stdout
stderr = sys.stderr
def __init__(self, config_file, service_flags=None):
self.config_file = config_file
self.service_flags = service_flags
def read_stdin(self, text=True):
if text:
return self.stdin.read()
else:
return self.stdin.buffer.read()
def out(self, msg):
self.stdout.write(msg)
if not msg.endswith('\n'):
self.stdout.write('\n')
@reify
def _log(self):
return logging.getLogger(__name__)
def error(self, msg):
self.stderr.write(msg)
if not msg.endswith('\n'):
self.stderr.write('\n')
def abort(self, error, code=1):
self.error(error)
raise AbortCLI(error, code)
@reify
def plaster(self):
return plaster.get_loader(self.config_file, protocols=['wsgi'])
@reify
def settings(self):
return load_settings_from_file(self.config_file)
@reify
def service_factory(self):
from myapp.services import make_service_factory
flags = self.service_flags or {}
return make_service_factory(self.settings, flags)
@reify
def services(self):
return self.service_factory.create_container()
@contextmanager
def services_context(self):
services = self.service_factory.create_container()
tm = services.get(name='tm')
with tm_context(tm):
yield services
def get_service(self, *args, **kwargs):
return self.services.get(*args, **kwargs)
@property
def tm(self):
return self.get_service(name='tm')
@property
def db(self):
return self.get_service(name='db')
def prompt(self,
prompt,
validator=None,
attempts=3,
confirm=False,
secure=False,
default=None,
):
if default is None:
default = ''
while attempts > 0:
value = _get_input(prompt, secure=secure)
if not value:
value = default
try:
if validator is not None:
value = validator(value)
except ValidationError as ex:
self.error(ex.message)
else:
if confirm:
confirm_value = _get_input(confirm, secure=secure)
if not confirm_value:
confirm_value = default
if confirm_value == value:
return value
self.error('Values do not match.')
else:
return value
attempts -= 1
self.abort('Too many attempts, aborting.')
def input_file(self, path, **kw):
from .utils import input_file
return input_file(self, path, **kw)
def output_file(self, path, **kw):
from .utils import output_file
return output_file(self, path, **kw)
def _get_input(prompt, secure=False):
prompt = prompt.strip() + ' '
if secure:
value = getpass(prompt)
else:
value = input(prompt)
value = value.strip()
return value
class AbortCLI(Exception):
def __init__(self, message, code):
self.message = message
self.code = code
def context_factory(
cli,
args,
ignore_config_file=False,
ignore_schema=False,
without_tm=False,
service_flags=None,
):
if getattr(args, 'reload', False):
import hupper
reloader = hupper.start_reloader(
__name__ + '.main',
shutdown_interval=30,
)
reloader.watch_files([args.config_file])
app = MyApp(
config_file=args.config_file,
service_flags=service_flags,
)
if not ignore_config_file and not os.path.exists(app.config_file):
app.abort('Invalid config file, does not exist.')
if args.quiet:
root_logger = logging.getLogger('')
root_logger.setLevel(logging.CRITICAL)
elif ignore_config_file:
root_logger = logging.getLogger('')
root_logger.setLevel(logging.INFO)
else:
app.plaster.setup_logging()
if not ignore_config_file:
init_sentry_from_settings(app.settings)
try:
if ignore_config_file or ignore_schema or without_tm:
yield app
else:
with tm_context(app.tm):
yield app
except Exception as ex:
if args.pdb is True or (
args.pdb is None
and not ignore_config_file
and asbool(app.settings.get('app.auto_pdb'), False)
and not isinstance(ex, AbortCLI)
):
import pdb # noqa T100
pdb.post_mortem()
raise
def generic_options(parser):
default_config = os.environ.get(CONFIG_ENVIRON_KEY, 'site.ini')
parser.add_argument(
'-c', '--config-file',
default=default_config,
)
parser.add_argument(
'-q', '--quiet',
action='store_true',
)
parser.add_argument(
'--pdb',
action='store_true',
)
parser.add_argument(
'---no-pdb',
action='store_false',
dest='pdb',
)
def main():
cli = CLI(
version=importlib_metadata.version('myapp'),
context_factory=context_factory,
)
cli.add_generic_options(generic_options)
cli.load_commands('.commands')
try:
return cli.run()
except AbortCLI as ex:
return ex.code
except Exception:
sentry_sdk.capture_exception()
raise
import argparse
from subparse import parse_docstring
from myapp.utils.resolver import maybe_resolve
def main(cli, args):
module_name, *module_args = args.module
fn = maybe_resolve(module_name)
if not callable(fn):
fn = fn.main
run_options = getattr(fn, 'options', None)
if run_options:
short_desc, long_desc = parse_docstring(run_options.__doc__)
if long_desc:
long_desc = short_desc + '\n\n' + long_desc
parser = argparse.ArgumentParser(
prog=module_name,
description=long_desc,
formatter_class=argparse.RawTextHelpFormatter,
)
run_options(parser)
module_args = parser.parse_args(module_args)
return fn(cli, module_args)
def make_default_shell(interact=interact):
import sys
def shell(env, help):
cprt = 'Type "help" for more information.'
banner = 'Python %s on %s\n%s' % (sys.version, sys.platform, cprt)
banner += '\n\n' + help + '\n' if help else '\n'
interact(banner, local=env)
return shell
def make_ipython_shell(IPShellFactory=None):
try:
import IPython
IPShellFactory = IPython.start_ipython
except ImportError:
return None
def shell(env, help):
from traitlets.config import Config
c = Config()
c.TerminalInteractiveShell.banner2 = help + '\n' if help else ''
IPShellFactory(argv=[], user_ns=env, config=c)
return shell
def make_shell():
shell = make_ipython_shell()
if shell is None:
shell = make_default_shell()
return shell
def run_shell(env, help='', *, shell=None):
if shell is None:
shell = make_shell()
return shell(env, help)
from contextlib import suppress
import textwrap
from transaction.interfaces import NoTransaction
from myapp.utils.resolver import maybe_resolve
from .utils import run_shell
def main(cli, args):
env = {
'cli': cli,
'settings': cli.settings,
'model': maybe_resolve('myapp.model'),
'S': maybe_resolve('myapp.services'),
'tm': cli.tm,
'db': cli.get_service(name='db'),
'services': cli.services,
# helper methods
'find_user': maybe_resolve('.utils.find_user'),
# extra imports for fun
'datetime': maybe_resolve('datetime.datetime'),
'timedelta': maybe_resolve('datetime.timedelta'),
'json': maybe_resolve('json'),
'pytz': maybe_resolve('pytz'),
'sa': maybe_resolve('sqlalchemy'),
}
help = textwrap.dedent(
"""
Imports:
datetime datetime.datetime
sa sqlalchemy
timedelta datetime.timedelta
json json
pytz pytz
model myapp.model
S myapp.services
Context:
cli The CLI application object.
settings The settings parsed from the ini file.
Helpers Helper functions bound to a service container.
Service context:
db database session
services services factory
tm transaction manager
Helper methods:
find_user(id) -> User
""").strip()
run_shell(env, help)
with suppress(NoTransaction):
cli.tm.get()
log.debug('aborting pending changes')
cli.tm.abort()
from datetime import datetime
class ValidationError(ValueError):
def __init__(self, message, invalid_value):
super().__init__(message)
self.message = message
self.invalid_value = invalid_value
def validate_text(min=1, max=None, field='Value', next_validator=None):
if min is not None and max is not None:
assert 0 < min <= max
if min == max:
if min == 1:
msg = f'{field} must be exactly {min} characters.'
else:
msg = f'{field} must be exactly 1 character.'
else:
msg = f'{field} must be between {min} and {max} characters.'
elif min is not None:
assert min > 0
if min > 1:
msg = f'{field} must be at least {min} characters.'
else:
msg = f'{field} must not be empty.'
elif max is not None:
assert max > 0
if max == 1:
msg = f'{field} must be at most 1 character.'
else:
msg = f'{field} cannot be longer than {max} characters.'
def validator(value):
if min is not None and len(value) < min:
raise ValidationError(msg, value)
if max is not None and len(value) > max:
raise ValidationError(msg, value)
if next_validator is not None:
return next_validator(value)
return value
return validator
validate_email = validate_text(min=1, max=255, field='Email')
validate_password = validate_text(min=8, field='Password')
def validate_bool(value):
value = value.lower()
if value == 'y':
return True
if value == 'n':
return False
raise ValidationError('Please answer "y" or "n".', value)
def validate_choice(choices):
choices_str = ', '.join(f'"{x}"' for x in sorted(choices))
def validator(value):
if value not in choices:
raise ValidationError(f'Value must be one of {choices_str}', value)
return value
return validator
def validate_user(account_svc):
def validator(value):
user = account_svc.find_user_by_email(value)
if user is None:
raise ValidationError('Could not find user.', value)
return user
return validator
def validate_date(value):
try:
dt = datetime.strptime(value, '%Y-%m-%d')
except Exception:
raise ValidationError('Invalid date format, must be YYYY-MM-DD.')
return dt.date()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment