Skip to content

Instantly share code, notes, and snippets.

@EmbraceLife
Created August 19, 2024 05:25
Show Gist options
  • Save EmbraceLife/87f69df227eafef3f4fc3774184404e2 to your computer and use it in GitHub Desktop.
Save EmbraceLife/87f69df227eafef3f4fc3774184404e2 to your computer and use it in GitHub Desktop.
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/07_meta.ipynb.
# %% auto 0
__all__ = ['test_sig', 'FixSigMeta', 'PrePostInitMeta', 'AutoInit', 'NewChkMeta', 'BypassNewMeta', 'empty2none', 'anno_dict',
'use_kwargs_dict', 'use_kwargs', 'delegates', 'method', 'funcs_kwargs']
# %% ../nbs/07_meta.ipynb
from .imports import *
from .test import *
from contextlib import contextmanager
from copy import copy
import inspect
# %% ../nbs/07_meta.ipynb
def test_sig(f, b):
"Test the signature of an object"
test_eq(str(inspect.signature(f)), b)
# %% ../nbs/07_meta.ipynb
def _rm_self(sig):
sigd = dict(sig.parameters)
sigd.pop('self')
return sig.replace(parameters=sigd.values())
# %% ../nbs/07_meta.ipynb
class FixSigMeta(type):
"A metaclass that fixes the signature on classes that override `__new__`"
def __new__(cls, name, bases, dict):
res = super().__new__(cls, name, bases, dict)
if res.__init__ is not object.__init__: res.__signature__ = _rm_self(inspect.signature(res.__init__))
return res
# %% ../nbs/07_meta.ipynb
class PrePostInitMeta(FixSigMeta):
"A metaclass that calls optional `__pre_init__` and `__post_init__` methods"
def __call__(cls, *args, **kwargs):
res = cls.__new__(cls)
if type(res)==cls:
if hasattr(res,'__pre_init__'): res.__pre_init__(*args,**kwargs)
res.__init__(*args,**kwargs)
if hasattr(res,'__post_init__'): res.__post_init__(*args,**kwargs)
return res
# %% ../nbs/07_meta.ipynb
class AutoInit(metaclass=PrePostInitMeta):
"Same as `object`, but no need for subclasses to call `super().__init__`"
def __pre_init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
# %% ../nbs/07_meta.ipynb
class NewChkMeta(FixSigMeta):
"Metaclass to avoid recreating object passed to constructor"
def __call__(cls, x=None, *args, **kwargs):
if not args and not kwargs and x is not None and isinstance(x,cls): return x
res = super().__call__(*((x,) + args), **kwargs)
return res
# %% ../nbs/07_meta.ipynb
class BypassNewMeta(FixSigMeta):
"Metaclass: casts `x` to this class if it's of type `cls._bypass_type`"
def __call__(cls, x=None, *args, **kwargs):
if hasattr(cls, '_new_meta'): x = cls._new_meta(x, *args, **kwargs)
elif not isinstance(x,getattr(cls,'_bypass_type',object)) or len(args) or len(kwargs):
x = super().__call__(*((x,)+args), **kwargs)
if cls!=x.__class__: x.__class__ = cls
return x
# %% ../nbs/07_meta.ipynb
def empty2none(p):
"Replace `Parameter.empty` with `None`"
return None if p==inspect.Parameter.empty else p
# %% ../nbs/07_meta.ipynb
def anno_dict(f):
"`__annotation__ dictionary with `empty` cast to `None`, returning empty if doesn't exist"
return {k:empty2none(v) for k,v in getattr(f, '__annotations__', {}).items()}
# %% ../nbs/07_meta.ipynb
def _mk_param(n,d=None): return inspect.Parameter(n, inspect.Parameter.KEYWORD_ONLY, default=d)
# %% ../nbs/07_meta.ipynb
def use_kwargs_dict(keep=False, **kwargs):
"Decorator: replace `**kwargs` in signature with `names` params"
def _f(f):
sig = inspect.signature(f)
sigd = dict(sig.parameters)
k = sigd.pop('kwargs')
s2 = {n:_mk_param(n,d) for n,d in kwargs.items() if n not in sigd}
sigd.update(s2)
if keep: sigd['kwargs'] = k
f.__signature__ = sig.replace(parameters=sigd.values())
return f
return _f
# %% ../nbs/07_meta.ipynb
def use_kwargs(names, keep=False):
"Decorator: replace `**kwargs` in signature with `names` params"
def _f(f):
sig = inspect.signature(f)
sigd = dict(sig.parameters)
k = sigd.pop('kwargs')
s2 = {n:_mk_param(n) for n in names if n not in sigd}
sigd.update(s2)
if keep: sigd['kwargs'] = k
f.__signature__ = sig.replace(parameters=sigd.values())
return f
return _f
# %% ../nbs/07_meta.ipynb
def delegates(to:FunctionType=None, # Delegatee
keep=False, # Keep `kwargs` in decorated function?
but:list=None): # Exclude these parameters from signature
"Decorator: replace `**kwargs` in signature with params from `to`"
if but is None: but = []
def _f(f):
if to is None: to_f,from_f = f.__base__.__init__,f.__init__
else: to_f,from_f = to.__init__ if isinstance(to,type) else to,f
from_f = getattr(from_f,'__func__',from_f)
to_f = getattr(to_f,'__func__',to_f)
if hasattr(from_f,'__delwrap__'): return f
sig = inspect.signature(from_f)
sigd = dict(sig.parameters)
k = sigd.pop('kwargs')
s2 = {k:v.replace(kind=inspect.Parameter.KEYWORD_ONLY) for k,v in inspect.signature(to_f).parameters.items()
if v.default != inspect.Parameter.empty and k not in sigd and k not in but}
anno = {k:v for k,v in getattr(to_f, "__annotations__", {}).items() if k not in sigd and k not in but}
sigd.update(s2)
if keep: sigd['kwargs'] = k
else: from_f.__delwrap__ = to_f
from_f.__signature__ = sig.replace(parameters=sigd.values())
if hasattr(from_f, '__annotations__'): from_f.__annotations__.update(anno)
return f
return _f
# %% ../nbs/07_meta.ipynb
def method(f):
"Mark `f` as a method"
# `1` is a dummy instance since Py3 doesn't allow `None` any more
return MethodType(f, 1)
# %% ../nbs/07_meta.ipynb
def _funcs_kwargs(cls, as_method):
old_init = cls.__init__
def _init(self, *args, **kwargs):
for k in cls._methods:
arg = kwargs.pop(k,None)
if arg is not None:
if as_method: arg = method(arg)
if isinstance(arg,MethodType): arg = MethodType(arg.__func__, self)
setattr(self, k, arg)
old_init(self, *args, **kwargs)
functools.update_wrapper(_init, old_init)
cls.__init__ = use_kwargs(cls._methods)(_init)
if hasattr(cls, '__signature__'): cls.__signature__ = _rm_self(inspect.signature(cls.__init__))
return cls
# %% ../nbs/07_meta.ipynb
def funcs_kwargs(as_method=False):
"Replace methods in `cls._methods` with those from `kwargs`"
if callable(as_method): return _funcs_kwargs(as_method, False)
return partial(_funcs_kwargs, as_method=as_method)
---
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_basics.ipynb.
# %% auto 0
__all__ = ['defaults', 'null', 'num_methods', 'rnum_methods', 'inum_methods', 'arg0', 'arg1', 'arg2', 'arg3', 'arg4', 'Self',
'ifnone', 'maybe_attr', 'basic_repr', 'is_array', 'listify', 'tuplify', 'true', 'NullType', 'tonull',
'get_class', 'mk_class', 'wrap_class', 'ignore_exceptions', 'exec_local', 'risinstance', 'ver2tuple', 'Inf',
'in_', 'ret_true', 'ret_false', 'stop', 'gen', 'chunked', 'otherwise', 'custom_dir', 'AttrDict', 'NS',
'get_annotations_ex', 'eval_type', 'type_hints', 'annotations', 'anno_ret', 'signature_ex', 'union2tuple',
'argnames', 'with_cast', 'store_attr', 'attrdict', 'properties', 'camel2words', 'camel2snake', 'snake2camel',
'class2attr', 'getcallable', 'getattrs', 'hasattrs', 'setattrs', 'try_attrs', 'GetAttrBase', 'GetAttr',
'delegate_attr', 'ShowPrint', 'Int', 'Str', 'Float', 'partition', 'flatten', 'concat', 'strcat', 'detuplify',
'replicate', 'setify', 'merge', 'range_of', 'groupby', 'last_index', 'filter_dict', 'filter_keys',
'filter_values', 'cycle', 'zip_cycle', 'sorted_ex', 'not_', 'argwhere', 'filter_ex', 'renumerate', 'first',
'only', 'nested_attr', 'nested_setdefault', 'nested_callable', 'nested_idx', 'set_nested_idx', 'val2idx',
'uniqueify', 'loop_first_last', 'loop_first', 'loop_last', 'first_match', 'last_match', 'fastuple', 'bind',
'mapt', 'map_ex', 'compose', 'maps', 'partialler', 'instantiate', 'using_attr', 'copy_func', 'patch_to',
'patch', 'patch_property', 'compile_re', 'ImportEnum', 'StrEnum', 'str_enum', 'ValEnum', 'Stateful',
'NotStr', 'PrettyString', 'even_mults', 'num_cpus', 'add_props', 'typed', 'exec_new', 'exec_import',
'str2bool', 'lt', 'gt', 'le', 'ge', 'eq', 'ne', 'add', 'sub', 'mul', 'truediv', 'is_', 'is_not', 'mod']
# %% ../nbs/01_basics.ipynb
from .imports import *
import builtins,types,typing
import pprint
from copy import copy
try: from types import UnionType
except ImportError: UnionType = None
# %% ../nbs/01_basics.ipynb
defaults = SimpleNamespace()
# %% ../nbs/01_basics.ipynb
def ifnone(a, b):
"`b` if `a` is None else `a`"
return b if a is None else a
# %% ../nbs/01_basics.ipynb
def maybe_attr(o, attr):
"`getattr(o,attr,o)`"
return getattr(o,attr,o)
# %% ../nbs/01_basics.ipynb
def basic_repr(flds=None):
"Minimal `__repr__`"
if isinstance(flds, str): flds = re.split(', *', flds)
flds = list(flds or [])
def _f(self):
res = f'{type(self).__module__}.{type(self).__name__}'
if not flds: return f'<{res}>'
sig = ', '.join(f'{o}={getattr(self,o)!r}' for o in flds)
return f'{res}({sig})'
return _f
# %% ../nbs/01_basics.ipynb
def is_array(x):
"`True` if `x` supports `__array__` or `iloc`"
return hasattr(x,'__array__') or hasattr(x,'iloc')
# %% ../nbs/01_basics.ipynb
def listify(o=None, *rest, use_list=False, match=None):
"Convert `o` to a `list`"
if rest: o = (o,)+rest
if use_list: res = list(o)
elif o is None: res = []
elif isinstance(o, list): res = o
elif isinstance(o, str) or isinstance(o, bytes) or is_array(o): res = [o]
elif is_iter(o): res = list(o)
else: res = [o]
if match is not None:
if is_coll(match): match = len(match)
if len(res)==1: res = res*match
else: assert len(res)==match, 'Match length mismatch'
return res
# %% ../nbs/01_basics.ipynb
def tuplify(o, use_list=False, match=None):
"Make `o` a tuple"
return tuple(listify(o, use_list=use_list, match=match))
# %% ../nbs/01_basics.ipynb
def true(x):
"Test whether `x` is truthy; collections with >0 elements are considered `True`"
try: return bool(len(x))
except: return bool(x)
# %% ../nbs/01_basics.ipynb
class NullType:
"An object that is `False` and can be called, chained, and indexed"
def __getattr__(self,*args):return null
def __call__(self,*args, **kwargs):return null
def __getitem__(self, *args):return null
def __bool__(self): return False
null = NullType()
# %% ../nbs/01_basics.ipynb
def tonull(x):
"Convert `None` to `null`"
return null if x is None else x
# %% ../nbs/01_basics.ipynb
def get_class(nm, *fld_names, sup=None, doc=None, funcs=None, anno=None, **flds):
"Dynamically create a class, optionally inheriting from `sup`, containing `fld_names`"
attrs = {}
if not anno: anno = {}
for f in fld_names:
attrs[f] = None
if f not in anno: anno[f] = typing.Any
for f in listify(funcs): attrs[f.__name__] = f
for k,v in flds.items(): attrs[k] = v
sup = ifnone(sup, ())
if not isinstance(sup, tuple): sup=(sup,)
def _init(self, *args, **kwargs):
for i,v in enumerate(args): setattr(self, list(attrs.keys())[i], v)
for k,v in kwargs.items(): setattr(self,k,v)
attrs['_fields'] = [*fld_names,*flds.keys()]
def _eq(self,b):
return all([getattr(self,k)==getattr(b,k) for k in self._fields])
if not sup: attrs['__repr__'] = basic_repr(attrs['_fields'])
attrs['__init__'] = _init
attrs['__eq__'] = _eq
if anno: attrs['__annotations__'] = anno
res = type(nm, sup, attrs)
if doc is not None: res.__doc__ = doc
return res
# %% ../nbs/01_basics.ipynb
def mk_class(nm, *fld_names, sup=None, doc=None, funcs=None, mod=None, anno=None, **flds):
"Create a class using `get_class` and add to the caller's module"
if mod is None: mod = sys._getframe(1).f_locals
res = get_class(nm, *fld_names, sup=sup, doc=doc, funcs=funcs, anno=anno, **flds)
mod[nm] = res
# %% ../nbs/01_basics.ipynb
def wrap_class(nm, *fld_names, sup=None, doc=None, funcs=None, **flds):
"Decorator: makes function a method of a new class `nm` passing parameters to `mk_class`"
def _inner(f):
mk_class(nm, *fld_names, sup=sup, doc=doc, funcs=listify(funcs)+[f], mod=f.__globals__, **flds)
return f
return _inner
# %% ../nbs/01_basics.ipynb
class ignore_exceptions:
"Context manager to ignore exceptions"
def __enter__(self): pass
def __exit__(self, *args): return True
# %% ../nbs/01_basics.ipynb
def exec_local(code, var_name):
"Call `exec` on `code` and return the var `var_name`"
loc = {}
exec(code, globals(), loc)
return loc[var_name]
# %% ../nbs/01_basics.ipynb
def risinstance(types, obj=None):
"Curried `isinstance` but with args reversed"
types = tuplify(types)
if obj is None: return partial(risinstance,types)
if any(isinstance(t,str) for t in types):
return any(t.__name__ in types for t in type(obj).__mro__)
return isinstance(obj, types)
# %% ../nbs/01_basics.ipynb
def ver2tuple(v:str)->tuple:
return tuple(int(o or 0) for o in re.search(r'(\d+)(?:\.(\d+))?(?:\.(\d+))?', v).groups())
# %% ../nbs/01_basics.ipynb
class _InfMeta(type):
@property
def count(self): return itertools.count()
@property
def zeros(self): return itertools.cycle([0])
@property
def ones(self): return itertools.cycle([1])
@property
def nones(self): return itertools.cycle([None])
# %% ../nbs/01_basics.ipynb
class Inf(metaclass=_InfMeta):
"Infinite lists"
pass
# %% ../nbs/01_basics.ipynb
_dumobj = object()
def _oper(op,a,b=_dumobj): return (lambda o:op(o,a)) if b is _dumobj else op(a,b)
def _mk_op(nm, mod):
"Create an operator using `oper` and add to the caller's module"
op = getattr(operator,nm)
def _inner(a, b=_dumobj): return _oper(op, a,b)
_inner.__name__ = _inner.__qualname__ = nm
_inner.__doc__ = f'Same as `operator.{nm}`, or returns partial if 1 arg'
mod[nm] = _inner
# %% ../nbs/01_basics.ipynb
def in_(x, a):
"`True` if `x in a`"
return x in a
operator.in_ = in_
# %% ../nbs/01_basics.ipynb
_all_ = ['lt','gt','le','ge','eq','ne','add','sub','mul','truediv','is_','is_not','in_', 'mod']
# %% ../nbs/01_basics.ipynb
for op in _all_: _mk_op(op, globals())
# %% ../nbs/01_basics.ipynb
def ret_true(*args, **kwargs):
"Predicate: always `True`"
return True
# %% ../nbs/01_basics.ipynb
def ret_false(*args, **kwargs):
"Predicate: always `False`"
return False
# %% ../nbs/01_basics.ipynb
def stop(e=StopIteration):
"Raises exception `e` (by default `StopIteration`)"
raise e
# %% ../nbs/01_basics.ipynb
def gen(func, seq, cond=ret_true):
"Like `(func(o) for o in seq if cond(func(o)))` but handles `StopIteration`"
return itertools.takewhile(cond, map(func,seq))
# %% ../nbs/01_basics.ipynb
def chunked(it, chunk_sz=None, drop_last=False, n_chunks=None):
"Return batches from iterator `it` of size `chunk_sz` (or return `n_chunks` total)"
assert bool(chunk_sz) ^ bool(n_chunks)
if n_chunks: chunk_sz = max(math.ceil(len(it)/n_chunks), 1)
if not isinstance(it, Iterator): it = iter(it)
while True:
res = list(itertools.islice(it, chunk_sz))
if res and (len(res)==chunk_sz or not drop_last): yield res
if len(res)<chunk_sz: return
# %% ../nbs/01_basics.ipynb
def otherwise(x, tst, y):
"`y if tst(x) else x`"
return y if tst(x) else x
# %% ../nbs/01_basics.ipynb
def custom_dir(c, add):
"Implement custom `__dir__`, adding `add` to `cls`"
return object.__dir__(c) + listify(add)
# %% ../nbs/01_basics.ipynb
class AttrDict(dict):
"`dict` subclass that also provides access to keys as attrs"
def __getattr__(self,k): return self[k] if k in self else stop(AttributeError(k))
def __setattr__(self, k, v): (self.__setitem__,super().__setattr__)[k[0]=='_'](k,v)
def __dir__(self): return super().__dir__() + list(self.keys())
def _repr_markdown_(self): return f'```json\n{pprint.pformat(self, indent=2)}\n```'
def copy(self): return AttrDict(**self)
# %% ../nbs/01_basics.ipynb
class NS(SimpleNamespace):
"`SimpleNamespace` subclass that also adds `iter` and `dict` support"
def __iter__(self): return iter(self.__dict__)
def __getitem__(self,x): return self.__dict__[x]
def __setitem__(self,x,y): self.__dict__[x] = y
# %% ../nbs/01_basics.ipynb
def get_annotations_ex(obj, *, globals=None, locals=None):
"Backport of py3.10 `get_annotations` that returns globals/locals"
if isinstance(obj, type):
obj_dict = getattr(obj, '__dict__', None)
if obj_dict and hasattr(obj_dict, 'get'):
ann = obj_dict.get('__annotations__', None)
if isinstance(ann, types.GetSetDescriptorType): ann = None
else: ann = None
obj_globals = None
module_name = getattr(obj, '__module__', None)
if module_name:
module = sys.modules.get(module_name, None)
if module: obj_globals = getattr(module, '__dict__', None)
obj_locals = dict(vars(obj))
unwrap = obj
elif isinstance(obj, types.ModuleType):
ann = getattr(obj, '__annotations__', None)
obj_globals = getattr(obj, '__dict__')
obj_locals,unwrap = None,None
elif callable(obj):
ann = getattr(obj, '__annotations__', None)
obj_globals = getattr(obj, '__globals__', None)
obj_locals,unwrap = None,obj
else: raise TypeError(f"{obj!r} is not a module, class, or callable.")
if ann is None: ann = {}
if not isinstance(ann, dict): raise ValueError(f"{obj!r}.__annotations__ is neither a dict nor None")
if not ann: ann = {}
if unwrap is not None:
while True:
if hasattr(unwrap, '__wrapped__'):
unwrap = unwrap.__wrapped__
continue
if isinstance(unwrap, functools.partial):
unwrap = unwrap.func
continue
break
if hasattr(unwrap, "__globals__"): obj_globals = unwrap.__globals__
if globals is None: globals = obj_globals
if locals is None: locals = obj_locals
return dict(ann), globals, locals
# %% ../nbs/01_basics.ipynb
def eval_type(t, glb, loc):
"`eval` a type or collection of types, if needed, for annotations in py3.10+"
if isinstance(t,str):
if '|' in t: return Union[eval_type(tuple(t.split('|')), glb, loc)]
return eval(t, glb, loc)
if isinstance(t,(tuple,list)): return type(t)([eval_type(c, glb, loc) for c in t])
return t
# %% ../nbs/01_basics.ipynb
def _eval_type(t, glb, loc):
res = eval_type(t, glb, loc)
return NoneType if res is None else res
def type_hints(f):
"Like `typing.get_type_hints` but returns `{}` if not allowed type"
if not isinstance(f, typing._allowed_types): return {}
ann,glb,loc = get_annotations_ex(f)
return {k:_eval_type(v,glb,loc) for k,v in ann.items()}
# %% ../nbs/01_basics.ipynb
def annotations(o):
"Annotations for `o`, or `type(o)`"
res = {}
if not o: return res
res = type_hints(o)
if not res: res = type_hints(getattr(o,'__init__',None))
if not res: res = type_hints(type(o))
return res
# %% ../nbs/01_basics.ipynb
def anno_ret(func):
"Get the return annotation of `func`"
return annotations(func).get('return', None) if func else None
# %% ../nbs/01_basics.ipynb
def _ispy3_10(): return sys.version_info.major >=3 and sys.version_info.minor >=10
def signature_ex(obj, eval_str:bool=False):
"Backport of `inspect.signature(..., eval_str=True` to <py310"
from inspect import Signature, Parameter, signature
def _eval_param(ann, k, v):
if k not in ann: return v
return Parameter(v.name, v.kind, annotation=ann[k], default=v.default)
if not eval_str: return signature(obj)
if _ispy3_10(): return signature(obj, eval_str=eval_str)
sig = signature(obj)
if sig is None: return None
ann = type_hints(obj)
params = [_eval_param(ann,k,v) for k,v in sig.parameters.items()]
return Signature(params, return_annotation=sig.return_annotation)
# %% ../nbs/01_basics.ipynb
def union2tuple(t):
if (getattr(t, '__origin__', None) is Union
or (UnionType and isinstance(t, UnionType))): return t.__args__
return t
# %% ../nbs/01_basics.ipynb
def argnames(f, frame=False):
"Names of arguments to function or frame `f`"
code = getattr(f, 'f_code' if frame else '__code__')
return code.co_varnames[:code.co_argcount+code.co_kwonlyargcount]
# %% ../nbs/01_basics.ipynb
def with_cast(f):
"Decorator which uses any parameter annotations as preprocessing functions"
anno, out_anno, params = annotations(f), anno_ret(f), argnames(f)
c_out = ifnone(out_anno, noop)
defaults = dict(zip(reversed(params), reversed(f.__defaults__ or {})))
@functools.wraps(f)
def _inner(*args, **kwargs):
args = list(args)
for i,v in enumerate(params):
if v in anno:
c = anno[v]
if v in kwargs: kwargs[v] = c(kwargs[v])
elif i<len(args): args[i] = c(args[i])
elif v in defaults: kwargs[v] = c(defaults[v])
return c_out(f(*args, **kwargs))
return _inner
# %% ../nbs/01_basics.ipynb
def _store_attr(self, anno, **attrs):
stored = getattr(self, '__stored_args__', None)
for n,v in attrs.items():
if n in anno: v = anno[n](v)
setattr(self, n, v)
if stored is not None: stored[n] = v
# %% ../nbs/01_basics.ipynb
def store_attr(names=None, self=None, but='', cast=False, store_args=None, **attrs):
"Store params named in comma-separated `names` from calling context into attrs in `self`"
fr = sys._getframe(1)
args = argnames(fr, True)
if self: args = ('self', *args)
else: self = fr.f_locals[args[0]]
if store_args is None: store_args = not hasattr(self,'__slots__')
if store_args and not hasattr(self, '__stored_args__'): self.__stored_args__ = {}
anno = annotations(self) if cast else {}
if names and isinstance(names,str): names = re.split(', *', names)
ns = names if names is not None else getattr(self, '__slots__', args[1:])
added = {n:fr.f_locals[n] for n in ns}
attrs = {**attrs, **added}
if isinstance(but,str): but = re.split(', *', but)
attrs = {k:v for k,v in attrs.items() if k not in but}
return _store_attr(self, anno, **attrs)
# %% ../nbs/01_basics.ipynb
def attrdict(o, *ks, default=None):
"Dict from each `k` in `ks` to `getattr(o,k)`"
return {k:getattr(o, k, default) for k in ks}
# %% ../nbs/01_basics.ipynb
def properties(cls, *ps):
"Change attrs in `cls` with names in `ps` to properties"
for p in ps: setattr(cls,p,property(getattr(cls,p)))
# %% ../nbs/01_basics.ipynb
_c2w_re = re.compile(r'((?<=[a-z])[A-Z]|(?<!\A)[A-Z](?=[a-z]))')
_camel_re1 = re.compile('(.)([A-Z][a-z]+)')
_camel_re2 = re.compile('([a-z0-9])([A-Z])')
# %% ../nbs/01_basics.ipynb
def camel2words(s, space=' '):
"Convert CamelCase to 'spaced words'"
return re.sub(_c2w_re, rf'{space}\1', s)
# %% ../nbs/01_basics.ipynb
def camel2snake(name):
"Convert CamelCase to snake_case"
s1 = re.sub(_camel_re1, r'\1_\2', name)
return re.sub(_camel_re2, r'\1_\2', s1).lower()
# %% ../nbs/01_basics.ipynb
def snake2camel(s):
"Convert snake_case to CamelCase"
return ''.join(s.title().split('_'))
# %% ../nbs/01_basics.ipynb
def class2attr(self, cls_name):
"Return the snake-cased name of the class; strip ending `cls_name` if it exists."
return camel2snake(re.sub(rf'{cls_name}$', '', self.__class__.__name__) or cls_name.lower())
# %% ../nbs/01_basics.ipynb
def getcallable(o, attr):
"Calls `getattr` with a default of `noop`"
return getattr(o, attr, noop)
# %% ../nbs/01_basics.ipynb
def getattrs(o, *attrs, default=None):
"List of all `attrs` in `o`"
return [getattr(o,attr,default) for attr in attrs]
# %% ../nbs/01_basics.ipynb
def hasattrs(o,attrs):
"Test whether `o` contains all `attrs`"
return all(hasattr(o,attr) for attr in attrs)
# %% ../nbs/01_basics.ipynb
def setattrs(dest, flds, src):
f = dict.get if isinstance(src, dict) else getattr
flds = re.split(r",\s*", flds)
for fld in flds: setattr(dest, fld, f(src, fld))
# %% ../nbs/01_basics.ipynb
def try_attrs(obj, *attrs):
"Return first attr that exists in `obj`"
for att in attrs:
try: return getattr(obj, att)
except: pass
raise AttributeError(attrs)
# %% ../nbs/01_basics.ipynb
class GetAttrBase:
"Basic delegation of `__getattr__` and `__dir__`"
_attr=noop
def __getattr__(self,k):
if k[0]=='_' or k==self._attr: return super().__getattr__(k)
return self._getattr(getattr(self, self._attr)[k])
def __dir__(self): return custom_dir(self, getattr(self, self._attr))
# %% ../nbs/01_basics.ipynb
class GetAttr:
"Inherit from this to have all attr accesses in `self._xtra` passed down to `self.default`"
_default='default'
def _component_attr_filter(self,k):
if k.startswith('__') or k in ('_xtra',self._default): return False
xtra = getattr(self,'_xtra',None)
return xtra is None or k in xtra
def _dir(self): return [k for k in dir(getattr(self,self._default)) if self._component_attr_filter(k)]
def __getattr__(self,k):
if self._component_attr_filter(k):
attr = getattr(self,self._default,None)
if attr is not None: return getattr(attr,k)
raise AttributeError(k)
def __dir__(self): return custom_dir(self,self._dir())
# def __getstate__(self): return self.__dict__
def __setstate__(self,data): self.__dict__.update(data)
# %% ../nbs/01_basics.ipynb
def delegate_attr(self, k, to):
"Use in `__getattr__` to delegate to attr `to` without inheriting from `GetAttr`"
if k.startswith('_') or k==to: raise AttributeError(k)
try: return getattr(getattr(self,to), k)
except AttributeError: raise AttributeError(k) from None
# %% ../nbs/01_basics.ipynb
class ShowPrint:
"Base class that prints for `show`"
def show(self, *args, **kwargs): print(str(self))
# %% ../nbs/01_basics.ipynb
class Int(int,ShowPrint):
"An extensible `int`"
pass
# %% ../nbs/01_basics.ipynb
class Str(str,ShowPrint):
"An extensible `str`"
pass
class Float(float,ShowPrint):
"An extensible `float`"
pass
# %% ../nbs/01_basics.ipynb
def partition(coll, f):
"Partition a collection by a predicate"
ts,fs = [],[]
for o in coll: (fs,ts)[f(o)].append(o)
if isinstance(coll,tuple):
typ = type(coll)
ts,fs = typ(ts),typ(fs)
return ts,fs
# %% ../nbs/01_basics.ipynb
def flatten(o):
"Concatenate all collections and items as a generator"
for item in o:
if isinstance(item, str): yield item; continue
try: yield from flatten(item)
except TypeError: yield item
# %% ../nbs/01_basics.ipynb
def concat(colls)->list:
"Concatenate all collections and items as a list"
return list(flatten(colls))
# %% ../nbs/01_basics.ipynb
def strcat(its, sep:str='')->str:
"Concatenate stringified items `its`"
return sep.join(map(str,its))
# %% ../nbs/01_basics.ipynb
def detuplify(x):
"If `x` is a tuple with one thing, extract it"
return None if len(x)==0 else x[0] if len(x)==1 and getattr(x, 'ndim', 1)==1 else x
# %% ../nbs/01_basics.ipynb
def replicate(item,match):
"Create tuple of `item` copied `len(match)` times"
return (item,)*len(match)
# %% ../nbs/01_basics.ipynb
def setify(o):
"Turn any list like-object into a set."
return o if isinstance(o,set) else set(listify(o))
# %% ../nbs/01_basics.ipynb
def merge(*ds):
"Merge all dictionaries in `ds`"
return {k:v for d in ds if d is not None for k,v in d.items()}
# %% ../nbs/01_basics.ipynb
def range_of(x):
"All indices of collection `x` (i.e. `list(range(len(x)))`)"
return list(range(len(x)))
# %% ../nbs/01_basics.ipynb
def groupby(x, key, val=noop):
"Like `itertools.groupby` but doesn't need to be sorted, and isn't lazy, plus some extensions"
if isinstance(key,int): key = itemgetter(key)
elif isinstance(key,str): key = attrgetter(key)
if isinstance(val,int): val = itemgetter(val)
elif isinstance(val,str): val = attrgetter(val)
res = {}
for o in x: res.setdefault(key(o), []).append(val(o))
return res
# %% ../nbs/01_basics.ipynb
def last_index(x, o):
"Finds the last index of occurence of `x` in `o` (returns -1 if no occurence)"
try: return next(i for i in reversed(range(len(o))) if o[i] == x)
except StopIteration: return -1
# %% ../nbs/01_basics.ipynb
def filter_dict(d, func):
"Filter a `dict` using `func`, applied to keys and values"
return {k:v for k,v in d.items() if func(k,v)}
# %% ../nbs/01_basics.ipynb
def filter_keys(d, func):
"Filter a `dict` using `func`, applied to keys"
return {k:v for k,v in d.items() if func(k)}
# %% ../nbs/01_basics.ipynb
def filter_values(d, func):
"Filter a `dict` using `func`, applied to values"
return {k:v for k,v in d.items() if func(v)}
# %% ../nbs/01_basics.ipynb
def cycle(o):
"Like `itertools.cycle` except creates list of `None`s if `o` is empty"
o = listify(o)
return itertools.cycle(o) if o is not None and len(o) > 0 else itertools.cycle([None])
# %% ../nbs/01_basics.ipynb
def zip_cycle(x, *args):
"Like `itertools.zip_longest` but `cycle`s through elements of all but first argument"
return zip(x, *map(cycle,args))
# %% ../nbs/01_basics.ipynb
def sorted_ex(iterable, key=None, reverse=False):
"Like `sorted`, but if key is str use `attrgetter`; if int use `itemgetter`"
if isinstance(key,str): k=lambda o:getattr(o,key,0)
elif isinstance(key,int): k=itemgetter(key)
else: k=key
return sorted(iterable, key=k, reverse=reverse)
# %% ../nbs/01_basics.ipynb
def not_(f):
"Create new function that negates result of `f`"
def _f(*args, **kwargs): return not f(*args, **kwargs)
return _f
# %% ../nbs/01_basics.ipynb
def argwhere(iterable, f, negate=False, **kwargs):
"Like `filter_ex`, but return indices for matching items"
if kwargs: f = partial(f,**kwargs)
if negate: f = not_(f)
return [i for i,o in enumerate(iterable) if f(o)]
# %% ../nbs/01_basics.ipynb
def filter_ex(iterable, f=noop, negate=False, gen=False, **kwargs):
"Like `filter`, but passing `kwargs` to `f`, defaulting `f` to `noop`, and adding `negate` and `gen`"
if f is None: f = lambda _: True
if kwargs: f = partial(f,**kwargs)
if negate: f = not_(f)
res = filter(f, iterable)
if gen: return res
return list(res)
# %% ../nbs/01_basics.ipynb
def range_of(a, b=None, step=None):
"All indices of collection `a`, if `a` is a collection, otherwise `range`"
if is_coll(a): a = len(a)
return list(range(a,b,step) if step is not None else range(a,b) if b is not None else range(a))
# %% ../nbs/01_basics.ipynb
def renumerate(iterable, start=0):
"Same as `enumerate`, but returns index as 2nd element instead of 1st"
return ((o,i) for i,o in enumerate(iterable, start=start))
# %% ../nbs/01_basics.ipynb
def first(x, f=None, negate=False, **kwargs):
"First element of `x`, optionally filtered by `f`, or None if missing"
x = iter(x)
if f: x = filter_ex(x, f=f, negate=negate, gen=True, **kwargs)
return next(x, None)
# %% ../nbs/01_basics.ipynb
def only(o):
"Return the only item of `o`, raise if `o` doesn't have exactly one item"
it = iter(o)
try: res = next(it)
except StopIteration: raise ValueError('iterable has 0 items') from None
try: next(it)
except StopIteration: return res
raise ValueError(f'iterable has more than 1 item')
# %% ../nbs/01_basics.ipynb
def nested_attr(o, attr, default=None):
"Same as `getattr`, but if `attr` includes a `.`, then looks inside nested objects"
try:
for a in attr.split("."): o = getattr(o, a)
except AttributeError: return default
return o
# %% ../nbs/01_basics.ipynb
def nested_setdefault(o, attr, default):
"Same as `setdefault`, but if `attr` includes a `.`, then looks inside nested objects"
attrs = attr.split('.')
for a in attrs[:-1]: o = o.setdefault(a, type(o)())
return o.setdefault(attrs[-1], default)
# %% ../nbs/01_basics.ipynb
def nested_callable(o, attr):
"Same as `nested_attr` but if not found will return `noop`"
return nested_attr(o, attr, noop)
# %% ../nbs/01_basics.ipynb
def _access(coll, idx):
if isinstance(idx,str) and hasattr(coll, idx): return getattr(coll, idx)
if hasattr(coll, 'get'): return coll.get(idx, None)
try: length = len(coll)
except TypeError: length = 0
if isinstance(idx,int) and idx<length: return coll[idx]
return None
def _nested_idx(coll, *idxs):
*idxs,last_idx = idxs
for idx in idxs:
if isinstance(idx,str) and hasattr(coll, idx): coll = getattr(coll, idx)
else:
if isinstance(coll,str) or not isinstance(coll, typing.Collection): return None,None
coll = coll.get(idx, None) if hasattr(coll, 'get') else coll[idx] if idx<len(coll) else None
return coll,last_idx
# %% ../nbs/01_basics.ipynb
def nested_idx(coll, *idxs):
"Index into nested collections, dicts, etc, with `idxs`"
if not coll or not idxs: return coll
coll,idx = _nested_idx(coll, *idxs)
if not coll or not idxs: return coll
return _access(coll, idx)
# %% ../nbs/01_basics.ipynb
def set_nested_idx(coll, value, *idxs):
"Set value indexed like `nested_idx"
coll,idx = _nested_idx(coll, *idxs)
coll[idx] = value
# %% ../nbs/01_basics.ipynb
def val2idx(x):
"Dict from value to index"
return {v:k for k,v in enumerate(x)}
# %% ../nbs/01_basics.ipynb
def uniqueify(x, sort=False, bidir=False, start=None):
"Unique elements in `x`, optional `sort`, optional return reverse correspondence, optional prepend with elements."
res = list(dict.fromkeys(x))
if start is not None: res = listify(start)+res
if sort: res.sort()
return (res,val2idx(res)) if bidir else res
# %% ../nbs/01_basics.ipynb
# looping functions from https://github.com/willmcgugan/rich/blob/master/rich/_loop.py
def loop_first_last(values):
"Iterate and generate a tuple with a flag for first and last value."
iter_values = iter(values)
try: previous_value = next(iter_values)
except StopIteration: return
first = True
for value in iter_values:
yield first,False,previous_value
first,previous_value = False,value
yield first,True,previous_value
# %% ../nbs/01_basics.ipynb
def loop_first(values):
"Iterate and generate a tuple with a flag for first value."
return ((b,o) for b,_,o in loop_first_last(values))
# %% ../nbs/01_basics.ipynb
def loop_last(values):
"Iterate and generate a tuple with a flag for last value."
return ((b,o) for _,b,o in loop_first_last(values))
# %% ../nbs/01_basics.ipynb
def first_match(lst, f, default=None):
"First element of `lst` matching predicate `f`, or `default` if none"
return next((i for i,o in enumerate(lst) if f(o)), default)
# %% ../nbs/01_basics.ipynb
def last_match(lst, f, default=None):
"Last element of `lst` matching predicate `f`, or `default` if none"
return next((i for i in range(len(lst)-1, -1, -1) if f(lst[i])), default)
# %% ../nbs/01_basics.ipynb
num_methods = """
__add__ __sub__ __mul__ __matmul__ __truediv__ __floordiv__ __mod__ __divmod__ __pow__
__lshift__ __rshift__ __and__ __xor__ __or__ __neg__ __pos__ __abs__
""".split()
rnum_methods = """
__radd__ __rsub__ __rmul__ __rmatmul__ __rtruediv__ __rfloordiv__ __rmod__ __rdivmod__
__rpow__ __rlshift__ __rrshift__ __rand__ __rxor__ __ror__
""".split()
inum_methods = """
__iadd__ __isub__ __imul__ __imatmul__ __itruediv__
__ifloordiv__ __imod__ __ipow__ __ilshift__ __irshift__ __iand__ __ixor__ __ior__
""".split()
# %% ../nbs/01_basics.ipynb
class fastuple(tuple):
"A `tuple` with elementwise ops and more friendly __init__ behavior"
def __new__(cls, x=None, *rest):
if x is None: x = ()
if not isinstance(x,tuple):
if len(rest): x = (x,)
else:
try: x = tuple(iter(x))
except TypeError: x = (x,)
return super().__new__(cls, x+rest if rest else x)
def _op(self,op,*args):
if not isinstance(self,fastuple): self = fastuple(self)
return type(self)(map(op,self,*map(cycle, args)))
def mul(self,*args):
"`*` is already defined in `tuple` for replicating, so use `mul` instead"
return fastuple._op(self, operator.mul,*args)
def add(self,*args):
"`+` is already defined in `tuple` for concat, so use `add` instead"
return fastuple._op(self, operator.add,*args)
def _get_op(op):
if isinstance(op,str): op = getattr(operator,op)
def _f(self,*args): return self._op(op,*args)
return _f
for n in num_methods:
if not hasattr(fastuple, n) and hasattr(operator,n): setattr(fastuple,n,_get_op(n))
for n in 'eq ne lt le gt ge'.split(): setattr(fastuple,n,_get_op(n))
setattr(fastuple,'__invert__',_get_op('__not__'))
setattr(fastuple,'max',_get_op(max))
setattr(fastuple,'min',_get_op(min))
# %% ../nbs/01_basics.ipynb
class _Arg:
def __init__(self,i): self.i = i
arg0 = _Arg(0)
arg1 = _Arg(1)
arg2 = _Arg(2)
arg3 = _Arg(3)
arg4 = _Arg(4)
# %% ../nbs/01_basics.ipynb
class bind:
"Same as `partial`, except you can use `arg0` `arg1` etc param placeholders"
def __init__(self, func, *pargs, **pkwargs):
self.func,self.pargs,self.pkwargs = func,pargs,pkwargs
self.maxi = max((x.i for x in pargs if isinstance(x, _Arg)), default=-1)
def __call__(self, *args, **kwargs):
args = list(args)
kwargs = {**self.pkwargs,**kwargs}
for k,v in kwargs.items():
if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
return self.func(*fargs, **kwargs)
# %% ../nbs/01_basics.ipynb
def mapt(func, *iterables):
"Tuplified `map`"
return tuple(map(func, *iterables))
# %% ../nbs/01_basics.ipynb
def map_ex(iterable, f, *args, gen=False, **kwargs):
"Like `map`, but use `bind`, and supports `str` and indexing"
g = (bind(f,*args,**kwargs) if callable(f)
else f.format if isinstance(f,str)
else f.__getitem__)
res = map(g, iterable)
if gen: return res
return list(res)
# %% ../nbs/01_basics.ipynb
def compose(*funcs, order=None):
"Create a function that composes all functions in `funcs`, passing along remaining `*args` and `**kwargs` to all"
funcs = listify(funcs)
if len(funcs)==0: return noop
if len(funcs)==1: return funcs[0]
if order is not None: funcs = sorted_ex(funcs, key=order)
def _inner(x, *args, **kwargs):
for f in funcs: x = f(x, *args, **kwargs)
return x
return _inner
# %% ../nbs/01_basics.ipynb
def maps(*args, retain=noop):
"Like `map`, except funcs are composed first"
f = compose(*args[:-1])
def _f(b): return retain(f(b), b)
return map(_f, args[-1])
# %% ../nbs/01_basics.ipynb
def partialler(f, *args, order=None, **kwargs):
"Like `functools.partial` but also copies over docstring"
fnew = partial(f,*args,**kwargs)
fnew.__doc__ = f.__doc__
if order is not None: fnew.order=order
elif hasattr(f,'order'): fnew.order=f.order
return fnew
# %% ../nbs/01_basics.ipynb
def instantiate(t):
"Instantiate `t` if it's a type, otherwise do nothing"
return t() if isinstance(t, type) else t
# %% ../nbs/01_basics.ipynb
def _using_attr(f, attr, x): return f(getattr(x,attr))
# %% ../nbs/01_basics.ipynb
def using_attr(f, attr):
"Construct a function which applies `f` to the argument's attribute `attr`"
return partial(_using_attr, f, attr)
# %% ../nbs/01_basics.ipynb
class _Self:
"An alternative to `lambda` for calling methods on passed object."
def __init__(self): self.nms,self.args,self.kwargs,self.ready = [],[],[],True
def __repr__(self): return f'self: {self.nms}({self.args}, {self.kwargs})'
def __call__(self, *args, **kwargs):
if self.ready:
x = args[0]
for n,a,k in zip(self.nms,self.args,self.kwargs):
x = getattr(x,n)
if callable(x) and a is not None: x = x(*a, **k)
return x
else:
self.args.append(args)
self.kwargs.append(kwargs)
self.ready = True
return self
def __getattr__(self,k):
if not self.ready:
self.args.append(None)
self.kwargs.append(None)
self.nms.append(k)
self.ready = False
return self
def _call(self, *args, **kwargs):
self.args,self.kwargs,self.nms = [args],[kwargs],['__call__']
self.ready = True
return self
# %% ../nbs/01_basics.ipynb
class _SelfCls:
def __getattr__(self,k): return getattr(_Self(),k)
def __getitem__(self,i): return self.__getattr__('__getitem__')(i)
def __call__(self,*args,**kwargs): return self.__getattr__('_call')(*args,**kwargs)
Self = _SelfCls()
# %% ../nbs/01_basics.ipynb
_all_ = ['Self']
# %% ../nbs/01_basics.ipynb
def copy_func(f):
"Copy a non-builtin function (NB `copy.copy` does not work for this)"
if not isinstance(f,FunctionType): return copy(f)
fn = FunctionType(f.__code__, f.__globals__, f.__name__, f.__defaults__, f.__closure__)
fn.__kwdefaults__ = f.__kwdefaults__
fn.__dict__.update(f.__dict__)
fn.__annotations__.update(f.__annotations__)
fn.__qualname__ = f.__qualname__
return fn
# %% ../nbs/01_basics.ipynb
class _clsmethod:
def __init__(self, f): self.f = f
def __get__(self, _, f_cls): return MethodType(self.f, f_cls)
# %% ../nbs/01_basics.ipynb
def patch_to(cls, as_prop=False, cls_method=False):
"Decorator: add `f` to `cls`"
if not isinstance(cls, (tuple,list)): cls=(cls,)
def _inner(f):
for c_ in cls:
nf = copy_func(f)
nm = f.__name__
# `functools.update_wrapper` when passing patched function to `Pipeline`, so we do it manually
for o in functools.WRAPPER_ASSIGNMENTS: setattr(nf, o, getattr(f,o))
nf.__qualname__ = f"{c_.__name__}.{nm}"
if cls_method: setattr(c_, nm, _clsmethod(nf))
else:
if as_prop: setattr(c_, nm, property(nf))
else:
onm = '_orig_'+nm
if hasattr(c_, nm) and not hasattr(c_, onm): setattr(c_, onm, getattr(c_, nm))
setattr(c_, nm, nf)
# Avoid clobbering existing functions
return globals().get(nm, builtins.__dict__.get(nm, None))
return _inner
# %% ../nbs/01_basics.ipynb
def patch(f=None, *, as_prop=False, cls_method=False):
"Decorator: add `f` to the first parameter's class (based on f's type annotations)"
if f is None: return partial(patch, as_prop=as_prop, cls_method=cls_method)
ann,glb,loc = get_annotations_ex(f)
cls = union2tuple(eval_type(ann.pop('cls') if cls_method else next(iter(ann.values())), glb, loc))
return patch_to(cls, as_prop=as_prop, cls_method=cls_method)(f)
# %% ../nbs/01_basics.ipynb
def patch_property(f):
"Deprecated; use `patch(as_prop=True)` instead"
warnings.warn("`patch_property` is deprecated and will be removed; use `patch(as_prop=True)` instead")
cls = next(iter(f.__annotations__.values()))
return patch_to(cls, as_prop=True)(f)
# %% ../nbs/01_basics.ipynb
def compile_re(pat):
"Compile `pat` if it's not None"
return None if pat is None else re.compile(pat)
# %% ../nbs/01_basics.ipynb
class ImportEnum(enum.Enum):
"An `Enum` that can have its values imported"
@classmethod
def imports(cls):
g = sys._getframe(1).f_locals
for o in cls: g[o.name]=o
# %% ../nbs/01_basics.ipynb
class StrEnum(str,ImportEnum):
"An `ImportEnum` that behaves like a `str`"
def __str__(self): return self.name
# %% ../nbs/01_basics.ipynb
def str_enum(name, *vals):
"Simplified creation of `StrEnum` types"
return StrEnum(name, {o:o for o in vals})
# %% ../nbs/01_basics.ipynb
class ValEnum(str,ImportEnum):
"An `ImportEnum` that stringifies using values"
def __str__(self): return self.value
# %% ../nbs/01_basics.ipynb
class Stateful:
"A base class/mixin for objects that should not serialize all their state"
_stateattrs=()
def __init__(self,*args,**kwargs):
self._init_state()
super().__init__(*args,**kwargs) # required for mixin usage
def __getstate__(self):
return {k:v for k,v in self.__dict__.items()
if k not in self._stateattrs+('_state',)}
def __setstate__(self, state):
self.__dict__.update(state)
self._init_state()
def _init_state(self):
"Override for custom init and deserialization logic"
self._state = {}
# %% ../nbs/01_basics.ipynb
class NotStr(GetAttr):
"Behaves like a `str`, but isn't an instance of one"
_default = 's'
def __init__(self, s): self.s = s.s if isinstance(s, NotStr) else s
def __repr__(self): return repr(self.s)
def __str__(self): return self.s
def __add__(self, b): return NotStr(self.s+str(b))
def __mul__(self, b): return NotStr(self.s*b)
def __len__(self): return len(self.s)
def __eq__(self, b): return self.s==b.s if isinstance(b, NotStr) else b
def __lt__(self, b): return self.s<b
def __hash__(self): return hash(self.s)
def __bool__(self): return bool(self.s)
def __contains__(self, b): return b in self.s
def __iter__(self): return iter(self.s)
# %% ../nbs/01_basics.ipynb
class PrettyString(str):
"Little hack to get strings to show properly in Jupyter."
def __repr__(self): return self
# %% ../nbs/01_basics.ipynb
def even_mults(start, stop, n):
"Build log-stepped array from `start` to `stop` in `n` steps."
if n==1: return stop
mult = stop/start
step = mult**(1/(n-1))
return [start*(step**i) for i in range(n)]
# %% ../nbs/01_basics.ipynb
def num_cpus():
"Get number of cpus"
try: return len(os.sched_getaffinity(0))
except AttributeError: return os.cpu_count()
defaults.cpus = num_cpus()
# %% ../nbs/01_basics.ipynb
def add_props(f, g=None, n=2):
"Create properties passing each of `range(n)` to f"
if g is None: return (property(partial(f,i)) for i in range(n))
return (property(partial(f,i), partial(g,i)) for i in range(n))
# %% ../nbs/01_basics.ipynb
def _typeerr(arg, val, typ): return TypeError(f"{arg}=={val} not {typ}")
# %% ../nbs/01_basics.ipynb
def typed(f):
"Decorator to check param and return types at runtime"
names = f.__code__.co_varnames
anno = annotations(f)
ret = anno.pop('return',None)
def _f(*args,**kwargs):
kw = {**kwargs}
if len(anno) > 0:
for i,arg in enumerate(args): kw[names[i]] = arg
for k,v in kw.items():
if k in anno and not isinstance(v,anno[k]): raise _typeerr(k, v, anno[k])
res = f(*args,**kwargs)
if ret is not None and not isinstance(res,ret): raise _typeerr("return", res, ret)
return res
return functools.update_wrapper(_f, f)
# %% ../nbs/01_basics.ipynb
def exec_new(code):
"Execute `code` in a new environment and return it"
pkg = None if __name__=='__main__' else Path().cwd().name
g = {'__name__': __name__, '__package__': pkg}
exec(code, g)
return g
# %% ../nbs/01_basics.ipynb
def exec_import(mod, sym):
"Import `sym` from `mod` in a new environment"
# pref = '' if __name__=='__main__' or mod[0]=='.' else '.'
return exec_new(f'from {mod} import {sym}')
# %% ../nbs/01_basics.ipynb
def str2bool(s):
"Case-insensitive convert string `s` too a bool (`y`,`yes`,`t`,`true`,`on`,`1`->`True`)"
if not isinstance(s,str): return bool(s)
if not s: return False
s = s.lower()
if s in ('y', 'yes', 't', 'true', 'on', '1'): return 1
elif s in ('n', 'no', 'f', 'false', 'off', '0'): return 0
else: raise ValueError()
---
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/02_foundation.ipynb.
# %% auto 0
__all__ = ['working_directory', 'add_docs', 'docs', 'coll_repr', 'is_bool', 'mask2idxs', 'cycle', 'zip_cycle', 'is_indexer',
'CollBase', 'L', 'save_config_file', 'read_config_file', 'Config']
# %% ../nbs/02_foundation.ipynb
from .imports import *
from .basics import *
from functools import lru_cache
from contextlib import contextmanager
from copy import copy
from configparser import ConfigParser
import random,pickle,inspect
# %% ../nbs/02_foundation.ipynb
@contextmanager
def working_directory(path):
"Change working directory to `path` and return to previous on exit."
prev_cwd = Path.cwd()
os.chdir(path)
try: yield
finally: os.chdir(prev_cwd)
# %% ../nbs/02_foundation.ipynb
def add_docs(cls, cls_doc=None, **docs):
"Copy values from `docs` to `cls` docstrings, and confirm all public methods are documented"
if cls_doc is not None: cls.__doc__ = cls_doc
for k,v in docs.items():
f = getattr(cls,k)
if hasattr(f,'__func__'): f = f.__func__ # required for class methods
f.__doc__ = v
# List of public callables without docstring
nodoc = [c for n,c in vars(cls).items() if callable(c)
and not n.startswith('_') and c.__doc__ is None]
assert not nodoc, f"Missing docs: {nodoc}"
assert cls.__doc__ is not None, f"Missing class docs: {cls}"
# %% ../nbs/02_foundation.ipynb
def docs(cls):
"Decorator version of `add_docs`, using `_docs` dict"
add_docs(cls, **cls._docs)
return cls
# %% ../nbs/02_foundation.ipynb
def coll_repr(c, max_n=10):
"String repr of up to `max_n` items of (possibly lazy) collection `c`"
return f'(#{len(c)}) [' + ','.join(itertools.islice(map(repr,c), max_n)) + (
'...' if len(c)>max_n else '') + ']'
# %% ../nbs/02_foundation.ipynb
def is_bool(x):
"Check whether `x` is a bool or None"
return isinstance(x,(bool,NoneType)) or risinstance('bool_', x)
# %% ../nbs/02_foundation.ipynb
def mask2idxs(mask):
"Convert bool mask or index list to index `L`"
if isinstance(mask,slice): return mask
mask = list(mask)
if len(mask)==0: return []
it = mask[0]
if hasattr(it,'item'): it = it.item()
if is_bool(it): return [i for i,m in enumerate(mask) if m]
return [int(i) for i in mask]
# %% ../nbs/02_foundation.ipynb
def cycle(o):
"Like `itertools.cycle` except creates list of `None`s if `o` is empty"
o = listify(o)
return itertools.cycle(o) if o is not None and len(o) > 0 else itertools.cycle([None])
# %% ../nbs/02_foundation.ipynb
def zip_cycle(x, *args):
"Like `itertools.zip_longest` but `cycle`s through elements of all but first argument"
return zip(x, *map(cycle,args))
# %% ../nbs/02_foundation.ipynb
def is_indexer(idx):
"Test whether `idx` will index a single item in a list"
return isinstance(idx,int) or not getattr(idx,'ndim',1)
# %% ../nbs/02_foundation.ipynb
class CollBase:
"Base class for composing a list of `items`"
def __init__(self, items): self.items = items
def __len__(self): return len(self.items)
def __getitem__(self, k): return self.items[list(k) if isinstance(k,CollBase) else k]
def __setitem__(self, k, v): self.items[list(k) if isinstance(k,CollBase) else k] = v
def __delitem__(self, i): del(self.items[i])
def __repr__(self): return self.items.__repr__()
def __iter__(self): return self.items.__iter__()
# %% ../nbs/02_foundation.ipynb
class _L_Meta(type):
def __call__(cls, x=None, *args, **kwargs):
if not args and not kwargs and x is not None and isinstance(x,cls): return x
return super().__call__(x, *args, **kwargs)
# %% ../nbs/02_foundation.ipynb
class L(GetAttr, CollBase, metaclass=_L_Meta):
"Behaves like a list of `items` but can also index with list of indices or masks"
_default='items'
def __init__(self, items=None, *rest, use_list=False, match=None):
if (use_list is not None) or not is_array(items):
items = listify(items, *rest, use_list=use_list, match=match)
super().__init__(items)
@property
def _xtra(self): return None
def _new(self, items, *args, **kwargs): return type(self)(items, *args, use_list=None, **kwargs)
def __getitem__(self, idx): return self._get(idx) if is_indexer(idx) else L(self._get(idx), use_list=None)
def copy(self): return self._new(self.items.copy())
def _get(self, i):
if is_indexer(i) or isinstance(i,slice): return getattr(self.items,'iloc',self.items)[i]
i = mask2idxs(i)
return (self.items.iloc[list(i)] if hasattr(self.items,'iloc')
else self.items.__array__()[(i,)] if hasattr(self.items,'__array__')
else [self.items[i_] for i_ in i])
def __setitem__(self, idx, o):
"Set `idx` (can be list of indices, or mask, or int) items to `o` (which is broadcast if not iterable)"
if isinstance(idx, int): self.items[idx] = o
else:
idx = idx if isinstance(idx,L) else listify(idx)
if not is_iter(o): o = [o]*len(idx)
for i,o_ in zip(idx,o): self.items[i] = o_
def __eq__(self,b):
if b is None: return False
if not hasattr(b, '__iter__'): return False
if risinstance('ndarray', b): return array_equal(b, self)
if isinstance(b, (str,dict)) or callable(b): return False
return all_equal(b,self)
def sorted(self, key=None, reverse=False): return self._new(sorted_ex(self, key=key, reverse=reverse))
def __iter__(self): return iter(self.items.itertuples() if hasattr(self.items,'iloc') else self.items)
def __contains__(self,b): return b in self.items
def __reversed__(self): return self._new(reversed(self.items))
def __invert__(self): return self._new(not i for i in self)
def __repr__(self): return repr(self.items)
def _repr_pretty_(self, p, cycle):
p.text('...' if cycle else repr(self.items) if is_array(self.items) else coll_repr(self))
def __mul__ (a,b): return a._new(a.items*b)
def __add__ (a,b): return a._new(a.items+listify(b))
def __radd__(a,b): return a._new(b)+a
def __addi__(a,b):
a.items += list(b)
return a
@classmethod
def split(cls, s, sep=None, maxsplit=-1): return cls(s.split(sep,maxsplit))
@classmethod
def range(cls, a, b=None, step=None): return cls(range_of(a, b=b, step=step))
def map(self, f, *args, **kwargs): return self._new(map_ex(self, f, *args, gen=False, **kwargs))
def argwhere(self, f, negate=False, **kwargs): return self._new(argwhere(self, f, negate, **kwargs))
def argfirst(self, f, negate=False):
if negate: f = not_(f)
return first(i for i,o in self.enumerate() if f(o))
def filter(self, f=noop, negate=False, **kwargs):
return self._new(filter_ex(self, f=f, negate=negate, gen=False, **kwargs))
def enumerate(self): return L(enumerate(self))
def renumerate(self): return L(renumerate(self))
def unique(self, sort=False, bidir=False, start=None): return L(uniqueify(self, sort=sort, bidir=bidir, start=start))
def val2idx(self): return val2idx(self)
def cycle(self): return cycle(self)
def map_dict(self, f=noop, *args, **kwargs): return {k:f(k, *args,**kwargs) for k in self}
def map_first(self, f=noop, g=noop, *args, **kwargs):
return first(self.map(f, *args, **kwargs), g)
def itemgot(self, *idxs):
x = self
for idx in idxs: x = x.map(itemgetter(idx))
return x
def attrgot(self, k, default=None):
return self.map(lambda o: o.get(k,default) if isinstance(o, dict) else nested_attr(o,k,default))
def starmap(self, f, *args, **kwargs): return self._new(itertools.starmap(partial(f,*args,**kwargs), self))
def zip(self, cycled=False): return self._new((zip_cycle if cycled else zip)(*self))
def zipwith(self, *rest, cycled=False): return self._new([self, *rest]).zip(cycled=cycled)
def map_zip(self, f, *args, cycled=False, **kwargs): return self.zip(cycled=cycled).starmap(f, *args, **kwargs)
def map_zipwith(self, f, *rest, cycled=False, **kwargs): return self.zipwith(*rest, cycled=cycled).starmap(f, **kwargs)
def shuffle(self):
it = copy(self.items)
random.shuffle(it)
return self._new(it)
def concat(self): return self._new(itertools.chain.from_iterable(self.map(L)))
def reduce(self, f, initial=None): return reduce(f, self) if initial is None else reduce(f, self, initial)
def sum(self): return self.reduce(operator.add, 0)
def product(self): return self.reduce(operator.mul, 1)
def setattrs(self, attr, val): [setattr(o,attr,val) for o in self]
# %% ../nbs/02_foundation.ipynb
add_docs(L,
__getitem__="Retrieve `idx` (can be list of indices, or mask, or int) items",
range="Class Method: Same as `range`, but returns `L`. Can pass collection for `a`, to use `len(a)`",
split="Class Method: Same as `str.split`, but returns an `L`",
copy="Same as `list.copy`, but returns an `L`",
sorted="New `L` sorted by `key`. If key is str use `attrgetter`; if int use `itemgetter`",
unique="Unique items, in stable order",
val2idx="Dict from value to index",
filter="Create new `L` filtered by predicate `f`, passing `args` and `kwargs` to `f`",
argwhere="Like `filter`, but return indices for matching items",
argfirst="Return index of first matching item",
map="Create new `L` with `f` applied to all `items`, passing `args` and `kwargs` to `f`",
map_first="First element of `map_filter`",
map_dict="Like `map`, but creates a dict from `items` to function results",
starmap="Like `map`, but use `itertools.starmap`",
itemgot="Create new `L` with item `idx` of all `items`",
attrgot="Create new `L` with attr `k` (or value `k` for dicts) of all `items`.",
cycle="Same as `itertools.cycle`",
enumerate="Same as `enumerate`",
renumerate="Same as `renumerate`",
zip="Create new `L` with `zip(*items)`",
zipwith="Create new `L` with `self` zip with each of `*rest`",
map_zip="Combine `zip` and `starmap`",
map_zipwith="Combine `zipwith` and `starmap`",
concat="Concatenate all elements of list",
shuffle="Same as `random.shuffle`, but not inplace",
reduce="Wrapper for `functools.reduce`",
sum="Sum of the items",
product="Product of the items",
setattrs="Call `setattr` on all items"
)
# %% ../nbs/02_foundation.ipynb
# Here we are fixing the signature of L. What happens is that the __call__ method on the MetaClass of L shadows the __init__
# giving the wrong signature (https://stackoverflow.com/questions/49740290/call-from-metaclass-shadows-signature-of-init).
def _f(items=None, *rest, use_list=False, match=None): ...
L.__signature__ = inspect.signature(_f)
# %% ../nbs/02_foundation.ipynb
Sequence.register(L);
# %% ../nbs/02_foundation.ipynb
def save_config_file(file, d, **kwargs):
"Write settings dict to a new config file, or overwrite the existing one."
config = ConfigParser(**kwargs)
config['DEFAULT'] = d
config.write(open(file, 'w'))
# %% ../nbs/02_foundation.ipynb
def read_config_file(file, **kwargs):
config = ConfigParser(**kwargs)
config.read(file, encoding='utf8')
return config['DEFAULT']
# %% ../nbs/02_foundation.ipynb
class Config:
"Reading and writing `ConfigParser` ini files"
def __init__(self, cfg_path, cfg_name, create=None, save=True, extra_files=None, types=None):
self.types = types or {}
cfg_path = Path(cfg_path).expanduser().absolute()
self.config_path,self.config_file = cfg_path,cfg_path/cfg_name
self._cfg = ConfigParser()
self.d = self._cfg['DEFAULT']
found = [Path(o) for o in self._cfg.read(L(extra_files)+[self.config_file], encoding='utf8')]
if self.config_file not in found and create is not None:
self._cfg.read_dict({'DEFAULT':create})
if save:
cfg_path.mkdir(exist_ok=True, parents=True)
save_config_file(self.config_file, create)
def __repr__(self): return repr(dict(self._cfg.items('DEFAULT', raw=True)))
def __setitem__(self,k,v): self.d[k] = str(v)
def __contains__(self,k): return k in self.d
def save(self): save_config_file(self.config_file,self.d)
def __getattr__(self,k): return stop(AttributeError(k)) if k=='d' or k not in self.d else self.get(k)
def __getitem__(self,k): return stop(IndexError(k)) if k not in self.d else self.get(k)
def get(self,k,default=None):
v = self.d.get(k, default)
if v is None: return None
typ = self.types.get(k, None)
if typ==bool: return str2bool(v)
if not typ: return str(v)
if typ==Path: return self.config_path/v
return typ(v)
def path(self,k,default=None):
v = self.get(k, default)
return v if v is None else self.config_path/v
---
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/03_xtras.ipynb.
# %% ../nbs/03_xtras.ipynb 1
from __future__ import annotations
# %% auto 0
__all__ = ['spark_chars', 'UNSET', 'walk', 'globtastic', 'maybe_open', 'mkdir', 'image_size', 'bunzip', 'loads', 'loads_multi',
'dumps', 'untar_dir', 'repo_details', 'run', 'open_file', 'save_pickle', 'load_pickle', 'parse_env',
'expand_wildcards', 'dict2obj', 'obj2dict', 'repr_dict', 'is_listy', 'mapped', 'IterLen',
'ReindexCollection', 'get_source_link', 'truncstr', 'sparkline', 'modify_exception', 'round_multiple',
'set_num_threads', 'join_path_file', 'autostart', 'EventTimer', 'stringfmt_names', 'PartialFormatter',
'partial_format', 'utc2local', 'local2utc', 'trace', 'modified_env', 'ContextManagers', 'shufflish',
'console_help', 'hl_md', 'type2str', 'dataclass_src', 'Unset', 'nullable_dc', 'make_nullable', 'flexiclass',
'asdict', 'is_typeddict', 'is_namedtuple', 'flexicache', 'time_policy', 'mtime_policy', 'timed_cache']
# %% ../nbs/03_xtras.ipynb
from .imports import *
from .foundation import *
from .basics import *
from importlib import import_module
from functools import wraps
import string,time,dataclasses
from enum import Enum
from contextlib import contextmanager,ExitStack
from datetime import datetime, timezone
from time import sleep,time,perf_counter
from os.path import getmtime
from dataclasses import dataclass, field, fields, is_dataclass, MISSING, make_dataclass
# %% ../nbs/03_xtras.ipynb
def walk(
path:Path|str, # path to start searching
symlinks:bool=True, # follow symlinks?
keep_file:callable=ret_true, # function that returns True for wanted files
keep_folder:callable=ret_true, # function that returns True for folders to enter
skip_folder:callable=ret_false, # function that returns True for folders to skip
func:callable=os.path.join, # function to apply to each matched file
ret_folders:bool=False # return folders, not just files
):
"Generator version of `os.walk`, using functions to filter files and folders"
from copy import copy
for root,dirs,files in os.walk(path, followlinks=symlinks):
if keep_folder(root,''):
if ret_folders: yield func(root, '')
yield from (func(root, name) for name in files if keep_file(root,name))
for name in copy(dirs):
if skip_folder(root,name): dirs.remove(name)
# %% ../nbs/03_xtras.ipynb
def globtastic(
path:Path|str, # path to start searching
recursive:bool=True, # search subfolders
symlinks:bool=True, # follow symlinks?
file_glob:str=None, # Only include files matching glob
file_re:str=None, # Only include files matching regex
folder_re:str=None, # Only enter folders matching regex
skip_file_glob:str=None, # Skip files matching glob
skip_file_re:str=None, # Skip files matching regex
skip_folder_re:str=None, # Skip folders matching regex,
func:callable=os.path.join, # function to apply to each matched file
ret_folders:bool=False # return folders, not just files
)->L: # Paths to matched files
"A more powerful `glob`, including regex matches, symlink handling, and skip parameters"
from fnmatch import fnmatch
path = Path(path)
if path.is_file(): return L([path])
if not recursive: skip_folder_re='.'
file_re,folder_re = compile_re(file_re),compile_re(folder_re)
skip_file_re,skip_folder_re = compile_re(skip_file_re),compile_re(skip_folder_re)
def _keep_file(root, name):
return (not file_glob or fnmatch(name, file_glob)) and (
not file_re or file_re.search(name)) and (
not skip_file_glob or not fnmatch(name, skip_file_glob)) and (
not skip_file_re or not skip_file_re.search(name))
def _keep_folder(root, name): return not folder_re or folder_re.search(os.path.join(root,name))
def _skip_folder(root, name): return skip_folder_re and skip_folder_re.search(name)
return L(walk(path, symlinks=symlinks, keep_file=_keep_file, keep_folder=_keep_folder, skip_folder=_skip_folder,
func=func, ret_folders=ret_folders))
# %% ../nbs/03_xtras.ipynb
@contextmanager
def maybe_open(f, mode='r', **kwargs):
"Context manager: open `f` if it is a path (and close on exit)"
if isinstance(f, (str,os.PathLike)):
with open(f, mode, **kwargs) as f: yield f
else: yield f
# %% ../nbs/03_xtras.ipynb
def mkdir(path, exist_ok=False, parents=False, overwrite=False, **kwargs):
"Creates and returns a directory defined by `path`, optionally removing previous existing directory if `overwrite` is `True`"
import shutil
path = Path(path)
if path.exists() and overwrite: shutil.rmtree(path)
path.mkdir(exist_ok=exist_ok, parents=parents, **kwargs)
return path
# %% ../nbs/03_xtras.ipynb
def image_size(fn):
"Tuple of (w,h) for png, gif, or jpg; `None` otherwise"
from fastcore import imghdr
import struct
def _jpg_size(f):
size,ftype = 2,0
while not 0xc0 <= ftype <= 0xcf:
f.seek(size, 1)
byte = f.read(1)
while ord(byte) == 0xff: byte = f.read(1)
ftype = ord(byte)
size = struct.unpack('>H', f.read(2))[0] - 2
f.seek(1, 1) # `precision'
h,w = struct.unpack('>HH', f.read(4))
return w,h
def _gif_size(f): return struct.unpack('<HH', head[6:10])
def _png_size(f):
assert struct.unpack('>i', head[4:8])[0]==0x0d0a1a0a
return struct.unpack('>ii', head[16:24])
d = dict(png=_png_size, gif=_gif_size, jpeg=_jpg_size)
with maybe_open(fn, 'rb') as f: return d[imghdr.what(f)](f)
# %% ../nbs/03_xtras.ipynb
def bunzip(fn):
"bunzip `fn`, raising exception if output already exists"
fn = Path(fn)
assert fn.exists(), f"{fn} doesn't exist"
out_fn = fn.with_suffix('')
assert not out_fn.exists(), f"{out_fn} already exists"
import bz2
with bz2.BZ2File(fn, 'rb') as src, out_fn.open('wb') as dst:
for d in iter(lambda: src.read(1024*1024), b''): dst.write(d)
# %% ../nbs/03_xtras.ipynb
def loads(s, **kw):
"Same as `json.loads`, but handles `None`"
if not s: return {}
try: import ujson as json
except ModuleNotFoundError: import json
return json.loads(s, **kw)
# %% ../nbs/03_xtras.ipynb
def loads_multi(s:str):
"Generator of >=0 decoded json dicts, possibly with non-json ignored text at start and end"
import json
_dec = json.JSONDecoder()
while s.find('{')>=0:
s = s[s.find('{'):]
obj,pos = _dec.raw_decode(s)
if not pos: raise ValueError(f'no JSON object found at {pos}')
yield obj
s = s[pos:]
# %% ../nbs/03_xtras.ipynb
def dumps(obj, **kw):
"Same as `json.dumps`, but uses `ujson` if available"
try: import ujson as json
except ModuleNotFoundError: import json
else: kw['escape_forward_slashes']=False
return json.dumps(obj, **kw)
# %% ../nbs/03_xtras.ipynb
def _unpack(fname, out):
import shutil
shutil.unpack_archive(str(fname), str(out))
ls = out.ls()
return ls[0] if len(ls) == 1 else out
# %% ../nbs/03_xtras.ipynb
def untar_dir(fname, dest, rename=False, overwrite=False):
"untar `file` into `dest`, creating a directory if the root contains more than one item"
import tempfile,shutil
with tempfile.TemporaryDirectory() as d:
out = Path(d)/remove_suffix(Path(fname).stem, '.tar')
out.mkdir()
if rename: dest = dest/out.name
else:
src = _unpack(fname, out)
dest = dest/src.name
if dest.exists():
if overwrite: shutil.rmtree(dest) if dest.is_dir() else dest.unlink()
else: return dest
if rename: src = _unpack(fname, out)
shutil.move(str(src), dest)
return dest
# %% ../nbs/03_xtras.ipynb
def repo_details(url):
"Tuple of `owner,name` from ssh or https git repo `url`"
res = remove_suffix(url.strip(), '.git')
res = res.split(':')[-1]
return res.split('/')[-2:]
# %% ../nbs/03_xtras.ipynb
def run(cmd, *rest, same_in_win=False, ignore_ex=False, as_bytes=False, stderr=False):
"Pass `cmd` (splitting with `shlex` if string) to `subprocess.run`; return `stdout`; raise `IOError` if fails"
# Even the command is same on Windows, we have to add `cmd /c `"
import subprocess
if rest:
if sys.platform == 'win32' and same_in_win:
cmd = ('cmd', '/c', cmd, *rest)
else:
cmd = (cmd,)+rest
elif isinstance(cmd, str):
if sys.platform == 'win32' and same_in_win: cmd = 'cmd /c ' + cmd
import shlex
cmd = shlex.split(cmd)
elif isinstance(cmd, list):
if sys.platform == 'win32' and same_in_win: cmd = ['cmd', '/c'] + cmd
res = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout = res.stdout
if stderr and res.stderr: stdout += b' ;; ' + res.stderr
if not as_bytes: stdout = stdout.decode().strip()
if ignore_ex: return (res.returncode, stdout)
if res.returncode: raise IOError(stdout)
return stdout
# %% ../nbs/03_xtras.ipynb
def open_file(fn, mode='r', **kwargs):
"Open a file, with optional compression if gz or bz2 suffix"
if isinstance(fn, io.IOBase): return fn
import bz2,gzip,zipfile
fn = Path(fn)
if fn.suffix=='.bz2': return bz2.BZ2File(fn, mode, **kwargs)
elif fn.suffix=='.gz' : return gzip.GzipFile(fn, mode, **kwargs)
elif fn.suffix=='.zip': return zipfile.ZipFile(fn, mode, **kwargs)
else: return open(fn,mode, **kwargs)
# %% ../nbs/03_xtras.ipynb
def save_pickle(fn, o):
"Save a pickle file, to a file name or opened file"
import pickle
with open_file(fn, 'wb') as f: pickle.dump(o, f)
# %% ../nbs/03_xtras.ipynb
def load_pickle(fn):
"Load a pickle file from a file name or opened file"
import pickle
with open_file(fn, 'rb') as f: return pickle.load(f)
# %% ../nbs/03_xtras.ipynb
def parse_env(s:str=None, fn:Union[str,Path]=None) -> dict:
"Parse a shell-style environment string or file"
assert bool(s)^bool(fn), "Must pass exactly one of `s` or `fn`"
if fn: s = Path(fn).read_text()
def _f(line):
m = re.match(r'^\s*(?:export\s+)?(\w+)\s*=\s*(["\']?)(.*?)(\2)\s*(?:#.*)?$', line).groups()
return m[0], m[2]
return dict(_f(o.strip()) for o in s.splitlines() if o.strip() and not re.match(r'\s*#', o))
# %% ../nbs/03_xtras.ipynb
def expand_wildcards(code):
"Expand all wildcard imports in the given code string."
import ast,importlib
tree = ast.parse(code)
def _replace_node(code, old_node, new_node):
"Replace `old_node` in the source `code` with `new_node`."
lines = code.splitlines()
lnum = old_node.lineno
indent = ' ' * (len(lines[lnum-1]) - len(lines[lnum-1].lstrip()))
new_lines = [indent+line for line in ast.unparse(new_node).splitlines()]
lines[lnum-1 : old_node.end_lineno] = new_lines
return '\n'.join(lines)
def _expand_import(node, mod, existing):
"Create expanded import `node` in `tree` from wildcard import of `mod`."
mod_all = getattr(mod, '__all__', None)
available_names = set(mod_all) if mod_all is not None else set(dir(mod))
used_names = {n.id for n in ast.walk(tree) if isinstance(n, ast.Name) and n.id in available_names} - existing
if not used_names: return node
names = [ast.alias(name=name, asname=None) for name in sorted(used_names)]
return ast.ImportFrom(module=node.module, names=names, level=node.level)
existing = set()
for node in ast.walk(tree):
if isinstance(node, ast.ImportFrom) and node.names[0].name != '*': existing.update(n.name for n in node.names)
elif isinstance(node, ast.Import): existing.update(n.name.split('.')[0] for n in node.names)
for node in ast.walk(tree):
if isinstance(node, ast.ImportFrom) and any(n.name == '*' for n in node.names):
new_import = _expand_import(node, importlib.import_module(node.module), existing)
code = _replace_node(code, node, new_import)
return code
# %% ../nbs/03_xtras.ipynb
def dict2obj(d, list_func=L, dict_func=AttrDict):
"Convert (possibly nested) dicts (or lists of dicts) to `AttrDict`"
if isinstance(d, (L,list)): return list_func(d).map(dict2obj)
if not isinstance(d, dict): return d
return dict_func(**{k:dict2obj(v) for k,v in d.items()})
# %% ../nbs/03_xtras.ipynb
def obj2dict(d):
"Convert (possibly nested) AttrDicts (or lists of AttrDicts) to `dict`"
if isinstance(d, (L,list)): return list(L(d).map(obj2dict))
if not isinstance(d, dict): return d
return dict(**{k:obj2dict(v) for k,v in d.items()})
# %% ../nbs/03_xtras.ipynb
def _repr_dict(d, lvl):
if isinstance(d,dict):
its = [f"{k}: {_repr_dict(v,lvl+1)}" for k,v in d.items()]
elif isinstance(d,(list,L)): its = [_repr_dict(o,lvl+1) for o in d]
else: return str(d)
return '\n' + '\n'.join([" "*(lvl*2) + "- " + o for o in its])
# %% ../nbs/03_xtras.ipynb
def repr_dict(d):
"Print nested dicts and lists, such as returned by `dict2obj`"
return _repr_dict(d,0).strip()
# %% ../nbs/03_xtras.ipynb
def is_listy(x):
"`isinstance(x, (tuple,list,L,slice,Generator))`"
return isinstance(x, (tuple,list,L,slice,Generator))
# %% ../nbs/03_xtras.ipynb
def mapped(f, it):
"map `f` over `it`, unless it's not listy, in which case return `f(it)`"
return L(it).map(f) if is_listy(it) else f(it)
# %% ../nbs/03_xtras.ipynb
@patch
def readlines(self:Path, hint=-1, encoding='utf8'):
"Read the content of `self`"
with self.open(encoding=encoding) as f: return f.readlines(hint)
# %% ../nbs/03_xtras.ipynb
@patch
def read_json(self:Path, encoding=None, errors=None):
"Same as `read_text` followed by `loads`"
return loads(self.read_text(encoding=encoding, errors=errors))
# %% ../nbs/03_xtras.ipynb
@patch
def mk_write(self:Path, data, encoding=None, errors=None, mode=511):
"Make all parent dirs of `self`, and write `data`"
self.parent.mkdir(exist_ok=True, parents=True, mode=mode)
self.write_text(data, encoding=encoding, errors=errors)
# %% ../nbs/03_xtras.ipynb
@patch
def relpath(self:Path, start=None):
"Same as `os.path.relpath`, but returns a `Path`, and resolves symlinks"
return Path(os.path.relpath(self.resolve(), Path(start).resolve()))
# %% ../nbs/03_xtras.ipynb
@patch
def ls(self:Path, n_max=None, file_type=None, file_exts=None):
"Contents of path as a list"
import mimetypes
extns=L(file_exts)
if file_type: extns += L(k for k,v in mimetypes.types_map.items() if v.startswith(file_type+'/'))
has_extns = len(extns)==0
res = (o for o in self.iterdir() if has_extns or o.suffix in extns)
if n_max is not None: res = itertools.islice(res, n_max)
return L(res)
# %% ../nbs/03_xtras.ipynb
@patch
def __repr__(self:Path):
b = getattr(Path, 'BASE_PATH', None)
if b:
try: self = self.relative_to(b)
except: pass
return f"Path({self.as_posix()!r})"
# %% ../nbs/03_xtras.ipynb
@patch
def delete(self:Path):
"Delete a file, symlink, or directory tree"
if not self.exists(): return
if self.is_dir():
import shutil
shutil.rmtree(self)
else: self.unlink()
# %% ../nbs/03_xtras.ipynb
class IterLen:
"Base class to add iteration to anything supporting `__len__` and `__getitem__`"
def __iter__(self): return (self[i] for i in range_of(self))
# %% ../nbs/03_xtras.ipynb
@docs
class ReindexCollection(GetAttr, IterLen):
"Reindexes collection `coll` with indices `idxs` and optional LRU cache of size `cache`"
_default='coll'
def __init__(self, coll, idxs=None, cache=None, tfm=noop):
if idxs is None: idxs = L.range(coll)
store_attr()
if cache is not None: self._get = functools.lru_cache(maxsize=cache)(self._get)
def _get(self, i): return self.tfm(self.coll[i])
def __getitem__(self, i): return self._get(self.idxs[i])
def __len__(self): return len(self.coll)
def reindex(self, idxs): self.idxs = idxs
def shuffle(self):
import random
random.shuffle(self.idxs)
def cache_clear(self): self._get.cache_clear()
def __getstate__(self): return {'coll': self.coll, 'idxs': self.idxs, 'cache': self.cache, 'tfm': self.tfm}
def __setstate__(self, s): self.coll,self.idxs,self.cache,self.tfm = s['coll'],s['idxs'],s['cache'],s['tfm']
_docs = dict(reindex="Replace `self.idxs` with idxs",
shuffle="Randomly shuffle indices",
cache_clear="Clear LRU cache")
# %% ../nbs/03_xtras.ipynb
def _is_type_dispatch(x): return type(x).__name__ == "TypeDispatch"
def _unwrapped_type_dispatch_func(x): return x.first() if _is_type_dispatch(x) else x
def _is_property(x): return type(x)==property
def _has_property_getter(x): return _is_property(x) and hasattr(x, 'fget') and hasattr(x.fget, 'func')
def _property_getter(x): return x.fget.func if _has_property_getter(x) else x
def _unwrapped_func(x):
x = _unwrapped_type_dispatch_func(x)
x = _property_getter(x)
return x
def get_source_link(func):
"Return link to `func` in source code"
import inspect
func = _unwrapped_func(func)
try: line = inspect.getsourcelines(func)[1]
except Exception: return ''
mod = inspect.getmodule(func)
module = mod.__name__.replace('.', '/') + '.py'
try:
nbdev_mod = import_module(mod.__package__.split('.')[0] + '._nbdev')
return f"{nbdev_mod.git_url}{module}#L{line}"
except: return f"{module}#L{line}"
# %% ../nbs/03_xtras.ipynb
def truncstr(s:str, maxlen:int, suf:str='…', space='')->str:
"Truncate `s` to length `maxlen`, adding suffix `suf` if truncated"
return s[:maxlen-len(suf)]+suf if len(s)+len(space)>maxlen else s+space
# %% ../nbs/03_xtras.ipynb
spark_chars = '▁▂▃▅▆▇'
# %% ../nbs/03_xtras.ipynb
def _ceil(x, lim=None): return x if (not lim or x <= lim) else lim
def _sparkchar(x, mn, mx, incr, empty_zero):
if x is None or (empty_zero and not x): return ' '
if incr == 0: return spark_chars[0]
res = int((_ceil(x,mx)-mn)/incr-0.5)
return spark_chars[res]
# %% ../nbs/03_xtras.ipynb
def sparkline(data, mn=None, mx=None, empty_zero=False):
"Sparkline for `data`, with `None`s (and zero, if `empty_zero`) shown as empty column"
valid = [o for o in data if o is not None]
if not valid: return ' '
mn,mx,n = ifnone(mn,min(valid)),ifnone(mx,max(valid)),len(spark_chars)
res = [_sparkchar(x=o, mn=mn, mx=mx, incr=(mx-mn)/n, empty_zero=empty_zero) for o in data]
return ''.join(res)
# %% ../nbs/03_xtras.ipynb
def modify_exception(
e:Exception, # An exception
msg:str=None, # A custom message
replace:bool=False, # Whether to replace e.args with [msg]
) -> Exception:
"Modifies `e` with a custom message attached"
e.args = [f'{e.args[0]} {msg}'] if not replace and len(e.args) > 0 else [msg]
return e
# %% ../nbs/03_xtras.ipynb
def round_multiple(x, mult, round_down=False):
"Round `x` to nearest multiple of `mult`"
def _f(x_): return (int if round_down else round)(x_/mult)*mult
res = L(x).map(_f)
return res if is_listy(x) else res[0]
# %% ../nbs/03_xtras.ipynb
def set_num_threads(nt):
"Get numpy (and others) to use `nt` threads"
try: import mkl; mkl.set_num_threads(nt)
except: pass
try: import torch; torch.set_num_threads(nt)
except: pass
os.environ['IPC_ENABLE']='1'
for o in ['OPENBLAS_NUM_THREADS','NUMEXPR_NUM_THREADS','OMP_NUM_THREADS','MKL_NUM_THREADS']:
os.environ[o] = str(nt)
# %% ../nbs/03_xtras.ipynb
def join_path_file(file, path, ext=''):
"Return `path/file` if file is a string or a `Path`, file otherwise"
if not isinstance(file, (str, Path)): return file
path.mkdir(parents=True, exist_ok=True)
return path/f'{file}{ext}'
# %% ../nbs/03_xtras.ipynb
def autostart(g):
"Decorator that automatically starts a generator"
@functools.wraps(g)
def f():
r = g()
next(r)
return r
return f
# %% ../nbs/03_xtras.ipynb
class EventTimer:
"An event timer with history of `store` items of time `span`"
def __init__(self, store=5, span=60):
import collections
self.hist,self.span,self.last = collections.deque(maxlen=store),span,perf_counter()
self._reset()
def _reset(self): self.start,self.events = self.last,0
def add(self, n=1):
"Record `n` events"
if self.duration>self.span:
self.hist.append(self.freq)
self._reset()
self.events +=n
self.last = perf_counter()
@property
def duration(self): return perf_counter()-self.start
@property
def freq(self): return self.events/self.duration
# %% ../nbs/03_xtras.ipynb
_fmt = string.Formatter()
# %% ../nbs/03_xtras.ipynb
def stringfmt_names(s:str)->list:
"Unique brace-delimited names in `s`"
return uniqueify(o[1] for o in _fmt.parse(s) if o[1])
# %% ../nbs/03_xtras.ipynb
class PartialFormatter(string.Formatter):
"A `string.Formatter` that doesn't error on missing fields, and tracks missing fields and unused args"
def __init__(self):
self.missing = set()
super().__init__()
def get_field(self, nm, args, kwargs):
try: return super().get_field(nm, args, kwargs)
except KeyError:
self.missing.add(nm)
return '{'+nm+'}',nm
def check_unused_args(self, used, args, kwargs):
self.xtra = filter_keys(kwargs, lambda o: o not in used)
# %% ../nbs/03_xtras.ipynb
def partial_format(s:str, **kwargs):
"string format `s`, ignoring missing field errors, returning missing and extra fields"
fmt = PartialFormatter()
res = fmt.format(s, **kwargs)
return res,list(fmt.missing),fmt.xtra
# %% ../nbs/03_xtras.ipynb
def utc2local(dt:datetime)->datetime:
"Convert `dt` from UTC to local time"
return dt.replace(tzinfo=timezone.utc).astimezone(tz=None)
# %% ../nbs/03_xtras.ipynb
def local2utc(dt:datetime)->datetime:
"Convert `dt` from local to UTC time"
return dt.replace(tzinfo=None).astimezone(tz=timezone.utc)
# %% ../nbs/03_xtras.ipynb
def trace(f):
"Add `set_trace` to an existing function `f`"
from pdb import set_trace
if getattr(f, '_traced', False): return f
def _inner(*args,**kwargs):
set_trace()
return f(*args,**kwargs)
_inner._traced = True
return _inner
# %% ../nbs/03_xtras.ipynb
@contextmanager
def modified_env(*delete, **replace):
"Context manager temporarily modifying `os.environ` by deleting `delete` and replacing `replace`"
prev = dict(os.environ)
try:
os.environ.update(replace)
for k in delete: os.environ.pop(k, None)
yield
finally:
os.environ.clear()
os.environ.update(prev)
# %% ../nbs/03_xtras.ipynb
class ContextManagers(GetAttr):
"Wrapper for `contextlib.ExitStack` which enters a collection of context managers"
def __init__(self, mgrs): self.default,self.stack = L(mgrs),ExitStack()
def __enter__(self): self.default.map(self.stack.enter_context)
def __exit__(self, *args, **kwargs): self.stack.__exit__(*args, **kwargs)
# %% ../nbs/03_xtras.ipynb
def shufflish(x, pct=0.04):
"Randomly relocate items of `x` up to `pct` of `len(x)` from their starting location"
n = len(x)
import random
return L(x[i] for i in sorted(range_of(x), key=lambda o: o+n*(1+random.random()*pct)))
# %% ../nbs/03_xtras.ipynb
def console_help(
libname:str): # name of library for console script listing
"Show help for all console scripts from `libname`"
from fastcore.style import S
from pkg_resources import iter_entry_points as ep
for e in ep('console_scripts'):
if e.module_name == libname or e.module_name.startswith(libname+'.'):
nm = S.bold.light_blue(e.name)
print(f'{nm:45}{e.load().__doc__}')
# %% ../nbs/03_xtras.ipynb
def hl_md(s, lang='xml', show=True):
"Syntax highlight `s` using `lang`."
md = f'```{lang}\n{s}\n```'
if not show: return md
try:
from IPython import display
return display.Markdown(md)
except ImportError: print(s)
# %% ../nbs/03_xtras.ipynb
def type2str(typ:type)->str:
"Stringify `typ`"
if typ is None or typ is NoneType: return 'None'
if hasattr(typ, '__origin__'):
args = ", ".join(type2str(arg) for arg in typ.__args__)
if typ.__origin__ is Union: return f"Union[{args}]"
return f"{typ.__origin__.__name__}[{args}]"
elif isinstance(typ, type): return typ.__name__
return str(typ)
# %% ../nbs/03_xtras.ipynb
def dataclass_src(cls):
src = f"@dataclass\nclass {cls.__name__}:\n"
for f in dataclasses.fields(cls):
d = "" if f.default is dataclasses.MISSING else f" = {f.default!r}"
src += f" {f.name}: {type2str(f.type)}{d}\n"
return src
# %% ../nbs/03_xtras.ipynb
class Unset(Enum):
_Unset=''
def __repr__(self): return 'UNSET'
def __str__ (self): return 'UNSET'
def __bool__(self): return False
UNSET = Unset._Unset
# %% ../nbs/03_xtras.ipynb
def nullable_dc(cls):
"Like `dataclass`, but default of `UNSET` added to fields without defaults"
for k,v in get_annotations_ex(cls)[0].items():
if not hasattr(cls,k): setattr(cls, k, field(default=UNSET))
return dataclass(cls)
# %% ../nbs/03_xtras.ipynb
def make_nullable(clas):
if hasattr(clas, '_nullable'): return
clas._nullable = True
original_init = clas.__init__
def __init__(self, *args, **kwargs):
flds = fields(clas)
dargs = {k.name:v for k,v in zip(flds, args)}
for f in flds:
nm = f.name
if nm not in dargs and nm not in kwargs and f.default is None and f.default_factory is MISSING:
kwargs[nm] = UNSET
original_init(self, *args, **kwargs)
clas.__init__ = __init__
for f in fields(clas):
if f.default is MISSING and f.default_factory is MISSING: f.default = None
return clas
# %% ../nbs/03_xtras.ipynb
def flexiclass(cls):
"Convert `cls` into a `dataclass` like `make_nullable`"
if is_dataclass(cls): return make_nullable(cls)
for k,v in get_annotations_ex(cls)[0].items():
if not hasattr(cls,k) or getattr(cls,k) is MISSING:
setattr(cls, k, field(default=UNSET))
return dataclass(cls, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False)
# %% ../nbs/03_xtras.ipynb
def asdict(o)->dict:
"Convert `o` to a `dict`, supporting dataclasses, namedtuples, iterables, and `__dict__` attrs."
if isinstance(o, dict): return o
if is_dataclass(o): r = dataclasses.asdict(o)
elif hasattr(o, '_asdict'): r = o._asdict()
elif hasattr(o, '__iter__'):
try: r = dict(o)
except TypeError: pass
elif hasattr(o, '__dict__'): r = o.__dict__
else: raise TypeError(f'Can not convert {o} to a dict')
return {k:v for k,v in r.items() if v not in (UNSET,MISSING)}
# %% ../nbs/03_xtras.ipynb
def is_typeddict(cls:type)->bool:
"Check if `cls` is a `TypedDict`"
attrs = 'annotations', 'required_keys', 'optional_keys'
return isinstance(cls, type) and all(hasattr(cls, f'__{attr}__') for attr in attrs)
# %% ../nbs/03_xtras.ipynb
def is_namedtuple(cls):
"`True` if `cls` is a namedtuple type"
return issubclass(cls, tuple) and hasattr(cls, '_fields')
# %% ../nbs/03_xtras.ipynb
def flexicache(*funcs, maxsize=128):
"Like `lru_cache`, but customisable with policy `funcs`"
import asyncio
def _f(func):
cache,states = {}, [None]*len(funcs)
def _cache_logic(key, execute_func):
if key in cache:
result,states = cache[key]
if not any(f(state) for f,state in zip(funcs, states)):
cache[key] = cache.pop(key)
return result
del cache[key]
try: newres = execute_func()
except:
if key not in cache: raise
cache[key] = cache.pop(key)
return result
cache[key] = (newres, [f(None) for f in funcs])
if len(cache) > maxsize: cache.popitem()
return newres
@wraps(func)
def wrapper(*args, **kwargs):
return _cache_logic(f"{args} // {kwargs}", lambda: func(*args, **kwargs))
@wraps(func)
async def async_wrapper(*args, **kwargs):
return await _cache_logic(f"{args} // {kwargs}", lambda: asyncio.ensure_future(func(*args, **kwargs)))
return async_wrapper if asyncio.iscoroutinefunction(func) else wrapper
return _f
# %% ../nbs/03_xtras.ipynb
def time_policy(seconds):
"A `flexicache` policy that expires cached items after `seconds` have passed"
def policy(last_time):
now = time()
return now if last_time is None or now-last_time>seconds else None
return policy
# %% ../nbs/03_xtras.ipynb
def mtime_policy(filepath):
"A `flexicache` policy that expires cached items after `filepath` modified-time changes"
def policy(mtime):
current_mtime = getmtime(filepath)
return current_mtime if mtime is None or current_mtime>mtime else None
return policy
# %% ../nbs/03_xtras.ipynb
def timed_cache(seconds=60, maxsize=128):
"Like `lru_cache`, but also with time-based eviction"
return flexicache(time_policy(seconds), maxsize=maxsize)
---
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/03a_parallel.ipynb.
# %% auto 0
__all__ = ['threaded', 'startthread', 'startproc', 'parallelable', 'ThreadPoolExecutor', 'ProcessPoolExecutor', 'parallel',
'add_one', 'run_procs', 'parallel_gen']
# %% ../nbs/03a_parallel.ipynb
from .imports import *
from .basics import *
from .foundation import *
from .meta import *
from .xtras import *
from functools import wraps
import concurrent.futures,time
from multiprocessing import Process,Queue,Manager,set_start_method,get_all_start_methods,get_context
from threading import Thread
try:
if sys.platform == 'darwin' and IN_NOTEBOOK: set_start_method("fork")
except: pass
# %% ../nbs/03a_parallel.ipynb
def threaded(process=False):
"Run `f` in a `Thread` (or `Process` if `process=True`), and returns it"
def _r(f):
def g(_obj_td, *args, **kwargs):
res = f(*args, **kwargs)
_obj_td.result = res
@wraps(f)
def _f(*args, **kwargs):
res = (Thread,Process)[process](target=g, args=args, kwargs=kwargs)
res._args = (res,)+res._args
res.start()
return res
return _f
if callable(process):
o = process
process = False
return _r(o)
return _r
# %% ../nbs/03a_parallel.ipynb
def startthread(f):
"Like `threaded`, but start thread immediately"
return threaded(f)()
# %% ../nbs/03a_parallel.ipynb
def startproc(f):
"Like `threaded(True)`, but start Process immediately"
return threaded(True)(f)()
# %% ../nbs/03a_parallel.ipynb
def _call(lock, pause, n, g, item):
l = False
if pause:
try:
l = lock.acquire(timeout=pause*(n+2))
time.sleep(pause)
finally:
if l: lock.release()
return g(item)
# %% ../nbs/03a_parallel.ipynb
def parallelable(param_name, num_workers, f=None):
f_in_main = f == None or sys.modules[f.__module__].__name__ == "__main__"
if sys.platform == "win32" and IN_NOTEBOOK and num_workers > 0 and f_in_main:
print("Due to IPython and Windows limitation, python multiprocessing isn't available now.")
print(f"So `{param_name}` has to be changed to 0 to avoid getting stuck")
return False
return True
# %% ../nbs/03a_parallel.ipynb
class ThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
"Same as Python's ThreadPoolExecutor, except can pass `max_workers==0` for serial execution"
def __init__(self, max_workers=defaults.cpus, on_exc=print, pause=0, **kwargs):
if max_workers is None: max_workers=defaults.cpus
store_attr()
self.not_parallel = max_workers==0
if self.not_parallel: max_workers=1
super().__init__(max_workers, **kwargs)
def map(self, f, items, *args, timeout=None, chunksize=1, **kwargs):
if self.not_parallel == False: self.lock = Manager().Lock()
g = partial(f, *args, **kwargs)
if self.not_parallel: return map(g, items)
_g = partial(_call, self.lock, self.pause, self.max_workers, g)
try: return super().map(_g, items, timeout=timeout, chunksize=chunksize)
except Exception as e: self.on_exc(e)
# %% ../nbs/03a_parallel.ipynb
@delegates()
class ProcessPoolExecutor(concurrent.futures.ProcessPoolExecutor):
"Same as Python's ProcessPoolExecutor, except can pass `max_workers==0` for serial execution"
def __init__(self, max_workers=defaults.cpus, on_exc=print, pause=0, **kwargs):
if max_workers is None: max_workers=defaults.cpus
store_attr()
self.not_parallel = max_workers==0
if self.not_parallel: max_workers=1
super().__init__(max_workers, **kwargs)
def map(self, f, items, *args, timeout=None, chunksize=1, **kwargs):
if not parallelable('max_workers', self.max_workers, f): self.max_workers = 0
self.not_parallel = self.max_workers==0
if self.not_parallel: self.max_workers=1
if self.not_parallel == False: self.lock = Manager().Lock()
g = partial(f, *args, **kwargs)
if self.not_parallel: return map(g, items)
_g = partial(_call, self.lock, self.pause, self.max_workers, g)
try: return super().map(_g, items, timeout=timeout, chunksize=chunksize)
except Exception as e: self.on_exc(e)
# %% ../nbs/03a_parallel.ipynb
try: from fastprogress import progress_bar
except: progress_bar = None
# %% ../nbs/03a_parallel.ipynb
def parallel(f, items, *args, n_workers=defaults.cpus, total=None, progress=None, pause=0,
method=None, threadpool=False, timeout=None, chunksize=1, **kwargs):
"Applies `func` in parallel to `items`, using `n_workers`"
kwpool = {}
if threadpool: pool = ThreadPoolExecutor
else:
if not method and sys.platform == 'darwin': method='fork'
if method: kwpool['mp_context'] = get_context(method)
pool = ProcessPoolExecutor
with pool(n_workers, pause=pause, **kwpool) as ex:
r = ex.map(f,items, *args, timeout=timeout, chunksize=chunksize, **kwargs)
if progress and progress_bar:
if total is None: total = len(items)
r = progress_bar(r, total=total, leave=False)
return L(r)
# %% ../nbs/03a_parallel.ipynb
def add_one(x, a=1):
# this import is necessary for multiprocessing in notebook on windows
import random
time.sleep(random.random()/80)
return x+a
# %% ../nbs/03a_parallel.ipynb
def run_procs(f, f_done, args):
"Call `f` for each item in `args` in parallel, yielding `f_done`"
processes = L(args).map(Process, args=arg0, target=f)
for o in processes: o.start()
yield from f_done()
processes.map(Self.join())
# %% ../nbs/03a_parallel.ipynb
def _f_pg(obj, queue, batch, start_idx):
for i,b in enumerate(obj(batch)): queue.put((start_idx+i,b))
def _done_pg(queue, items): return (queue.get() for _ in items)
# %% ../nbs/03a_parallel.ipynb
def parallel_gen(cls, items, n_workers=defaults.cpus, **kwargs):
"Instantiate `cls` in `n_workers` procs & call each on a subset of `items` in parallel."
if not parallelable('n_workers', n_workers): n_workers = 0
if n_workers==0:
yield from enumerate(list(cls(**kwargs)(items)))
return
batches = L(chunked(items, n_chunks=n_workers))
idx = L(itertools.accumulate(0 + batches.map(len)))
queue = Queue()
if progress_bar: items = progress_bar(items, leave=False)
f=partial(_f_pg, cls(**kwargs), queue)
done=partial(_done_pg, queue, items)
yield from run_procs(f, done, L(batches,idx).zip())
---
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/12_py2pyi.ipynb.
# %% auto 0
__all__ = ['functypes', 'imp_mod', 'has_deco', 'sig2str', 'ast_args', 'create_pyi', 'py2pyi', 'replace_wildcards']
# %% ../nbs/12_py2pyi.ipynb
import ast, sys, inspect, re, os, importlib.util, importlib.machinery
from ast import parse, unparse
from inspect import signature, getsource
from .utils import *
from .meta import delegates
# %% ../nbs/12_py2pyi.ipynb
def imp_mod(module_path, package=None):
"Import dynamically the module referenced in `fn`"
module_path = str(module_path)
module_name = os.path.splitext(os.path.basename(module_path))[0]
spec = importlib.machinery.ModuleSpec(module_name, None, origin=module_path)
module = importlib.util.module_from_spec(spec)
spec.loader = importlib.machinery.SourceFileLoader(module_name, module_path)
if package is not None: module.__package__ = package
module.__file__ = os.path.abspath(module_path)
spec.loader.exec_module(module)
return module
# %% ../nbs/12_py2pyi.ipynb
def _get_tree(mod):
return parse(getsource(mod))
# %% ../nbs/12_py2pyi.ipynb
@patch
def __repr__(self:ast.AST):
return unparse(self)
@patch
def _repr_markdown_(self:ast.AST):
return f"""```python
{self!r}
```"""
# %% ../nbs/12_py2pyi.ipynb
functypes = (ast.FunctionDef,ast.AsyncFunctionDef)
# %% ../nbs/12_py2pyi.ipynb
def _deco_id(d:Union[ast.Name,ast.Attribute])->bool:
"Get the id for AST node `d`"
return d.id if isinstance(d, ast.Name) else d.func.id
def has_deco(node:Union[ast.FunctionDef,ast.AsyncFunctionDef], name:str)->bool:
"Check if a function node `node` has a decorator named `name`"
return any(_deco_id(d)==name for d in getattr(node, 'decorator_list', []))
# %% ../nbs/12_py2pyi.ipynb
def _get_proc(node):
if isinstance(node, ast.ClassDef): return _proc_class
if not isinstance(node, functypes): return None
if not has_deco(node, 'delegates'): return _proc_body
if has_deco(node, 'patch'): return _proc_patched
return _proc_func
# %% ../nbs/12_py2pyi.ipynb
def _proc_tree(tree, mod):
for node in tree.body:
proc = _get_proc(node)
if proc: proc(node, mod)
# %% ../nbs/12_py2pyi.ipynb
def _proc_mod(mod):
tree = _get_tree(mod)
_proc_tree(tree, mod)
return tree
# %% ../nbs/12_py2pyi.ipynb
def sig2str(sig):
s = str(sig)
s = re.sub(r"<class '(.*?)'>", r'\1', s)
s = re.sub(r"dynamic_module\.", "", s)
return s
# %% ../nbs/12_py2pyi.ipynb
def ast_args(func):
sig = signature(func)
return ast.parse(f"def _{sig2str(sig)}: ...").body[0].args
# %% ../nbs/12_py2pyi.ipynb
def _body_ellip(n: ast.AST):
stidx = 1 if isinstance(n.body[0], ast.Expr) and isinstance(n.body[0].value, ast.Str) else 0
n.body[stidx:] = [ast.Expr(ast.Constant(...))]
# %% ../nbs/12_py2pyi.ipynb
def _update_func(node, sym):
"""Replace the parameter list of the source code of a function `f` with a different signature.
Replace the body of the function with just `pass`, and remove any decorators named 'delegates'"""
node.args = ast_args(sym)
_body_ellip(node)
node.decorator_list = [d for d in node.decorator_list if _deco_id(d) != 'delegates']
# %% ../nbs/12_py2pyi.ipynb
def _proc_body(node, mod): _body_ellip(node)
# %% ../nbs/12_py2pyi.ipynb
def _proc_func(node, mod):
sym = getattr(mod, node.name)
_update_func(node, sym)
# %% ../nbs/12_py2pyi.ipynb
def _proc_patched(node, mod):
ann = node.args.args[0].annotation
if hasattr(ann, 'elts'): ann = ann.elts[0]
cls = getattr(mod, ann.id)
sym = getattr(cls, node.name)
_update_func(node, sym)
# %% ../nbs/12_py2pyi.ipynb
def _proc_class(node, mod):
cls = getattr(mod, node.name)
_proc_tree(node, cls)
# %% ../nbs/12_py2pyi.ipynb
def create_pyi(fn, package=None):
"Convert `fname.py` to `fname.pyi` by removing function bodies and expanding `delegates` kwargs"
fn = Path(fn)
mod = imp_mod(fn, package=package)
tree = _proc_mod(mod)
res = unparse(tree)
fn.with_suffix('.pyi').write_text(res)
# %% ../nbs/12_py2pyi.ipynb
from .script import call_parse
# %% ../nbs/12_py2pyi.ipynb
@call_parse
def py2pyi(fname:str, # The file name to convert
package:str=None # The parent package
):
"Convert `fname.py` to `fname.pyi` by removing function bodies and expanding `delegates` kwargs"
create_pyi(fname, package)
# %% ../nbs/12_py2pyi.ipynb
@call_parse
def replace_wildcards(
# Path to the Python file to process
path: str):
"Expand wildcard imports in the specified Python file."
path = Path(path)
path.write_text(expand_wildcards(path.read_text()))
----
import sys,os,re,typing,itertools,operator,functools,math,warnings,functools,io,enum
from operator import itemgetter,attrgetter
from warnings import warn
from typing import Iterable,Generator,Sequence,Iterator,List,Set,Dict,Union,Optional,Tuple
from functools import partial,reduce
from pathlib import Path
try:
from types import WrapperDescriptorType,MethodWrapperType,MethodDescriptorType
except ImportError:
WrapperDescriptorType = type(object.__init__)
MethodWrapperType = type(object().__str__)
MethodDescriptorType = type(str.join)
from types import BuiltinFunctionType,BuiltinMethodType,MethodType,FunctionType,SimpleNamespace
NoneType = type(None)
string_classes = (str,bytes)
def is_iter(o):
"Test whether `o` can be used in a `for` loop"
#Rank 0 tensors in PyTorch are not really iterable
return isinstance(o, (Iterable,Generator)) and getattr(o,'ndim',1)
def is_coll(o):
"Test whether `o` is a collection (i.e. has a usable `len`)"
#Rank 0 tensors in PyTorch do not have working `len`
return hasattr(o, '__len__') and getattr(o,'ndim',1)
def all_equal(a,b):
"Compares whether `a` and `b` are the same length and have the same contents"
if not is_iter(b): return a==b
return all(equals(a_,b_) for a_,b_ in itertools.zip_longest(a,b))
def noop (x=None, *args, **kwargs):
"Do nothing"
return x
def noops(self, x=None, *args, **kwargs):
"Do nothing (method)"
return x
def any_is_instance(t, *args): return any(isinstance(a,t) for a in args)
def isinstance_str(x, cls_name):
"Like `isinstance`, except takes a type name instead of a type"
return cls_name in [t.__name__ for t in type(x).__mro__]
def array_equal(a,b):
if hasattr(a, '__array__'): a = a.__array__()
if hasattr(b, '__array__'): b = b.__array__()
if isinstance_str(a, 'ndarray') and isinstance_str(b, 'ndarray'): return (a==b).all()
return all_equal(a,b)
def df_equal(a,b): return a.equals(b) if isinstance_str(a, 'NDFrame') else b.equals(a)
def equals(a,b):
"Compares `a` and `b` for equality; supports sublists, tensors and arrays too"
if (a is None) ^ (b is None): return False
if any_is_instance(type,a,b): return a==b
if hasattr(a, '__array_eq__'): return a.__array_eq__(b)
if hasattr(b, '__array_eq__'): return b.__array_eq__(a)
cmp = (array_equal if isinstance_str(a, 'ndarray') or isinstance_str(b, 'ndarray') else
array_equal if isinstance_str(a, 'Tensor') or isinstance_str(b, 'Tensor') else
df_equal if isinstance_str(a, 'NDFrame') or isinstance_str(b, 'NDFrame') else
operator.eq if any_is_instance((str,dict,set), a, b) else
all_equal if is_iter(a) or is_iter(b) else
operator.eq)
return cmp(a,b)
def ipython_shell():
"Same as `get_ipython` but returns `False` if not in IPython"
try: return get_ipython()
except NameError: return False
def in_ipython():
"Check if code is running in some kind of IPython environment"
return bool(ipython_shell())
def in_colab():
"Check if the code is running in Google Colaboratory"
return 'google.colab' in sys.modules
def in_jupyter():
"Check if the code is running in a jupyter notebook"
if not in_ipython(): return False
return ipython_shell().__class__.__name__ == 'ZMQInteractiveShell'
def in_notebook():
"Check if the code is running in a jupyter notebook"
return in_colab() or in_jupyter()
IN_IPYTHON,IN_JUPYTER,IN_COLAB,IN_NOTEBOOK = in_ipython(),in_jupyter(),in_colab(),in_notebook()
def remove_prefix(text, prefix):
"Temporary until py39 is a prereq"
return text[text.startswith(prefix) and len(prefix):]
def remove_suffix(text, suffix):
"Temporary until py39 is a prereq"
return text[:-len(suffix)] if text.endswith(suffix) else text
----
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/04_dispatch.ipynb.
# %% ../nbs/04_dispatch.ipynb 1
from __future__ import annotations
from .imports import *
from .foundation import *
from .utils import *
from collections import defaultdict
# %% auto 0
__all__ = ['typedispatch', 'lenient_issubclass', 'sorted_topologically', 'TypeDispatch', 'DispatchReg', 'retain_meta',
'default_set_meta', 'cast', 'retain_type', 'retain_types', 'explode_types']
# %% ../nbs/04_dispatch.ipynb
def lenient_issubclass(cls, types):
"If possible return whether `cls` is a subclass of `types`, otherwise return False."
if cls is object and types is not object: return False # treat `object` as highest level
try: return isinstance(cls, types) or issubclass(cls, types)
except: return False
# %% ../nbs/04_dispatch.ipynb
def sorted_topologically(iterable, *, cmp=operator.lt, reverse=False):
"Return a new list containing all items from the iterable sorted topologically"
l,res = L(list(iterable)),[]
for _ in range(len(l)):
t = l.reduce(lambda x,y: y if cmp(y,x) else x)
res.append(t), l.remove(t)
return res[::-1] if reverse else res
# %% ../nbs/04_dispatch.ipynb
def _chk_defaults(f, ann):
pass
# Implementation removed until we can figure out how to do this without `inspect` module
# try: # Some callables don't have signatures, so ignore those errors
# params = list(inspect.signature(f).parameters.values())[:min(len(ann),2)]
# if any(p.default!=inspect.Parameter.empty for p in params):
# warn(f"{f.__name__} has default params. These will be ignored.")
# except ValueError: pass
# %% ../nbs/04_dispatch.ipynb
def _p2_anno(f):
"Get the 1st 2 annotations of `f`, defaulting to `object`"
hints = type_hints(f)
ann = [o for n,o in hints.items() if n!='return']
if callable(f): _chk_defaults(f, ann)
while len(ann)<2: ann.append(object)
return ann[:2]
# %% ../nbs/04_dispatch.ipynb
class _TypeDict:
def __init__(self): self.d,self.cache = {},{}
def _reset(self):
self.d = {k:self.d[k] for k in sorted_topologically(self.d, cmp=lenient_issubclass)}
self.cache = {}
def add(self, t, f):
"Add type `t` and function `f`"
if not isinstance(t, tuple): t = tuple(L(union2tuple(t)))
for t_ in t: self.d[t_] = f
self._reset()
def all_matches(self, k):
"Find first matching type that is a super-class of `k`"
if k not in self.cache:
types = [f for f in self.d if lenient_issubclass(k,f)]
self.cache[k] = [self.d[o] for o in types]
return self.cache[k]
def __getitem__(self, k):
"Find first matching type that is a super-class of `k`"
res = self.all_matches(k)
return res[0] if len(res) else None
def __repr__(self): return self.d.__repr__()
def first(self): return first(self.d.values())
# %% ../nbs/04_dispatch.ipynb
class TypeDispatch:
"Dictionary-like object; `__getitem__` matches keys of types using `issubclass`"
def __init__(self, funcs=(), bases=()):
self.funcs,self.bases = _TypeDict(),L(bases).filter(is_not(None))
for o in L(funcs): self.add(o)
self.inst = None
self.owner = None
def add(self, f):
"Add type `t` and function `f`"
if isinstance(f, staticmethod): a0,a1 = _p2_anno(f.__func__)
else: a0,a1 = _p2_anno(f)
t = self.funcs.d.get(a0)
if t is None:
t = _TypeDict()
self.funcs.add(a0, t)
t.add(a1, f)
def first(self):
"Get first function in ordered dict of type:func."
return self.funcs.first().first()
def returns(self, x):
"Get the return type of annotation of `x`."
return anno_ret(self[type(x)])
def _attname(self,k): return getattr(k,'__name__',str(k))
def __repr__(self):
r = [f'({self._attname(k)},{self._attname(l)}) -> {getattr(v, "__name__", type(v).__name__)}'
for k in self.funcs.d for l,v in self.funcs[k].d.items()]
r = r + [o.__repr__() for o in self.bases]
return '\n'.join(r)
def __call__(self, *args, **kwargs):
ts = L(args).map(type)[:2]
f = self[tuple(ts)]
if not f: return args[0]
if isinstance(f, staticmethod): f = f.__func__
elif self.inst is not None: f = MethodType(f, self.inst)
elif self.owner is not None: f = MethodType(f, self.owner)
return f(*args, **kwargs)
def __get__(self, inst, owner):
self.inst = inst
self.owner = owner
return self
def __getitem__(self, k):
"Find first matching type that is a super-class of `k`"
k = L(k)
while len(k)<2: k.append(object)
r = self.funcs.all_matches(k[0])
for t in r:
o = t[k[1]]
if o is not None: return o
for base in self.bases:
res = base[k]
if res is not None: return res
return None
# %% ../nbs/04_dispatch.ipynb
class DispatchReg:
"A global registry for `TypeDispatch` objects keyed by function name"
def __init__(self): self.d = defaultdict(TypeDispatch)
def __call__(self, f):
if isinstance(f, (classmethod, staticmethod)): nm = f'{f.__func__.__qualname__}'
else: nm = f'{f.__qualname__}'
if isinstance(f, classmethod): f=f.__func__
self.d[nm].add(f)
return self.d[nm]
typedispatch = DispatchReg()
# %% ../nbs/04_dispatch.ipynb
_all_=['cast']
# %% ../nbs/04_dispatch.ipynb
def retain_meta(x, res, as_copy=False):
"Call `res.set_meta(x)`, if it exists"
if hasattr(res,'set_meta'): res.set_meta(x, as_copy=as_copy)
return res
# %% ../nbs/04_dispatch.ipynb
def default_set_meta(self, x, as_copy=False):
"Copy over `_meta` from `x` to `res`, if it's missing"
if hasattr(x, '_meta') and not hasattr(self, '_meta'):
meta = x._meta
if as_copy: meta = copy(meta)
self._meta = meta
return self
# %% ../nbs/04_dispatch.ipynb
@typedispatch
def cast(x, typ):
"cast `x` to type `typ` (may also change `x` inplace)"
res = typ._before_cast(x) if hasattr(typ, '_before_cast') else x
if risinstance('ndarray', res): res = res.view(typ)
elif hasattr(res, 'as_subclass'): res = res.as_subclass(typ)
else:
try: res.__class__ = typ
except: res = typ(res)
return retain_meta(x, res)
# %% ../nbs/04_dispatch.ipynb
def retain_type(new, old=None, typ=None, as_copy=False):
"Cast `new` to type of `old` or `typ` if it's a superclass"
# e.g. old is TensorImage, new is Tensor - if not subclass then do nothing
if new is None: return
assert old is not None or typ is not None
if typ is None:
if not isinstance(old, type(new)): return new
typ = old if isinstance(old,type) else type(old)
# Do nothing the new type is already an instance of requested type (i.e. same type)
if typ==NoneType or isinstance(new, typ): return new
return retain_meta(old, cast(new, typ), as_copy=as_copy)
# %% ../nbs/04_dispatch.ipynb
def retain_types(new, old=None, typs=None):
"Cast each item of `new` to type of matching item in `old` if it's a superclass"
if not is_listy(new): return retain_type(new, old, typs)
if typs is not None:
if isinstance(typs, dict):
t = first(typs.keys())
typs = typs[t]
else: t,typs = typs,None
else: t = type(old) if old is not None and isinstance(old,type(new)) else type(new)
return t(L(new, old, typs).map_zip(retain_types, cycled=True))
# %% ../nbs/04_dispatch.ipynb
def explode_types(o):
"Return the type of `o`, potentially in nested dictionaries for thing that are listy"
if not is_listy(o): return type(o)
return {type(o): [explode_types(o_) for o_ in o]}
---
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/06_docments.ipynb.
# %% ../nbs/06_docments.ipynb 2
from __future__ import annotations
import re
from tokenize import tokenize,COMMENT
from ast import parse,FunctionDef,AsyncFunctionDef,AnnAssign
from io import BytesIO
from textwrap import dedent
from types import SimpleNamespace
from inspect import getsource,isfunction,ismethod,isclass,signature,Parameter
from dataclasses import dataclass, is_dataclass
from .utils import *
from .meta import delegates
from . import docscrape
from inspect import isclass,getdoc
# %% auto 0
__all__ = ['empty', 'docstring', 'parse_docstring', 'isdataclass', 'get_dataclass_source', 'get_source', 'get_name', 'qual_name',
'docments']
# %% ../nbs/06_docments.ipynb
def docstring(sym):
"Get docstring for `sym` for functions ad classes"
if isinstance(sym, str): return sym
res = getdoc(sym)
if not res and isclass(sym): res = getdoc(sym.__init__)
return res or ""
# %% ../nbs/06_docments.ipynb
def parse_docstring(sym):
"Parse a numpy-style docstring in `sym`"
docs = docstring(sym)
return AttrDict(**docscrape.NumpyDocString(docstring(sym)))
# %% ../nbs/06_docments.ipynb
def isdataclass(s):
"Check if `s` is a dataclass but not a dataclass' instance"
return is_dataclass(s) and isclass(s)
# %% ../nbs/06_docments.ipynb
def get_dataclass_source(s):
"Get source code for dataclass `s`"
return getsource(s) if not getattr(s, "__module__") == '__main__' else ""
# %% ../nbs/06_docments.ipynb
def get_source(s):
"Get source code for string, function object or dataclass `s`"
return getsource(s) if isfunction(s) or ismethod(s) else get_dataclass_source(s) if isdataclass(s) else s
# %% ../nbs/06_docments.ipynb
def _parses(s):
"Parse Python code in string, function object or dataclass `s`"
return parse(dedent(get_source(s)))
def _tokens(s):
"Tokenize Python code in string or function object `s`"
s = get_source(s)
return tokenize(BytesIO(s.encode('utf-8')).readline)
_clean_re = re.compile(r'^\s*#(.*)\s*$')
def _clean_comment(s):
res = _clean_re.findall(s)
return res[0] if res else None
def _param_locs(s, returns=True):
"`dict` of parameter line numbers to names"
body = _parses(s).body
if len(body)==1: #or not isinstance(body[0], FunctionDef): return None
defn = body[0]
if isinstance(defn, (FunctionDef, AsyncFunctionDef)):
res = {arg.lineno:arg.arg for arg in defn.args.args}
if returns and defn.returns: res[defn.returns.lineno] = 'return'
return res
elif isdataclass(s):
res = {arg.lineno:arg.target.id for arg in defn.body if isinstance(arg, AnnAssign)}
return res
return None
# %% ../nbs/06_docments.ipynb
empty = Parameter.empty
# %% ../nbs/06_docments.ipynb
def _get_comment(line, arg, comments, parms):
if line in comments: return comments[line].strip()
line -= 1
res = []
while line and line in comments and line not in parms:
res.append(comments[line])
line -= 1
return dedent('\n'.join(reversed(res))) if res else None
def _get_full(anno, name, default, docs):
if anno==empty and default!=empty: anno = type(default)
return AttrDict(docment=docs.get(name), anno=anno, default=default)
# %% ../nbs/06_docments.ipynb
def _merge_doc(dm, npdoc):
if not npdoc: return dm
if not dm.anno or dm.anno==empty: dm.anno = npdoc.type
if not dm.docment: dm.docment = '\n'.join(npdoc.desc)
return dm
def _merge_docs(dms, npdocs):
npparams = npdocs['Parameters']
params = {nm:_merge_doc(dm,npparams.get(nm,None)) for nm,dm in dms.items()}
if 'return' in dms: params['return'] = _merge_doc(dms['return'], npdocs['Returns'])
return params
# %% ../nbs/06_docments.ipynb
def _get_property_name(p):
"Get the name of property `p`"
if hasattr(p, 'fget'):
return p.fget.func.__qualname__ if hasattr(p.fget, 'func') else p.fget.__qualname__
else: return next(iter(re.findall(r'\'(.*)\'', str(p)))).split('.')[-1]
# %% ../nbs/06_docments.ipynb
def get_name(obj):
"Get the name of `obj`"
if hasattr(obj, '__name__'): return obj.__name__
elif getattr(obj, '_name', False): return obj._name
elif hasattr(obj,'__origin__'): return str(obj.__origin__).split('.')[-1] #for types
elif type(obj)==property: return _get_property_name(obj)
else: return str(obj).split('.')[-1]
# %% ../nbs/06_docments.ipynb
def qual_name(obj):
"Get the qualified name of `obj`"
if hasattr(obj,'__qualname__'): return obj.__qualname__
if ismethod(obj): return f"{get_name(obj.__self__)}.{get_name(fn)}"
return get_name(obj)
# %% ../nbs/06_docments.ipynb
def _docments(s, returns=True, eval_str=False):
"`dict` of parameter names to 'docment-style' comments in function or string `s`"
nps = parse_docstring(s)
if isclass(s) and not is_dataclass(s): s = s.__init__ # Constructor for a class
comments = {o.start[0]:_clean_comment(o.string) for o in _tokens(s) if o.type==COMMENT}
parms = _param_locs(s, returns=returns) or {}
docs = {arg:_get_comment(line, arg, comments, parms) for line,arg in parms.items()}
if isinstance(s,str): s = eval(s)
sig = signature(s)
res = {arg:_get_full(p.annotation, p.name, p.default, docs) for arg,p in sig.parameters.items()}
if returns: res['return'] = _get_full(sig.return_annotation, 'return', empty, docs)
res = _merge_docs(res, nps)
if eval_str:
hints = type_hints(s)
for k,v in res.items():
if k in hints: v['anno'] = hints.get(k)
return res
# %% ../nbs/06_docments.ipynb
@delegates(_docments)
def docments(elt, full=False, **kwargs):
"Generates a `docment`"
r = {}
params = set(signature(elt).parameters)
params.add('return')
def _update_docments(f, r):
if hasattr(f, '__delwrap__'): _update_docments(f.__delwrap__, r)
r.update({k:v for k,v in _docments(f, **kwargs).items() if k in params
and (v.get('docment', None) or not nested_idx(r, k, 'docment'))})
_update_docments(elt, r)
if not full: r = {k:v['docment'] for k,v in r.items()}
return AttrDict(r)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment