Created
August 19, 2024 05:25
-
-
Save EmbraceLife/87f69df227eafef3f4fc3774184404e2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/07_meta.ipynb. | |
# %% auto 0 | |
__all__ = ['test_sig', 'FixSigMeta', 'PrePostInitMeta', 'AutoInit', 'NewChkMeta', 'BypassNewMeta', 'empty2none', 'anno_dict', | |
'use_kwargs_dict', 'use_kwargs', 'delegates', 'method', 'funcs_kwargs'] | |
# %% ../nbs/07_meta.ipynb | |
from .imports import * | |
from .test import * | |
from contextlib import contextmanager | |
from copy import copy | |
import inspect | |
# %% ../nbs/07_meta.ipynb | |
def test_sig(f, b): | |
"Test the signature of an object" | |
test_eq(str(inspect.signature(f)), b) | |
# %% ../nbs/07_meta.ipynb | |
def _rm_self(sig): | |
sigd = dict(sig.parameters) | |
sigd.pop('self') | |
return sig.replace(parameters=sigd.values()) | |
# %% ../nbs/07_meta.ipynb | |
class FixSigMeta(type): | |
"A metaclass that fixes the signature on classes that override `__new__`" | |
def __new__(cls, name, bases, dict): | |
res = super().__new__(cls, name, bases, dict) | |
if res.__init__ is not object.__init__: res.__signature__ = _rm_self(inspect.signature(res.__init__)) | |
return res | |
# %% ../nbs/07_meta.ipynb | |
class PrePostInitMeta(FixSigMeta): | |
"A metaclass that calls optional `__pre_init__` and `__post_init__` methods" | |
def __call__(cls, *args, **kwargs): | |
res = cls.__new__(cls) | |
if type(res)==cls: | |
if hasattr(res,'__pre_init__'): res.__pre_init__(*args,**kwargs) | |
res.__init__(*args,**kwargs) | |
if hasattr(res,'__post_init__'): res.__post_init__(*args,**kwargs) | |
return res | |
# %% ../nbs/07_meta.ipynb | |
class AutoInit(metaclass=PrePostInitMeta): | |
"Same as `object`, but no need for subclasses to call `super().__init__`" | |
def __pre_init__(self, *args, **kwargs): super().__init__(*args, **kwargs) | |
# %% ../nbs/07_meta.ipynb | |
class NewChkMeta(FixSigMeta): | |
"Metaclass to avoid recreating object passed to constructor" | |
def __call__(cls, x=None, *args, **kwargs): | |
if not args and not kwargs and x is not None and isinstance(x,cls): return x | |
res = super().__call__(*((x,) + args), **kwargs) | |
return res | |
# %% ../nbs/07_meta.ipynb | |
class BypassNewMeta(FixSigMeta): | |
"Metaclass: casts `x` to this class if it's of type `cls._bypass_type`" | |
def __call__(cls, x=None, *args, **kwargs): | |
if hasattr(cls, '_new_meta'): x = cls._new_meta(x, *args, **kwargs) | |
elif not isinstance(x,getattr(cls,'_bypass_type',object)) or len(args) or len(kwargs): | |
x = super().__call__(*((x,)+args), **kwargs) | |
if cls!=x.__class__: x.__class__ = cls | |
return x | |
# %% ../nbs/07_meta.ipynb | |
def empty2none(p): | |
"Replace `Parameter.empty` with `None`" | |
return None if p==inspect.Parameter.empty else p | |
# %% ../nbs/07_meta.ipynb | |
def anno_dict(f): | |
"`__annotation__ dictionary with `empty` cast to `None`, returning empty if doesn't exist" | |
return {k:empty2none(v) for k,v in getattr(f, '__annotations__', {}).items()} | |
# %% ../nbs/07_meta.ipynb | |
def _mk_param(n,d=None): return inspect.Parameter(n, inspect.Parameter.KEYWORD_ONLY, default=d) | |
# %% ../nbs/07_meta.ipynb | |
def use_kwargs_dict(keep=False, **kwargs): | |
"Decorator: replace `**kwargs` in signature with `names` params" | |
def _f(f): | |
sig = inspect.signature(f) | |
sigd = dict(sig.parameters) | |
k = sigd.pop('kwargs') | |
s2 = {n:_mk_param(n,d) for n,d in kwargs.items() if n not in sigd} | |
sigd.update(s2) | |
if keep: sigd['kwargs'] = k | |
f.__signature__ = sig.replace(parameters=sigd.values()) | |
return f | |
return _f | |
# %% ../nbs/07_meta.ipynb | |
def use_kwargs(names, keep=False): | |
"Decorator: replace `**kwargs` in signature with `names` params" | |
def _f(f): | |
sig = inspect.signature(f) | |
sigd = dict(sig.parameters) | |
k = sigd.pop('kwargs') | |
s2 = {n:_mk_param(n) for n in names if n not in sigd} | |
sigd.update(s2) | |
if keep: sigd['kwargs'] = k | |
f.__signature__ = sig.replace(parameters=sigd.values()) | |
return f | |
return _f | |
# %% ../nbs/07_meta.ipynb | |
def delegates(to:FunctionType=None, # Delegatee | |
keep=False, # Keep `kwargs` in decorated function? | |
but:list=None): # Exclude these parameters from signature | |
"Decorator: replace `**kwargs` in signature with params from `to`" | |
if but is None: but = [] | |
def _f(f): | |
if to is None: to_f,from_f = f.__base__.__init__,f.__init__ | |
else: to_f,from_f = to.__init__ if isinstance(to,type) else to,f | |
from_f = getattr(from_f,'__func__',from_f) | |
to_f = getattr(to_f,'__func__',to_f) | |
if hasattr(from_f,'__delwrap__'): return f | |
sig = inspect.signature(from_f) | |
sigd = dict(sig.parameters) | |
k = sigd.pop('kwargs') | |
s2 = {k:v.replace(kind=inspect.Parameter.KEYWORD_ONLY) for k,v in inspect.signature(to_f).parameters.items() | |
if v.default != inspect.Parameter.empty and k not in sigd and k not in but} | |
anno = {k:v for k,v in getattr(to_f, "__annotations__", {}).items() if k not in sigd and k not in but} | |
sigd.update(s2) | |
if keep: sigd['kwargs'] = k | |
else: from_f.__delwrap__ = to_f | |
from_f.__signature__ = sig.replace(parameters=sigd.values()) | |
if hasattr(from_f, '__annotations__'): from_f.__annotations__.update(anno) | |
return f | |
return _f | |
# %% ../nbs/07_meta.ipynb | |
def method(f): | |
"Mark `f` as a method" | |
# `1` is a dummy instance since Py3 doesn't allow `None` any more | |
return MethodType(f, 1) | |
# %% ../nbs/07_meta.ipynb | |
def _funcs_kwargs(cls, as_method): | |
old_init = cls.__init__ | |
def _init(self, *args, **kwargs): | |
for k in cls._methods: | |
arg = kwargs.pop(k,None) | |
if arg is not None: | |
if as_method: arg = method(arg) | |
if isinstance(arg,MethodType): arg = MethodType(arg.__func__, self) | |
setattr(self, k, arg) | |
old_init(self, *args, **kwargs) | |
functools.update_wrapper(_init, old_init) | |
cls.__init__ = use_kwargs(cls._methods)(_init) | |
if hasattr(cls, '__signature__'): cls.__signature__ = _rm_self(inspect.signature(cls.__init__)) | |
return cls | |
# %% ../nbs/07_meta.ipynb | |
def funcs_kwargs(as_method=False): | |
"Replace methods in `cls._methods` with those from `kwargs`" | |
if callable(as_method): return _funcs_kwargs(as_method, False) | |
return partial(_funcs_kwargs, as_method=as_method) | |
--- | |
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_basics.ipynb. | |
# %% auto 0 | |
__all__ = ['defaults', 'null', 'num_methods', 'rnum_methods', 'inum_methods', 'arg0', 'arg1', 'arg2', 'arg3', 'arg4', 'Self', | |
'ifnone', 'maybe_attr', 'basic_repr', 'is_array', 'listify', 'tuplify', 'true', 'NullType', 'tonull', | |
'get_class', 'mk_class', 'wrap_class', 'ignore_exceptions', 'exec_local', 'risinstance', 'ver2tuple', 'Inf', | |
'in_', 'ret_true', 'ret_false', 'stop', 'gen', 'chunked', 'otherwise', 'custom_dir', 'AttrDict', 'NS', | |
'get_annotations_ex', 'eval_type', 'type_hints', 'annotations', 'anno_ret', 'signature_ex', 'union2tuple', | |
'argnames', 'with_cast', 'store_attr', 'attrdict', 'properties', 'camel2words', 'camel2snake', 'snake2camel', | |
'class2attr', 'getcallable', 'getattrs', 'hasattrs', 'setattrs', 'try_attrs', 'GetAttrBase', 'GetAttr', | |
'delegate_attr', 'ShowPrint', 'Int', 'Str', 'Float', 'partition', 'flatten', 'concat', 'strcat', 'detuplify', | |
'replicate', 'setify', 'merge', 'range_of', 'groupby', 'last_index', 'filter_dict', 'filter_keys', | |
'filter_values', 'cycle', 'zip_cycle', 'sorted_ex', 'not_', 'argwhere', 'filter_ex', 'renumerate', 'first', | |
'only', 'nested_attr', 'nested_setdefault', 'nested_callable', 'nested_idx', 'set_nested_idx', 'val2idx', | |
'uniqueify', 'loop_first_last', 'loop_first', 'loop_last', 'first_match', 'last_match', 'fastuple', 'bind', | |
'mapt', 'map_ex', 'compose', 'maps', 'partialler', 'instantiate', 'using_attr', 'copy_func', 'patch_to', | |
'patch', 'patch_property', 'compile_re', 'ImportEnum', 'StrEnum', 'str_enum', 'ValEnum', 'Stateful', | |
'NotStr', 'PrettyString', 'even_mults', 'num_cpus', 'add_props', 'typed', 'exec_new', 'exec_import', | |
'str2bool', 'lt', 'gt', 'le', 'ge', 'eq', 'ne', 'add', 'sub', 'mul', 'truediv', 'is_', 'is_not', 'mod'] | |
# %% ../nbs/01_basics.ipynb | |
from .imports import * | |
import builtins,types,typing | |
import pprint | |
from copy import copy | |
try: from types import UnionType | |
except ImportError: UnionType = None | |
# %% ../nbs/01_basics.ipynb | |
defaults = SimpleNamespace() | |
# %% ../nbs/01_basics.ipynb | |
def ifnone(a, b): | |
"`b` if `a` is None else `a`" | |
return b if a is None else a | |
# %% ../nbs/01_basics.ipynb | |
def maybe_attr(o, attr): | |
"`getattr(o,attr,o)`" | |
return getattr(o,attr,o) | |
# %% ../nbs/01_basics.ipynb | |
def basic_repr(flds=None): | |
"Minimal `__repr__`" | |
if isinstance(flds, str): flds = re.split(', *', flds) | |
flds = list(flds or []) | |
def _f(self): | |
res = f'{type(self).__module__}.{type(self).__name__}' | |
if not flds: return f'<{res}>' | |
sig = ', '.join(f'{o}={getattr(self,o)!r}' for o in flds) | |
return f'{res}({sig})' | |
return _f | |
# %% ../nbs/01_basics.ipynb | |
def is_array(x): | |
"`True` if `x` supports `__array__` or `iloc`" | |
return hasattr(x,'__array__') or hasattr(x,'iloc') | |
# %% ../nbs/01_basics.ipynb | |
def listify(o=None, *rest, use_list=False, match=None): | |
"Convert `o` to a `list`" | |
if rest: o = (o,)+rest | |
if use_list: res = list(o) | |
elif o is None: res = [] | |
elif isinstance(o, list): res = o | |
elif isinstance(o, str) or isinstance(o, bytes) or is_array(o): res = [o] | |
elif is_iter(o): res = list(o) | |
else: res = [o] | |
if match is not None: | |
if is_coll(match): match = len(match) | |
if len(res)==1: res = res*match | |
else: assert len(res)==match, 'Match length mismatch' | |
return res | |
# %% ../nbs/01_basics.ipynb | |
def tuplify(o, use_list=False, match=None): | |
"Make `o` a tuple" | |
return tuple(listify(o, use_list=use_list, match=match)) | |
# %% ../nbs/01_basics.ipynb | |
def true(x): | |
"Test whether `x` is truthy; collections with >0 elements are considered `True`" | |
try: return bool(len(x)) | |
except: return bool(x) | |
# %% ../nbs/01_basics.ipynb | |
class NullType: | |
"An object that is `False` and can be called, chained, and indexed" | |
def __getattr__(self,*args):return null | |
def __call__(self,*args, **kwargs):return null | |
def __getitem__(self, *args):return null | |
def __bool__(self): return False | |
null = NullType() | |
# %% ../nbs/01_basics.ipynb | |
def tonull(x): | |
"Convert `None` to `null`" | |
return null if x is None else x | |
# %% ../nbs/01_basics.ipynb | |
def get_class(nm, *fld_names, sup=None, doc=None, funcs=None, anno=None, **flds): | |
"Dynamically create a class, optionally inheriting from `sup`, containing `fld_names`" | |
attrs = {} | |
if not anno: anno = {} | |
for f in fld_names: | |
attrs[f] = None | |
if f not in anno: anno[f] = typing.Any | |
for f in listify(funcs): attrs[f.__name__] = f | |
for k,v in flds.items(): attrs[k] = v | |
sup = ifnone(sup, ()) | |
if not isinstance(sup, tuple): sup=(sup,) | |
def _init(self, *args, **kwargs): | |
for i,v in enumerate(args): setattr(self, list(attrs.keys())[i], v) | |
for k,v in kwargs.items(): setattr(self,k,v) | |
attrs['_fields'] = [*fld_names,*flds.keys()] | |
def _eq(self,b): | |
return all([getattr(self,k)==getattr(b,k) for k in self._fields]) | |
if not sup: attrs['__repr__'] = basic_repr(attrs['_fields']) | |
attrs['__init__'] = _init | |
attrs['__eq__'] = _eq | |
if anno: attrs['__annotations__'] = anno | |
res = type(nm, sup, attrs) | |
if doc is not None: res.__doc__ = doc | |
return res | |
# %% ../nbs/01_basics.ipynb | |
def mk_class(nm, *fld_names, sup=None, doc=None, funcs=None, mod=None, anno=None, **flds): | |
"Create a class using `get_class` and add to the caller's module" | |
if mod is None: mod = sys._getframe(1).f_locals | |
res = get_class(nm, *fld_names, sup=sup, doc=doc, funcs=funcs, anno=anno, **flds) | |
mod[nm] = res | |
# %% ../nbs/01_basics.ipynb | |
def wrap_class(nm, *fld_names, sup=None, doc=None, funcs=None, **flds): | |
"Decorator: makes function a method of a new class `nm` passing parameters to `mk_class`" | |
def _inner(f): | |
mk_class(nm, *fld_names, sup=sup, doc=doc, funcs=listify(funcs)+[f], mod=f.__globals__, **flds) | |
return f | |
return _inner | |
# %% ../nbs/01_basics.ipynb | |
class ignore_exceptions: | |
"Context manager to ignore exceptions" | |
def __enter__(self): pass | |
def __exit__(self, *args): return True | |
# %% ../nbs/01_basics.ipynb | |
def exec_local(code, var_name): | |
"Call `exec` on `code` and return the var `var_name`" | |
loc = {} | |
exec(code, globals(), loc) | |
return loc[var_name] | |
# %% ../nbs/01_basics.ipynb | |
def risinstance(types, obj=None): | |
"Curried `isinstance` but with args reversed" | |
types = tuplify(types) | |
if obj is None: return partial(risinstance,types) | |
if any(isinstance(t,str) for t in types): | |
return any(t.__name__ in types for t in type(obj).__mro__) | |
return isinstance(obj, types) | |
# %% ../nbs/01_basics.ipynb | |
def ver2tuple(v:str)->tuple: | |
return tuple(int(o or 0) for o in re.search(r'(\d+)(?:\.(\d+))?(?:\.(\d+))?', v).groups()) | |
# %% ../nbs/01_basics.ipynb | |
class _InfMeta(type): | |
@property | |
def count(self): return itertools.count() | |
@property | |
def zeros(self): return itertools.cycle([0]) | |
@property | |
def ones(self): return itertools.cycle([1]) | |
@property | |
def nones(self): return itertools.cycle([None]) | |
# %% ../nbs/01_basics.ipynb | |
class Inf(metaclass=_InfMeta): | |
"Infinite lists" | |
pass | |
# %% ../nbs/01_basics.ipynb | |
_dumobj = object() | |
def _oper(op,a,b=_dumobj): return (lambda o:op(o,a)) if b is _dumobj else op(a,b) | |
def _mk_op(nm, mod): | |
"Create an operator using `oper` and add to the caller's module" | |
op = getattr(operator,nm) | |
def _inner(a, b=_dumobj): return _oper(op, a,b) | |
_inner.__name__ = _inner.__qualname__ = nm | |
_inner.__doc__ = f'Same as `operator.{nm}`, or returns partial if 1 arg' | |
mod[nm] = _inner | |
# %% ../nbs/01_basics.ipynb | |
def in_(x, a): | |
"`True` if `x in a`" | |
return x in a | |
operator.in_ = in_ | |
# %% ../nbs/01_basics.ipynb | |
_all_ = ['lt','gt','le','ge','eq','ne','add','sub','mul','truediv','is_','is_not','in_', 'mod'] | |
# %% ../nbs/01_basics.ipynb | |
for op in _all_: _mk_op(op, globals()) | |
# %% ../nbs/01_basics.ipynb | |
def ret_true(*args, **kwargs): | |
"Predicate: always `True`" | |
return True | |
# %% ../nbs/01_basics.ipynb | |
def ret_false(*args, **kwargs): | |
"Predicate: always `False`" | |
return False | |
# %% ../nbs/01_basics.ipynb | |
def stop(e=StopIteration): | |
"Raises exception `e` (by default `StopIteration`)" | |
raise e | |
# %% ../nbs/01_basics.ipynb | |
def gen(func, seq, cond=ret_true): | |
"Like `(func(o) for o in seq if cond(func(o)))` but handles `StopIteration`" | |
return itertools.takewhile(cond, map(func,seq)) | |
# %% ../nbs/01_basics.ipynb | |
def chunked(it, chunk_sz=None, drop_last=False, n_chunks=None): | |
"Return batches from iterator `it` of size `chunk_sz` (or return `n_chunks` total)" | |
assert bool(chunk_sz) ^ bool(n_chunks) | |
if n_chunks: chunk_sz = max(math.ceil(len(it)/n_chunks), 1) | |
if not isinstance(it, Iterator): it = iter(it) | |
while True: | |
res = list(itertools.islice(it, chunk_sz)) | |
if res and (len(res)==chunk_sz or not drop_last): yield res | |
if len(res)<chunk_sz: return | |
# %% ../nbs/01_basics.ipynb | |
def otherwise(x, tst, y): | |
"`y if tst(x) else x`" | |
return y if tst(x) else x | |
# %% ../nbs/01_basics.ipynb | |
def custom_dir(c, add): | |
"Implement custom `__dir__`, adding `add` to `cls`" | |
return object.__dir__(c) + listify(add) | |
# %% ../nbs/01_basics.ipynb | |
class AttrDict(dict): | |
"`dict` subclass that also provides access to keys as attrs" | |
def __getattr__(self,k): return self[k] if k in self else stop(AttributeError(k)) | |
def __setattr__(self, k, v): (self.__setitem__,super().__setattr__)[k[0]=='_'](k,v) | |
def __dir__(self): return super().__dir__() + list(self.keys()) | |
def _repr_markdown_(self): return f'```json\n{pprint.pformat(self, indent=2)}\n```' | |
def copy(self): return AttrDict(**self) | |
# %% ../nbs/01_basics.ipynb | |
class NS(SimpleNamespace): | |
"`SimpleNamespace` subclass that also adds `iter` and `dict` support" | |
def __iter__(self): return iter(self.__dict__) | |
def __getitem__(self,x): return self.__dict__[x] | |
def __setitem__(self,x,y): self.__dict__[x] = y | |
# %% ../nbs/01_basics.ipynb | |
def get_annotations_ex(obj, *, globals=None, locals=None): | |
"Backport of py3.10 `get_annotations` that returns globals/locals" | |
if isinstance(obj, type): | |
obj_dict = getattr(obj, '__dict__', None) | |
if obj_dict and hasattr(obj_dict, 'get'): | |
ann = obj_dict.get('__annotations__', None) | |
if isinstance(ann, types.GetSetDescriptorType): ann = None | |
else: ann = None | |
obj_globals = None | |
module_name = getattr(obj, '__module__', None) | |
if module_name: | |
module = sys.modules.get(module_name, None) | |
if module: obj_globals = getattr(module, '__dict__', None) | |
obj_locals = dict(vars(obj)) | |
unwrap = obj | |
elif isinstance(obj, types.ModuleType): | |
ann = getattr(obj, '__annotations__', None) | |
obj_globals = getattr(obj, '__dict__') | |
obj_locals,unwrap = None,None | |
elif callable(obj): | |
ann = getattr(obj, '__annotations__', None) | |
obj_globals = getattr(obj, '__globals__', None) | |
obj_locals,unwrap = None,obj | |
else: raise TypeError(f"{obj!r} is not a module, class, or callable.") | |
if ann is None: ann = {} | |
if not isinstance(ann, dict): raise ValueError(f"{obj!r}.__annotations__ is neither a dict nor None") | |
if not ann: ann = {} | |
if unwrap is not None: | |
while True: | |
if hasattr(unwrap, '__wrapped__'): | |
unwrap = unwrap.__wrapped__ | |
continue | |
if isinstance(unwrap, functools.partial): | |
unwrap = unwrap.func | |
continue | |
break | |
if hasattr(unwrap, "__globals__"): obj_globals = unwrap.__globals__ | |
if globals is None: globals = obj_globals | |
if locals is None: locals = obj_locals | |
return dict(ann), globals, locals | |
# %% ../nbs/01_basics.ipynb | |
def eval_type(t, glb, loc): | |
"`eval` a type or collection of types, if needed, for annotations in py3.10+" | |
if isinstance(t,str): | |
if '|' in t: return Union[eval_type(tuple(t.split('|')), glb, loc)] | |
return eval(t, glb, loc) | |
if isinstance(t,(tuple,list)): return type(t)([eval_type(c, glb, loc) for c in t]) | |
return t | |
# %% ../nbs/01_basics.ipynb | |
def _eval_type(t, glb, loc): | |
res = eval_type(t, glb, loc) | |
return NoneType if res is None else res | |
def type_hints(f): | |
"Like `typing.get_type_hints` but returns `{}` if not allowed type" | |
if not isinstance(f, typing._allowed_types): return {} | |
ann,glb,loc = get_annotations_ex(f) | |
return {k:_eval_type(v,glb,loc) for k,v in ann.items()} | |
# %% ../nbs/01_basics.ipynb | |
def annotations(o): | |
"Annotations for `o`, or `type(o)`" | |
res = {} | |
if not o: return res | |
res = type_hints(o) | |
if not res: res = type_hints(getattr(o,'__init__',None)) | |
if not res: res = type_hints(type(o)) | |
return res | |
# %% ../nbs/01_basics.ipynb | |
def anno_ret(func): | |
"Get the return annotation of `func`" | |
return annotations(func).get('return', None) if func else None | |
# %% ../nbs/01_basics.ipynb | |
def _ispy3_10(): return sys.version_info.major >=3 and sys.version_info.minor >=10 | |
def signature_ex(obj, eval_str:bool=False): | |
"Backport of `inspect.signature(..., eval_str=True` to <py310" | |
from inspect import Signature, Parameter, signature | |
def _eval_param(ann, k, v): | |
if k not in ann: return v | |
return Parameter(v.name, v.kind, annotation=ann[k], default=v.default) | |
if not eval_str: return signature(obj) | |
if _ispy3_10(): return signature(obj, eval_str=eval_str) | |
sig = signature(obj) | |
if sig is None: return None | |
ann = type_hints(obj) | |
params = [_eval_param(ann,k,v) for k,v in sig.parameters.items()] | |
return Signature(params, return_annotation=sig.return_annotation) | |
# %% ../nbs/01_basics.ipynb | |
def union2tuple(t): | |
if (getattr(t, '__origin__', None) is Union | |
or (UnionType and isinstance(t, UnionType))): return t.__args__ | |
return t | |
# %% ../nbs/01_basics.ipynb | |
def argnames(f, frame=False): | |
"Names of arguments to function or frame `f`" | |
code = getattr(f, 'f_code' if frame else '__code__') | |
return code.co_varnames[:code.co_argcount+code.co_kwonlyargcount] | |
# %% ../nbs/01_basics.ipynb | |
def with_cast(f): | |
"Decorator which uses any parameter annotations as preprocessing functions" | |
anno, out_anno, params = annotations(f), anno_ret(f), argnames(f) | |
c_out = ifnone(out_anno, noop) | |
defaults = dict(zip(reversed(params), reversed(f.__defaults__ or {}))) | |
@functools.wraps(f) | |
def _inner(*args, **kwargs): | |
args = list(args) | |
for i,v in enumerate(params): | |
if v in anno: | |
c = anno[v] | |
if v in kwargs: kwargs[v] = c(kwargs[v]) | |
elif i<len(args): args[i] = c(args[i]) | |
elif v in defaults: kwargs[v] = c(defaults[v]) | |
return c_out(f(*args, **kwargs)) | |
return _inner | |
# %% ../nbs/01_basics.ipynb | |
def _store_attr(self, anno, **attrs): | |
stored = getattr(self, '__stored_args__', None) | |
for n,v in attrs.items(): | |
if n in anno: v = anno[n](v) | |
setattr(self, n, v) | |
if stored is not None: stored[n] = v | |
# %% ../nbs/01_basics.ipynb | |
def store_attr(names=None, self=None, but='', cast=False, store_args=None, **attrs): | |
"Store params named in comma-separated `names` from calling context into attrs in `self`" | |
fr = sys._getframe(1) | |
args = argnames(fr, True) | |
if self: args = ('self', *args) | |
else: self = fr.f_locals[args[0]] | |
if store_args is None: store_args = not hasattr(self,'__slots__') | |
if store_args and not hasattr(self, '__stored_args__'): self.__stored_args__ = {} | |
anno = annotations(self) if cast else {} | |
if names and isinstance(names,str): names = re.split(', *', names) | |
ns = names if names is not None else getattr(self, '__slots__', args[1:]) | |
added = {n:fr.f_locals[n] for n in ns} | |
attrs = {**attrs, **added} | |
if isinstance(but,str): but = re.split(', *', but) | |
attrs = {k:v for k,v in attrs.items() if k not in but} | |
return _store_attr(self, anno, **attrs) | |
# %% ../nbs/01_basics.ipynb | |
def attrdict(o, *ks, default=None): | |
"Dict from each `k` in `ks` to `getattr(o,k)`" | |
return {k:getattr(o, k, default) for k in ks} | |
# %% ../nbs/01_basics.ipynb | |
def properties(cls, *ps): | |
"Change attrs in `cls` with names in `ps` to properties" | |
for p in ps: setattr(cls,p,property(getattr(cls,p))) | |
# %% ../nbs/01_basics.ipynb | |
_c2w_re = re.compile(r'((?<=[a-z])[A-Z]|(?<!\A)[A-Z](?=[a-z]))') | |
_camel_re1 = re.compile('(.)([A-Z][a-z]+)') | |
_camel_re2 = re.compile('([a-z0-9])([A-Z])') | |
# %% ../nbs/01_basics.ipynb | |
def camel2words(s, space=' '): | |
"Convert CamelCase to 'spaced words'" | |
return re.sub(_c2w_re, rf'{space}\1', s) | |
# %% ../nbs/01_basics.ipynb | |
def camel2snake(name): | |
"Convert CamelCase to snake_case" | |
s1 = re.sub(_camel_re1, r'\1_\2', name) | |
return re.sub(_camel_re2, r'\1_\2', s1).lower() | |
# %% ../nbs/01_basics.ipynb | |
def snake2camel(s): | |
"Convert snake_case to CamelCase" | |
return ''.join(s.title().split('_')) | |
# %% ../nbs/01_basics.ipynb | |
def class2attr(self, cls_name): | |
"Return the snake-cased name of the class; strip ending `cls_name` if it exists." | |
return camel2snake(re.sub(rf'{cls_name}$', '', self.__class__.__name__) or cls_name.lower()) | |
# %% ../nbs/01_basics.ipynb | |
def getcallable(o, attr): | |
"Calls `getattr` with a default of `noop`" | |
return getattr(o, attr, noop) | |
# %% ../nbs/01_basics.ipynb | |
def getattrs(o, *attrs, default=None): | |
"List of all `attrs` in `o`" | |
return [getattr(o,attr,default) for attr in attrs] | |
# %% ../nbs/01_basics.ipynb | |
def hasattrs(o,attrs): | |
"Test whether `o` contains all `attrs`" | |
return all(hasattr(o,attr) for attr in attrs) | |
# %% ../nbs/01_basics.ipynb | |
def setattrs(dest, flds, src): | |
f = dict.get if isinstance(src, dict) else getattr | |
flds = re.split(r",\s*", flds) | |
for fld in flds: setattr(dest, fld, f(src, fld)) | |
# %% ../nbs/01_basics.ipynb | |
def try_attrs(obj, *attrs): | |
"Return first attr that exists in `obj`" | |
for att in attrs: | |
try: return getattr(obj, att) | |
except: pass | |
raise AttributeError(attrs) | |
# %% ../nbs/01_basics.ipynb | |
class GetAttrBase: | |
"Basic delegation of `__getattr__` and `__dir__`" | |
_attr=noop | |
def __getattr__(self,k): | |
if k[0]=='_' or k==self._attr: return super().__getattr__(k) | |
return self._getattr(getattr(self, self._attr)[k]) | |
def __dir__(self): return custom_dir(self, getattr(self, self._attr)) | |
# %% ../nbs/01_basics.ipynb | |
class GetAttr: | |
"Inherit from this to have all attr accesses in `self._xtra` passed down to `self.default`" | |
_default='default' | |
def _component_attr_filter(self,k): | |
if k.startswith('__') or k in ('_xtra',self._default): return False | |
xtra = getattr(self,'_xtra',None) | |
return xtra is None or k in xtra | |
def _dir(self): return [k for k in dir(getattr(self,self._default)) if self._component_attr_filter(k)] | |
def __getattr__(self,k): | |
if self._component_attr_filter(k): | |
attr = getattr(self,self._default,None) | |
if attr is not None: return getattr(attr,k) | |
raise AttributeError(k) | |
def __dir__(self): return custom_dir(self,self._dir()) | |
# def __getstate__(self): return self.__dict__ | |
def __setstate__(self,data): self.__dict__.update(data) | |
# %% ../nbs/01_basics.ipynb | |
def delegate_attr(self, k, to): | |
"Use in `__getattr__` to delegate to attr `to` without inheriting from `GetAttr`" | |
if k.startswith('_') or k==to: raise AttributeError(k) | |
try: return getattr(getattr(self,to), k) | |
except AttributeError: raise AttributeError(k) from None | |
# %% ../nbs/01_basics.ipynb | |
class ShowPrint: | |
"Base class that prints for `show`" | |
def show(self, *args, **kwargs): print(str(self)) | |
# %% ../nbs/01_basics.ipynb | |
class Int(int,ShowPrint): | |
"An extensible `int`" | |
pass | |
# %% ../nbs/01_basics.ipynb | |
class Str(str,ShowPrint): | |
"An extensible `str`" | |
pass | |
class Float(float,ShowPrint): | |
"An extensible `float`" | |
pass | |
# %% ../nbs/01_basics.ipynb | |
def partition(coll, f): | |
"Partition a collection by a predicate" | |
ts,fs = [],[] | |
for o in coll: (fs,ts)[f(o)].append(o) | |
if isinstance(coll,tuple): | |
typ = type(coll) | |
ts,fs = typ(ts),typ(fs) | |
return ts,fs | |
# %% ../nbs/01_basics.ipynb | |
def flatten(o): | |
"Concatenate all collections and items as a generator" | |
for item in o: | |
if isinstance(item, str): yield item; continue | |
try: yield from flatten(item) | |
except TypeError: yield item | |
# %% ../nbs/01_basics.ipynb | |
def concat(colls)->list: | |
"Concatenate all collections and items as a list" | |
return list(flatten(colls)) | |
# %% ../nbs/01_basics.ipynb | |
def strcat(its, sep:str='')->str: | |
"Concatenate stringified items `its`" | |
return sep.join(map(str,its)) | |
# %% ../nbs/01_basics.ipynb | |
def detuplify(x): | |
"If `x` is a tuple with one thing, extract it" | |
return None if len(x)==0 else x[0] if len(x)==1 and getattr(x, 'ndim', 1)==1 else x | |
# %% ../nbs/01_basics.ipynb | |
def replicate(item,match): | |
"Create tuple of `item` copied `len(match)` times" | |
return (item,)*len(match) | |
# %% ../nbs/01_basics.ipynb | |
def setify(o): | |
"Turn any list like-object into a set." | |
return o if isinstance(o,set) else set(listify(o)) | |
# %% ../nbs/01_basics.ipynb | |
def merge(*ds): | |
"Merge all dictionaries in `ds`" | |
return {k:v for d in ds if d is not None for k,v in d.items()} | |
# %% ../nbs/01_basics.ipynb | |
def range_of(x): | |
"All indices of collection `x` (i.e. `list(range(len(x)))`)" | |
return list(range(len(x))) | |
# %% ../nbs/01_basics.ipynb | |
def groupby(x, key, val=noop): | |
"Like `itertools.groupby` but doesn't need to be sorted, and isn't lazy, plus some extensions" | |
if isinstance(key,int): key = itemgetter(key) | |
elif isinstance(key,str): key = attrgetter(key) | |
if isinstance(val,int): val = itemgetter(val) | |
elif isinstance(val,str): val = attrgetter(val) | |
res = {} | |
for o in x: res.setdefault(key(o), []).append(val(o)) | |
return res | |
# %% ../nbs/01_basics.ipynb | |
def last_index(x, o): | |
"Finds the last index of occurence of `x` in `o` (returns -1 if no occurence)" | |
try: return next(i for i in reversed(range(len(o))) if o[i] == x) | |
except StopIteration: return -1 | |
# %% ../nbs/01_basics.ipynb | |
def filter_dict(d, func): | |
"Filter a `dict` using `func`, applied to keys and values" | |
return {k:v for k,v in d.items() if func(k,v)} | |
# %% ../nbs/01_basics.ipynb | |
def filter_keys(d, func): | |
"Filter a `dict` using `func`, applied to keys" | |
return {k:v for k,v in d.items() if func(k)} | |
# %% ../nbs/01_basics.ipynb | |
def filter_values(d, func): | |
"Filter a `dict` using `func`, applied to values" | |
return {k:v for k,v in d.items() if func(v)} | |
# %% ../nbs/01_basics.ipynb | |
def cycle(o): | |
"Like `itertools.cycle` except creates list of `None`s if `o` is empty" | |
o = listify(o) | |
return itertools.cycle(o) if o is not None and len(o) > 0 else itertools.cycle([None]) | |
# %% ../nbs/01_basics.ipynb | |
def zip_cycle(x, *args): | |
"Like `itertools.zip_longest` but `cycle`s through elements of all but first argument" | |
return zip(x, *map(cycle,args)) | |
# %% ../nbs/01_basics.ipynb | |
def sorted_ex(iterable, key=None, reverse=False): | |
"Like `sorted`, but if key is str use `attrgetter`; if int use `itemgetter`" | |
if isinstance(key,str): k=lambda o:getattr(o,key,0) | |
elif isinstance(key,int): k=itemgetter(key) | |
else: k=key | |
return sorted(iterable, key=k, reverse=reverse) | |
# %% ../nbs/01_basics.ipynb | |
def not_(f): | |
"Create new function that negates result of `f`" | |
def _f(*args, **kwargs): return not f(*args, **kwargs) | |
return _f | |
# %% ../nbs/01_basics.ipynb | |
def argwhere(iterable, f, negate=False, **kwargs): | |
"Like `filter_ex`, but return indices for matching items" | |
if kwargs: f = partial(f,**kwargs) | |
if negate: f = not_(f) | |
return [i for i,o in enumerate(iterable) if f(o)] | |
# %% ../nbs/01_basics.ipynb | |
def filter_ex(iterable, f=noop, negate=False, gen=False, **kwargs): | |
"Like `filter`, but passing `kwargs` to `f`, defaulting `f` to `noop`, and adding `negate` and `gen`" | |
if f is None: f = lambda _: True | |
if kwargs: f = partial(f,**kwargs) | |
if negate: f = not_(f) | |
res = filter(f, iterable) | |
if gen: return res | |
return list(res) | |
# %% ../nbs/01_basics.ipynb | |
def range_of(a, b=None, step=None): | |
"All indices of collection `a`, if `a` is a collection, otherwise `range`" | |
if is_coll(a): a = len(a) | |
return list(range(a,b,step) if step is not None else range(a,b) if b is not None else range(a)) | |
# %% ../nbs/01_basics.ipynb | |
def renumerate(iterable, start=0): | |
"Same as `enumerate`, but returns index as 2nd element instead of 1st" | |
return ((o,i) for i,o in enumerate(iterable, start=start)) | |
# %% ../nbs/01_basics.ipynb | |
def first(x, f=None, negate=False, **kwargs): | |
"First element of `x`, optionally filtered by `f`, or None if missing" | |
x = iter(x) | |
if f: x = filter_ex(x, f=f, negate=negate, gen=True, **kwargs) | |
return next(x, None) | |
# %% ../nbs/01_basics.ipynb | |
def only(o): | |
"Return the only item of `o`, raise if `o` doesn't have exactly one item" | |
it = iter(o) | |
try: res = next(it) | |
except StopIteration: raise ValueError('iterable has 0 items') from None | |
try: next(it) | |
except StopIteration: return res | |
raise ValueError(f'iterable has more than 1 item') | |
# %% ../nbs/01_basics.ipynb | |
def nested_attr(o, attr, default=None): | |
"Same as `getattr`, but if `attr` includes a `.`, then looks inside nested objects" | |
try: | |
for a in attr.split("."): o = getattr(o, a) | |
except AttributeError: return default | |
return o | |
# %% ../nbs/01_basics.ipynb | |
def nested_setdefault(o, attr, default): | |
"Same as `setdefault`, but if `attr` includes a `.`, then looks inside nested objects" | |
attrs = attr.split('.') | |
for a in attrs[:-1]: o = o.setdefault(a, type(o)()) | |
return o.setdefault(attrs[-1], default) | |
# %% ../nbs/01_basics.ipynb | |
def nested_callable(o, attr): | |
"Same as `nested_attr` but if not found will return `noop`" | |
return nested_attr(o, attr, noop) | |
# %% ../nbs/01_basics.ipynb | |
def _access(coll, idx): | |
if isinstance(idx,str) and hasattr(coll, idx): return getattr(coll, idx) | |
if hasattr(coll, 'get'): return coll.get(idx, None) | |
try: length = len(coll) | |
except TypeError: length = 0 | |
if isinstance(idx,int) and idx<length: return coll[idx] | |
return None | |
def _nested_idx(coll, *idxs): | |
*idxs,last_idx = idxs | |
for idx in idxs: | |
if isinstance(idx,str) and hasattr(coll, idx): coll = getattr(coll, idx) | |
else: | |
if isinstance(coll,str) or not isinstance(coll, typing.Collection): return None,None | |
coll = coll.get(idx, None) if hasattr(coll, 'get') else coll[idx] if idx<len(coll) else None | |
return coll,last_idx | |
# %% ../nbs/01_basics.ipynb | |
def nested_idx(coll, *idxs): | |
"Index into nested collections, dicts, etc, with `idxs`" | |
if not coll or not idxs: return coll | |
coll,idx = _nested_idx(coll, *idxs) | |
if not coll or not idxs: return coll | |
return _access(coll, idx) | |
# %% ../nbs/01_basics.ipynb | |
def set_nested_idx(coll, value, *idxs): | |
"Set value indexed like `nested_idx" | |
coll,idx = _nested_idx(coll, *idxs) | |
coll[idx] = value | |
# %% ../nbs/01_basics.ipynb | |
def val2idx(x): | |
"Dict from value to index" | |
return {v:k for k,v in enumerate(x)} | |
# %% ../nbs/01_basics.ipynb | |
def uniqueify(x, sort=False, bidir=False, start=None): | |
"Unique elements in `x`, optional `sort`, optional return reverse correspondence, optional prepend with elements." | |
res = list(dict.fromkeys(x)) | |
if start is not None: res = listify(start)+res | |
if sort: res.sort() | |
return (res,val2idx(res)) if bidir else res | |
# %% ../nbs/01_basics.ipynb | |
# looping functions from https://github.com/willmcgugan/rich/blob/master/rich/_loop.py | |
def loop_first_last(values): | |
"Iterate and generate a tuple with a flag for first and last value." | |
iter_values = iter(values) | |
try: previous_value = next(iter_values) | |
except StopIteration: return | |
first = True | |
for value in iter_values: | |
yield first,False,previous_value | |
first,previous_value = False,value | |
yield first,True,previous_value | |
# %% ../nbs/01_basics.ipynb | |
def loop_first(values): | |
"Iterate and generate a tuple with a flag for first value." | |
return ((b,o) for b,_,o in loop_first_last(values)) | |
# %% ../nbs/01_basics.ipynb | |
def loop_last(values): | |
"Iterate and generate a tuple with a flag for last value." | |
return ((b,o) for _,b,o in loop_first_last(values)) | |
# %% ../nbs/01_basics.ipynb | |
def first_match(lst, f, default=None): | |
"First element of `lst` matching predicate `f`, or `default` if none" | |
return next((i for i,o in enumerate(lst) if f(o)), default) | |
# %% ../nbs/01_basics.ipynb | |
def last_match(lst, f, default=None): | |
"Last element of `lst` matching predicate `f`, or `default` if none" | |
return next((i for i in range(len(lst)-1, -1, -1) if f(lst[i])), default) | |
# %% ../nbs/01_basics.ipynb | |
num_methods = """ | |
__add__ __sub__ __mul__ __matmul__ __truediv__ __floordiv__ __mod__ __divmod__ __pow__ | |
__lshift__ __rshift__ __and__ __xor__ __or__ __neg__ __pos__ __abs__ | |
""".split() | |
rnum_methods = """ | |
__radd__ __rsub__ __rmul__ __rmatmul__ __rtruediv__ __rfloordiv__ __rmod__ __rdivmod__ | |
__rpow__ __rlshift__ __rrshift__ __rand__ __rxor__ __ror__ | |
""".split() | |
inum_methods = """ | |
__iadd__ __isub__ __imul__ __imatmul__ __itruediv__ | |
__ifloordiv__ __imod__ __ipow__ __ilshift__ __irshift__ __iand__ __ixor__ __ior__ | |
""".split() | |
# %% ../nbs/01_basics.ipynb | |
class fastuple(tuple): | |
"A `tuple` with elementwise ops and more friendly __init__ behavior" | |
def __new__(cls, x=None, *rest): | |
if x is None: x = () | |
if not isinstance(x,tuple): | |
if len(rest): x = (x,) | |
else: | |
try: x = tuple(iter(x)) | |
except TypeError: x = (x,) | |
return super().__new__(cls, x+rest if rest else x) | |
def _op(self,op,*args): | |
if not isinstance(self,fastuple): self = fastuple(self) | |
return type(self)(map(op,self,*map(cycle, args))) | |
def mul(self,*args): | |
"`*` is already defined in `tuple` for replicating, so use `mul` instead" | |
return fastuple._op(self, operator.mul,*args) | |
def add(self,*args): | |
"`+` is already defined in `tuple` for concat, so use `add` instead" | |
return fastuple._op(self, operator.add,*args) | |
def _get_op(op): | |
if isinstance(op,str): op = getattr(operator,op) | |
def _f(self,*args): return self._op(op,*args) | |
return _f | |
for n in num_methods: | |
if not hasattr(fastuple, n) and hasattr(operator,n): setattr(fastuple,n,_get_op(n)) | |
for n in 'eq ne lt le gt ge'.split(): setattr(fastuple,n,_get_op(n)) | |
setattr(fastuple,'__invert__',_get_op('__not__')) | |
setattr(fastuple,'max',_get_op(max)) | |
setattr(fastuple,'min',_get_op(min)) | |
# %% ../nbs/01_basics.ipynb | |
class _Arg: | |
def __init__(self,i): self.i = i | |
arg0 = _Arg(0) | |
arg1 = _Arg(1) | |
arg2 = _Arg(2) | |
arg3 = _Arg(3) | |
arg4 = _Arg(4) | |
# %% ../nbs/01_basics.ipynb | |
class bind: | |
"Same as `partial`, except you can use `arg0` `arg1` etc param placeholders" | |
def __init__(self, func, *pargs, **pkwargs): | |
self.func,self.pargs,self.pkwargs = func,pargs,pkwargs | |
self.maxi = max((x.i for x in pargs if isinstance(x, _Arg)), default=-1) | |
def __call__(self, *args, **kwargs): | |
args = list(args) | |
kwargs = {**self.pkwargs,**kwargs} | |
for k,v in kwargs.items(): | |
if isinstance(v,_Arg): kwargs[k] = args.pop(v.i) | |
fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:] | |
return self.func(*fargs, **kwargs) | |
# %% ../nbs/01_basics.ipynb | |
def mapt(func, *iterables): | |
"Tuplified `map`" | |
return tuple(map(func, *iterables)) | |
# %% ../nbs/01_basics.ipynb | |
def map_ex(iterable, f, *args, gen=False, **kwargs): | |
"Like `map`, but use `bind`, and supports `str` and indexing" | |
g = (bind(f,*args,**kwargs) if callable(f) | |
else f.format if isinstance(f,str) | |
else f.__getitem__) | |
res = map(g, iterable) | |
if gen: return res | |
return list(res) | |
# %% ../nbs/01_basics.ipynb | |
def compose(*funcs, order=None): | |
"Create a function that composes all functions in `funcs`, passing along remaining `*args` and `**kwargs` to all" | |
funcs = listify(funcs) | |
if len(funcs)==0: return noop | |
if len(funcs)==1: return funcs[0] | |
if order is not None: funcs = sorted_ex(funcs, key=order) | |
def _inner(x, *args, **kwargs): | |
for f in funcs: x = f(x, *args, **kwargs) | |
return x | |
return _inner | |
# %% ../nbs/01_basics.ipynb | |
def maps(*args, retain=noop): | |
"Like `map`, except funcs are composed first" | |
f = compose(*args[:-1]) | |
def _f(b): return retain(f(b), b) | |
return map(_f, args[-1]) | |
# %% ../nbs/01_basics.ipynb | |
def partialler(f, *args, order=None, **kwargs): | |
"Like `functools.partial` but also copies over docstring" | |
fnew = partial(f,*args,**kwargs) | |
fnew.__doc__ = f.__doc__ | |
if order is not None: fnew.order=order | |
elif hasattr(f,'order'): fnew.order=f.order | |
return fnew | |
# %% ../nbs/01_basics.ipynb | |
def instantiate(t): | |
"Instantiate `t` if it's a type, otherwise do nothing" | |
return t() if isinstance(t, type) else t | |
# %% ../nbs/01_basics.ipynb | |
def _using_attr(f, attr, x): return f(getattr(x,attr)) | |
# %% ../nbs/01_basics.ipynb | |
def using_attr(f, attr): | |
"Construct a function which applies `f` to the argument's attribute `attr`" | |
return partial(_using_attr, f, attr) | |
# %% ../nbs/01_basics.ipynb | |
class _Self: | |
"An alternative to `lambda` for calling methods on passed object." | |
def __init__(self): self.nms,self.args,self.kwargs,self.ready = [],[],[],True | |
def __repr__(self): return f'self: {self.nms}({self.args}, {self.kwargs})' | |
def __call__(self, *args, **kwargs): | |
if self.ready: | |
x = args[0] | |
for n,a,k in zip(self.nms,self.args,self.kwargs): | |
x = getattr(x,n) | |
if callable(x) and a is not None: x = x(*a, **k) | |
return x | |
else: | |
self.args.append(args) | |
self.kwargs.append(kwargs) | |
self.ready = True | |
return self | |
def __getattr__(self,k): | |
if not self.ready: | |
self.args.append(None) | |
self.kwargs.append(None) | |
self.nms.append(k) | |
self.ready = False | |
return self | |
def _call(self, *args, **kwargs): | |
self.args,self.kwargs,self.nms = [args],[kwargs],['__call__'] | |
self.ready = True | |
return self | |
# %% ../nbs/01_basics.ipynb | |
class _SelfCls: | |
def __getattr__(self,k): return getattr(_Self(),k) | |
def __getitem__(self,i): return self.__getattr__('__getitem__')(i) | |
def __call__(self,*args,**kwargs): return self.__getattr__('_call')(*args,**kwargs) | |
Self = _SelfCls() | |
# %% ../nbs/01_basics.ipynb | |
_all_ = ['Self'] | |
# %% ../nbs/01_basics.ipynb | |
def copy_func(f): | |
"Copy a non-builtin function (NB `copy.copy` does not work for this)" | |
if not isinstance(f,FunctionType): return copy(f) | |
fn = FunctionType(f.__code__, f.__globals__, f.__name__, f.__defaults__, f.__closure__) | |
fn.__kwdefaults__ = f.__kwdefaults__ | |
fn.__dict__.update(f.__dict__) | |
fn.__annotations__.update(f.__annotations__) | |
fn.__qualname__ = f.__qualname__ | |
return fn | |
# %% ../nbs/01_basics.ipynb | |
class _clsmethod: | |
def __init__(self, f): self.f = f | |
def __get__(self, _, f_cls): return MethodType(self.f, f_cls) | |
# %% ../nbs/01_basics.ipynb | |
def patch_to(cls, as_prop=False, cls_method=False): | |
"Decorator: add `f` to `cls`" | |
if not isinstance(cls, (tuple,list)): cls=(cls,) | |
def _inner(f): | |
for c_ in cls: | |
nf = copy_func(f) | |
nm = f.__name__ | |
# `functools.update_wrapper` when passing patched function to `Pipeline`, so we do it manually | |
for o in functools.WRAPPER_ASSIGNMENTS: setattr(nf, o, getattr(f,o)) | |
nf.__qualname__ = f"{c_.__name__}.{nm}" | |
if cls_method: setattr(c_, nm, _clsmethod(nf)) | |
else: | |
if as_prop: setattr(c_, nm, property(nf)) | |
else: | |
onm = '_orig_'+nm | |
if hasattr(c_, nm) and not hasattr(c_, onm): setattr(c_, onm, getattr(c_, nm)) | |
setattr(c_, nm, nf) | |
# Avoid clobbering existing functions | |
return globals().get(nm, builtins.__dict__.get(nm, None)) | |
return _inner | |
# %% ../nbs/01_basics.ipynb | |
def patch(f=None, *, as_prop=False, cls_method=False): | |
"Decorator: add `f` to the first parameter's class (based on f's type annotations)" | |
if f is None: return partial(patch, as_prop=as_prop, cls_method=cls_method) | |
ann,glb,loc = get_annotations_ex(f) | |
cls = union2tuple(eval_type(ann.pop('cls') if cls_method else next(iter(ann.values())), glb, loc)) | |
return patch_to(cls, as_prop=as_prop, cls_method=cls_method)(f) | |
# %% ../nbs/01_basics.ipynb | |
def patch_property(f): | |
"Deprecated; use `patch(as_prop=True)` instead" | |
warnings.warn("`patch_property` is deprecated and will be removed; use `patch(as_prop=True)` instead") | |
cls = next(iter(f.__annotations__.values())) | |
return patch_to(cls, as_prop=True)(f) | |
# %% ../nbs/01_basics.ipynb | |
def compile_re(pat): | |
"Compile `pat` if it's not None" | |
return None if pat is None else re.compile(pat) | |
# %% ../nbs/01_basics.ipynb | |
class ImportEnum(enum.Enum): | |
"An `Enum` that can have its values imported" | |
@classmethod | |
def imports(cls): | |
g = sys._getframe(1).f_locals | |
for o in cls: g[o.name]=o | |
# %% ../nbs/01_basics.ipynb | |
class StrEnum(str,ImportEnum): | |
"An `ImportEnum` that behaves like a `str`" | |
def __str__(self): return self.name | |
# %% ../nbs/01_basics.ipynb | |
def str_enum(name, *vals): | |
"Simplified creation of `StrEnum` types" | |
return StrEnum(name, {o:o for o in vals}) | |
# %% ../nbs/01_basics.ipynb | |
class ValEnum(str,ImportEnum): | |
"An `ImportEnum` that stringifies using values" | |
def __str__(self): return self.value | |
# %% ../nbs/01_basics.ipynb | |
class Stateful: | |
"A base class/mixin for objects that should not serialize all their state" | |
_stateattrs=() | |
def __init__(self,*args,**kwargs): | |
self._init_state() | |
super().__init__(*args,**kwargs) # required for mixin usage | |
def __getstate__(self): | |
return {k:v for k,v in self.__dict__.items() | |
if k not in self._stateattrs+('_state',)} | |
def __setstate__(self, state): | |
self.__dict__.update(state) | |
self._init_state() | |
def _init_state(self): | |
"Override for custom init and deserialization logic" | |
self._state = {} | |
# %% ../nbs/01_basics.ipynb | |
class NotStr(GetAttr): | |
"Behaves like a `str`, but isn't an instance of one" | |
_default = 's' | |
def __init__(self, s): self.s = s.s if isinstance(s, NotStr) else s | |
def __repr__(self): return repr(self.s) | |
def __str__(self): return self.s | |
def __add__(self, b): return NotStr(self.s+str(b)) | |
def __mul__(self, b): return NotStr(self.s*b) | |
def __len__(self): return len(self.s) | |
def __eq__(self, b): return self.s==b.s if isinstance(b, NotStr) else b | |
def __lt__(self, b): return self.s<b | |
def __hash__(self): return hash(self.s) | |
def __bool__(self): return bool(self.s) | |
def __contains__(self, b): return b in self.s | |
def __iter__(self): return iter(self.s) | |
# %% ../nbs/01_basics.ipynb | |
class PrettyString(str): | |
"Little hack to get strings to show properly in Jupyter." | |
def __repr__(self): return self | |
# %% ../nbs/01_basics.ipynb | |
def even_mults(start, stop, n): | |
"Build log-stepped array from `start` to `stop` in `n` steps." | |
if n==1: return stop | |
mult = stop/start | |
step = mult**(1/(n-1)) | |
return [start*(step**i) for i in range(n)] | |
# %% ../nbs/01_basics.ipynb | |
def num_cpus(): | |
"Get number of cpus" | |
try: return len(os.sched_getaffinity(0)) | |
except AttributeError: return os.cpu_count() | |
defaults.cpus = num_cpus() | |
# %% ../nbs/01_basics.ipynb | |
def add_props(f, g=None, n=2): | |
"Create properties passing each of `range(n)` to f" | |
if g is None: return (property(partial(f,i)) for i in range(n)) | |
return (property(partial(f,i), partial(g,i)) for i in range(n)) | |
# %% ../nbs/01_basics.ipynb | |
def _typeerr(arg, val, typ): return TypeError(f"{arg}=={val} not {typ}") | |
# %% ../nbs/01_basics.ipynb | |
def typed(f): | |
"Decorator to check param and return types at runtime" | |
names = f.__code__.co_varnames | |
anno = annotations(f) | |
ret = anno.pop('return',None) | |
def _f(*args,**kwargs): | |
kw = {**kwargs} | |
if len(anno) > 0: | |
for i,arg in enumerate(args): kw[names[i]] = arg | |
for k,v in kw.items(): | |
if k in anno and not isinstance(v,anno[k]): raise _typeerr(k, v, anno[k]) | |
res = f(*args,**kwargs) | |
if ret is not None and not isinstance(res,ret): raise _typeerr("return", res, ret) | |
return res | |
return functools.update_wrapper(_f, f) | |
# %% ../nbs/01_basics.ipynb | |
def exec_new(code): | |
"Execute `code` in a new environment and return it" | |
pkg = None if __name__=='__main__' else Path().cwd().name | |
g = {'__name__': __name__, '__package__': pkg} | |
exec(code, g) | |
return g | |
# %% ../nbs/01_basics.ipynb | |
def exec_import(mod, sym): | |
"Import `sym` from `mod` in a new environment" | |
# pref = '' if __name__=='__main__' or mod[0]=='.' else '.' | |
return exec_new(f'from {mod} import {sym}') | |
# %% ../nbs/01_basics.ipynb | |
def str2bool(s): | |
"Case-insensitive convert string `s` too a bool (`y`,`yes`,`t`,`true`,`on`,`1`->`True`)" | |
if not isinstance(s,str): return bool(s) | |
if not s: return False | |
s = s.lower() | |
if s in ('y', 'yes', 't', 'true', 'on', '1'): return 1 | |
elif s in ('n', 'no', 'f', 'false', 'off', '0'): return 0 | |
else: raise ValueError() | |
--- | |
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/02_foundation.ipynb. | |
# %% auto 0 | |
__all__ = ['working_directory', 'add_docs', 'docs', 'coll_repr', 'is_bool', 'mask2idxs', 'cycle', 'zip_cycle', 'is_indexer', | |
'CollBase', 'L', 'save_config_file', 'read_config_file', 'Config'] | |
# %% ../nbs/02_foundation.ipynb | |
from .imports import * | |
from .basics import * | |
from functools import lru_cache | |
from contextlib import contextmanager | |
from copy import copy | |
from configparser import ConfigParser | |
import random,pickle,inspect | |
# %% ../nbs/02_foundation.ipynb | |
@contextmanager | |
def working_directory(path): | |
"Change working directory to `path` and return to previous on exit." | |
prev_cwd = Path.cwd() | |
os.chdir(path) | |
try: yield | |
finally: os.chdir(prev_cwd) | |
# %% ../nbs/02_foundation.ipynb | |
def add_docs(cls, cls_doc=None, **docs): | |
"Copy values from `docs` to `cls` docstrings, and confirm all public methods are documented" | |
if cls_doc is not None: cls.__doc__ = cls_doc | |
for k,v in docs.items(): | |
f = getattr(cls,k) | |
if hasattr(f,'__func__'): f = f.__func__ # required for class methods | |
f.__doc__ = v | |
# List of public callables without docstring | |
nodoc = [c for n,c in vars(cls).items() if callable(c) | |
and not n.startswith('_') and c.__doc__ is None] | |
assert not nodoc, f"Missing docs: {nodoc}" | |
assert cls.__doc__ is not None, f"Missing class docs: {cls}" | |
# %% ../nbs/02_foundation.ipynb | |
def docs(cls): | |
"Decorator version of `add_docs`, using `_docs` dict" | |
add_docs(cls, **cls._docs) | |
return cls | |
# %% ../nbs/02_foundation.ipynb | |
def coll_repr(c, max_n=10): | |
"String repr of up to `max_n` items of (possibly lazy) collection `c`" | |
return f'(#{len(c)}) [' + ','.join(itertools.islice(map(repr,c), max_n)) + ( | |
'...' if len(c)>max_n else '') + ']' | |
# %% ../nbs/02_foundation.ipynb | |
def is_bool(x): | |
"Check whether `x` is a bool or None" | |
return isinstance(x,(bool,NoneType)) or risinstance('bool_', x) | |
# %% ../nbs/02_foundation.ipynb | |
def mask2idxs(mask): | |
"Convert bool mask or index list to index `L`" | |
if isinstance(mask,slice): return mask | |
mask = list(mask) | |
if len(mask)==0: return [] | |
it = mask[0] | |
if hasattr(it,'item'): it = it.item() | |
if is_bool(it): return [i for i,m in enumerate(mask) if m] | |
return [int(i) for i in mask] | |
# %% ../nbs/02_foundation.ipynb | |
def cycle(o): | |
"Like `itertools.cycle` except creates list of `None`s if `o` is empty" | |
o = listify(o) | |
return itertools.cycle(o) if o is not None and len(o) > 0 else itertools.cycle([None]) | |
# %% ../nbs/02_foundation.ipynb | |
def zip_cycle(x, *args): | |
"Like `itertools.zip_longest` but `cycle`s through elements of all but first argument" | |
return zip(x, *map(cycle,args)) | |
# %% ../nbs/02_foundation.ipynb | |
def is_indexer(idx): | |
"Test whether `idx` will index a single item in a list" | |
return isinstance(idx,int) or not getattr(idx,'ndim',1) | |
# %% ../nbs/02_foundation.ipynb | |
class CollBase: | |
"Base class for composing a list of `items`" | |
def __init__(self, items): self.items = items | |
def __len__(self): return len(self.items) | |
def __getitem__(self, k): return self.items[list(k) if isinstance(k,CollBase) else k] | |
def __setitem__(self, k, v): self.items[list(k) if isinstance(k,CollBase) else k] = v | |
def __delitem__(self, i): del(self.items[i]) | |
def __repr__(self): return self.items.__repr__() | |
def __iter__(self): return self.items.__iter__() | |
# %% ../nbs/02_foundation.ipynb | |
class _L_Meta(type): | |
def __call__(cls, x=None, *args, **kwargs): | |
if not args and not kwargs and x is not None and isinstance(x,cls): return x | |
return super().__call__(x, *args, **kwargs) | |
# %% ../nbs/02_foundation.ipynb | |
class L(GetAttr, CollBase, metaclass=_L_Meta): | |
"Behaves like a list of `items` but can also index with list of indices or masks" | |
_default='items' | |
def __init__(self, items=None, *rest, use_list=False, match=None): | |
if (use_list is not None) or not is_array(items): | |
items = listify(items, *rest, use_list=use_list, match=match) | |
super().__init__(items) | |
@property | |
def _xtra(self): return None | |
def _new(self, items, *args, **kwargs): return type(self)(items, *args, use_list=None, **kwargs) | |
def __getitem__(self, idx): return self._get(idx) if is_indexer(idx) else L(self._get(idx), use_list=None) | |
def copy(self): return self._new(self.items.copy()) | |
def _get(self, i): | |
if is_indexer(i) or isinstance(i,slice): return getattr(self.items,'iloc',self.items)[i] | |
i = mask2idxs(i) | |
return (self.items.iloc[list(i)] if hasattr(self.items,'iloc') | |
else self.items.__array__()[(i,)] if hasattr(self.items,'__array__') | |
else [self.items[i_] for i_ in i]) | |
def __setitem__(self, idx, o): | |
"Set `idx` (can be list of indices, or mask, or int) items to `o` (which is broadcast if not iterable)" | |
if isinstance(idx, int): self.items[idx] = o | |
else: | |
idx = idx if isinstance(idx,L) else listify(idx) | |
if not is_iter(o): o = [o]*len(idx) | |
for i,o_ in zip(idx,o): self.items[i] = o_ | |
def __eq__(self,b): | |
if b is None: return False | |
if not hasattr(b, '__iter__'): return False | |
if risinstance('ndarray', b): return array_equal(b, self) | |
if isinstance(b, (str,dict)) or callable(b): return False | |
return all_equal(b,self) | |
def sorted(self, key=None, reverse=False): return self._new(sorted_ex(self, key=key, reverse=reverse)) | |
def __iter__(self): return iter(self.items.itertuples() if hasattr(self.items,'iloc') else self.items) | |
def __contains__(self,b): return b in self.items | |
def __reversed__(self): return self._new(reversed(self.items)) | |
def __invert__(self): return self._new(not i for i in self) | |
def __repr__(self): return repr(self.items) | |
def _repr_pretty_(self, p, cycle): | |
p.text('...' if cycle else repr(self.items) if is_array(self.items) else coll_repr(self)) | |
def __mul__ (a,b): return a._new(a.items*b) | |
def __add__ (a,b): return a._new(a.items+listify(b)) | |
def __radd__(a,b): return a._new(b)+a | |
def __addi__(a,b): | |
a.items += list(b) | |
return a | |
@classmethod | |
def split(cls, s, sep=None, maxsplit=-1): return cls(s.split(sep,maxsplit)) | |
@classmethod | |
def range(cls, a, b=None, step=None): return cls(range_of(a, b=b, step=step)) | |
def map(self, f, *args, **kwargs): return self._new(map_ex(self, f, *args, gen=False, **kwargs)) | |
def argwhere(self, f, negate=False, **kwargs): return self._new(argwhere(self, f, negate, **kwargs)) | |
def argfirst(self, f, negate=False): | |
if negate: f = not_(f) | |
return first(i for i,o in self.enumerate() if f(o)) | |
def filter(self, f=noop, negate=False, **kwargs): | |
return self._new(filter_ex(self, f=f, negate=negate, gen=False, **kwargs)) | |
def enumerate(self): return L(enumerate(self)) | |
def renumerate(self): return L(renumerate(self)) | |
def unique(self, sort=False, bidir=False, start=None): return L(uniqueify(self, sort=sort, bidir=bidir, start=start)) | |
def val2idx(self): return val2idx(self) | |
def cycle(self): return cycle(self) | |
def map_dict(self, f=noop, *args, **kwargs): return {k:f(k, *args,**kwargs) for k in self} | |
def map_first(self, f=noop, g=noop, *args, **kwargs): | |
return first(self.map(f, *args, **kwargs), g) | |
def itemgot(self, *idxs): | |
x = self | |
for idx in idxs: x = x.map(itemgetter(idx)) | |
return x | |
def attrgot(self, k, default=None): | |
return self.map(lambda o: o.get(k,default) if isinstance(o, dict) else nested_attr(o,k,default)) | |
def starmap(self, f, *args, **kwargs): return self._new(itertools.starmap(partial(f,*args,**kwargs), self)) | |
def zip(self, cycled=False): return self._new((zip_cycle if cycled else zip)(*self)) | |
def zipwith(self, *rest, cycled=False): return self._new([self, *rest]).zip(cycled=cycled) | |
def map_zip(self, f, *args, cycled=False, **kwargs): return self.zip(cycled=cycled).starmap(f, *args, **kwargs) | |
def map_zipwith(self, f, *rest, cycled=False, **kwargs): return self.zipwith(*rest, cycled=cycled).starmap(f, **kwargs) | |
def shuffle(self): | |
it = copy(self.items) | |
random.shuffle(it) | |
return self._new(it) | |
def concat(self): return self._new(itertools.chain.from_iterable(self.map(L))) | |
def reduce(self, f, initial=None): return reduce(f, self) if initial is None else reduce(f, self, initial) | |
def sum(self): return self.reduce(operator.add, 0) | |
def product(self): return self.reduce(operator.mul, 1) | |
def setattrs(self, attr, val): [setattr(o,attr,val) for o in self] | |
# %% ../nbs/02_foundation.ipynb | |
add_docs(L, | |
__getitem__="Retrieve `idx` (can be list of indices, or mask, or int) items", | |
range="Class Method: Same as `range`, but returns `L`. Can pass collection for `a`, to use `len(a)`", | |
split="Class Method: Same as `str.split`, but returns an `L`", | |
copy="Same as `list.copy`, but returns an `L`", | |
sorted="New `L` sorted by `key`. If key is str use `attrgetter`; if int use `itemgetter`", | |
unique="Unique items, in stable order", | |
val2idx="Dict from value to index", | |
filter="Create new `L` filtered by predicate `f`, passing `args` and `kwargs` to `f`", | |
argwhere="Like `filter`, but return indices for matching items", | |
argfirst="Return index of first matching item", | |
map="Create new `L` with `f` applied to all `items`, passing `args` and `kwargs` to `f`", | |
map_first="First element of `map_filter`", | |
map_dict="Like `map`, but creates a dict from `items` to function results", | |
starmap="Like `map`, but use `itertools.starmap`", | |
itemgot="Create new `L` with item `idx` of all `items`", | |
attrgot="Create new `L` with attr `k` (or value `k` for dicts) of all `items`.", | |
cycle="Same as `itertools.cycle`", | |
enumerate="Same as `enumerate`", | |
renumerate="Same as `renumerate`", | |
zip="Create new `L` with `zip(*items)`", | |
zipwith="Create new `L` with `self` zip with each of `*rest`", | |
map_zip="Combine `zip` and `starmap`", | |
map_zipwith="Combine `zipwith` and `starmap`", | |
concat="Concatenate all elements of list", | |
shuffle="Same as `random.shuffle`, but not inplace", | |
reduce="Wrapper for `functools.reduce`", | |
sum="Sum of the items", | |
product="Product of the items", | |
setattrs="Call `setattr` on all items" | |
) | |
# %% ../nbs/02_foundation.ipynb | |
# Here we are fixing the signature of L. What happens is that the __call__ method on the MetaClass of L shadows the __init__ | |
# giving the wrong signature (https://stackoverflow.com/questions/49740290/call-from-metaclass-shadows-signature-of-init). | |
def _f(items=None, *rest, use_list=False, match=None): ... | |
L.__signature__ = inspect.signature(_f) | |
# %% ../nbs/02_foundation.ipynb | |
Sequence.register(L); | |
# %% ../nbs/02_foundation.ipynb | |
def save_config_file(file, d, **kwargs): | |
"Write settings dict to a new config file, or overwrite the existing one." | |
config = ConfigParser(**kwargs) | |
config['DEFAULT'] = d | |
config.write(open(file, 'w')) | |
# %% ../nbs/02_foundation.ipynb | |
def read_config_file(file, **kwargs): | |
config = ConfigParser(**kwargs) | |
config.read(file, encoding='utf8') | |
return config['DEFAULT'] | |
# %% ../nbs/02_foundation.ipynb | |
class Config: | |
"Reading and writing `ConfigParser` ini files" | |
def __init__(self, cfg_path, cfg_name, create=None, save=True, extra_files=None, types=None): | |
self.types = types or {} | |
cfg_path = Path(cfg_path).expanduser().absolute() | |
self.config_path,self.config_file = cfg_path,cfg_path/cfg_name | |
self._cfg = ConfigParser() | |
self.d = self._cfg['DEFAULT'] | |
found = [Path(o) for o in self._cfg.read(L(extra_files)+[self.config_file], encoding='utf8')] | |
if self.config_file not in found and create is not None: | |
self._cfg.read_dict({'DEFAULT':create}) | |
if save: | |
cfg_path.mkdir(exist_ok=True, parents=True) | |
save_config_file(self.config_file, create) | |
def __repr__(self): return repr(dict(self._cfg.items('DEFAULT', raw=True))) | |
def __setitem__(self,k,v): self.d[k] = str(v) | |
def __contains__(self,k): return k in self.d | |
def save(self): save_config_file(self.config_file,self.d) | |
def __getattr__(self,k): return stop(AttributeError(k)) if k=='d' or k not in self.d else self.get(k) | |
def __getitem__(self,k): return stop(IndexError(k)) if k not in self.d else self.get(k) | |
def get(self,k,default=None): | |
v = self.d.get(k, default) | |
if v is None: return None | |
typ = self.types.get(k, None) | |
if typ==bool: return str2bool(v) | |
if not typ: return str(v) | |
if typ==Path: return self.config_path/v | |
return typ(v) | |
def path(self,k,default=None): | |
v = self.get(k, default) | |
return v if v is None else self.config_path/v | |
--- | |
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/03_xtras.ipynb. | |
# %% ../nbs/03_xtras.ipynb 1 | |
from __future__ import annotations | |
# %% auto 0 | |
__all__ = ['spark_chars', 'UNSET', 'walk', 'globtastic', 'maybe_open', 'mkdir', 'image_size', 'bunzip', 'loads', 'loads_multi', | |
'dumps', 'untar_dir', 'repo_details', 'run', 'open_file', 'save_pickle', 'load_pickle', 'parse_env', | |
'expand_wildcards', 'dict2obj', 'obj2dict', 'repr_dict', 'is_listy', 'mapped', 'IterLen', | |
'ReindexCollection', 'get_source_link', 'truncstr', 'sparkline', 'modify_exception', 'round_multiple', | |
'set_num_threads', 'join_path_file', 'autostart', 'EventTimer', 'stringfmt_names', 'PartialFormatter', | |
'partial_format', 'utc2local', 'local2utc', 'trace', 'modified_env', 'ContextManagers', 'shufflish', | |
'console_help', 'hl_md', 'type2str', 'dataclass_src', 'Unset', 'nullable_dc', 'make_nullable', 'flexiclass', | |
'asdict', 'is_typeddict', 'is_namedtuple', 'flexicache', 'time_policy', 'mtime_policy', 'timed_cache'] | |
# %% ../nbs/03_xtras.ipynb | |
from .imports import * | |
from .foundation import * | |
from .basics import * | |
from importlib import import_module | |
from functools import wraps | |
import string,time,dataclasses | |
from enum import Enum | |
from contextlib import contextmanager,ExitStack | |
from datetime import datetime, timezone | |
from time import sleep,time,perf_counter | |
from os.path import getmtime | |
from dataclasses import dataclass, field, fields, is_dataclass, MISSING, make_dataclass | |
# %% ../nbs/03_xtras.ipynb | |
def walk( | |
path:Path|str, # path to start searching | |
symlinks:bool=True, # follow symlinks? | |
keep_file:callable=ret_true, # function that returns True for wanted files | |
keep_folder:callable=ret_true, # function that returns True for folders to enter | |
skip_folder:callable=ret_false, # function that returns True for folders to skip | |
func:callable=os.path.join, # function to apply to each matched file | |
ret_folders:bool=False # return folders, not just files | |
): | |
"Generator version of `os.walk`, using functions to filter files and folders" | |
from copy import copy | |
for root,dirs,files in os.walk(path, followlinks=symlinks): | |
if keep_folder(root,''): | |
if ret_folders: yield func(root, '') | |
yield from (func(root, name) for name in files if keep_file(root,name)) | |
for name in copy(dirs): | |
if skip_folder(root,name): dirs.remove(name) | |
# %% ../nbs/03_xtras.ipynb | |
def globtastic( | |
path:Path|str, # path to start searching | |
recursive:bool=True, # search subfolders | |
symlinks:bool=True, # follow symlinks? | |
file_glob:str=None, # Only include files matching glob | |
file_re:str=None, # Only include files matching regex | |
folder_re:str=None, # Only enter folders matching regex | |
skip_file_glob:str=None, # Skip files matching glob | |
skip_file_re:str=None, # Skip files matching regex | |
skip_folder_re:str=None, # Skip folders matching regex, | |
func:callable=os.path.join, # function to apply to each matched file | |
ret_folders:bool=False # return folders, not just files | |
)->L: # Paths to matched files | |
"A more powerful `glob`, including regex matches, symlink handling, and skip parameters" | |
from fnmatch import fnmatch | |
path = Path(path) | |
if path.is_file(): return L([path]) | |
if not recursive: skip_folder_re='.' | |
file_re,folder_re = compile_re(file_re),compile_re(folder_re) | |
skip_file_re,skip_folder_re = compile_re(skip_file_re),compile_re(skip_folder_re) | |
def _keep_file(root, name): | |
return (not file_glob or fnmatch(name, file_glob)) and ( | |
not file_re or file_re.search(name)) and ( | |
not skip_file_glob or not fnmatch(name, skip_file_glob)) and ( | |
not skip_file_re or not skip_file_re.search(name)) | |
def _keep_folder(root, name): return not folder_re or folder_re.search(os.path.join(root,name)) | |
def _skip_folder(root, name): return skip_folder_re and skip_folder_re.search(name) | |
return L(walk(path, symlinks=symlinks, keep_file=_keep_file, keep_folder=_keep_folder, skip_folder=_skip_folder, | |
func=func, ret_folders=ret_folders)) | |
# %% ../nbs/03_xtras.ipynb | |
@contextmanager | |
def maybe_open(f, mode='r', **kwargs): | |
"Context manager: open `f` if it is a path (and close on exit)" | |
if isinstance(f, (str,os.PathLike)): | |
with open(f, mode, **kwargs) as f: yield f | |
else: yield f | |
# %% ../nbs/03_xtras.ipynb | |
def mkdir(path, exist_ok=False, parents=False, overwrite=False, **kwargs): | |
"Creates and returns a directory defined by `path`, optionally removing previous existing directory if `overwrite` is `True`" | |
import shutil | |
path = Path(path) | |
if path.exists() and overwrite: shutil.rmtree(path) | |
path.mkdir(exist_ok=exist_ok, parents=parents, **kwargs) | |
return path | |
# %% ../nbs/03_xtras.ipynb | |
def image_size(fn): | |
"Tuple of (w,h) for png, gif, or jpg; `None` otherwise" | |
from fastcore import imghdr | |
import struct | |
def _jpg_size(f): | |
size,ftype = 2,0 | |
while not 0xc0 <= ftype <= 0xcf: | |
f.seek(size, 1) | |
byte = f.read(1) | |
while ord(byte) == 0xff: byte = f.read(1) | |
ftype = ord(byte) | |
size = struct.unpack('>H', f.read(2))[0] - 2 | |
f.seek(1, 1) # `precision' | |
h,w = struct.unpack('>HH', f.read(4)) | |
return w,h | |
def _gif_size(f): return struct.unpack('<HH', head[6:10]) | |
def _png_size(f): | |
assert struct.unpack('>i', head[4:8])[0]==0x0d0a1a0a | |
return struct.unpack('>ii', head[16:24]) | |
d = dict(png=_png_size, gif=_gif_size, jpeg=_jpg_size) | |
with maybe_open(fn, 'rb') as f: return d[imghdr.what(f)](f) | |
# %% ../nbs/03_xtras.ipynb | |
def bunzip(fn): | |
"bunzip `fn`, raising exception if output already exists" | |
fn = Path(fn) | |
assert fn.exists(), f"{fn} doesn't exist" | |
out_fn = fn.with_suffix('') | |
assert not out_fn.exists(), f"{out_fn} already exists" | |
import bz2 | |
with bz2.BZ2File(fn, 'rb') as src, out_fn.open('wb') as dst: | |
for d in iter(lambda: src.read(1024*1024), b''): dst.write(d) | |
# %% ../nbs/03_xtras.ipynb | |
def loads(s, **kw): | |
"Same as `json.loads`, but handles `None`" | |
if not s: return {} | |
try: import ujson as json | |
except ModuleNotFoundError: import json | |
return json.loads(s, **kw) | |
# %% ../nbs/03_xtras.ipynb | |
def loads_multi(s:str): | |
"Generator of >=0 decoded json dicts, possibly with non-json ignored text at start and end" | |
import json | |
_dec = json.JSONDecoder() | |
while s.find('{')>=0: | |
s = s[s.find('{'):] | |
obj,pos = _dec.raw_decode(s) | |
if not pos: raise ValueError(f'no JSON object found at {pos}') | |
yield obj | |
s = s[pos:] | |
# %% ../nbs/03_xtras.ipynb | |
def dumps(obj, **kw): | |
"Same as `json.dumps`, but uses `ujson` if available" | |
try: import ujson as json | |
except ModuleNotFoundError: import json | |
else: kw['escape_forward_slashes']=False | |
return json.dumps(obj, **kw) | |
# %% ../nbs/03_xtras.ipynb | |
def _unpack(fname, out): | |
import shutil | |
shutil.unpack_archive(str(fname), str(out)) | |
ls = out.ls() | |
return ls[0] if len(ls) == 1 else out | |
# %% ../nbs/03_xtras.ipynb | |
def untar_dir(fname, dest, rename=False, overwrite=False): | |
"untar `file` into `dest`, creating a directory if the root contains more than one item" | |
import tempfile,shutil | |
with tempfile.TemporaryDirectory() as d: | |
out = Path(d)/remove_suffix(Path(fname).stem, '.tar') | |
out.mkdir() | |
if rename: dest = dest/out.name | |
else: | |
src = _unpack(fname, out) | |
dest = dest/src.name | |
if dest.exists(): | |
if overwrite: shutil.rmtree(dest) if dest.is_dir() else dest.unlink() | |
else: return dest | |
if rename: src = _unpack(fname, out) | |
shutil.move(str(src), dest) | |
return dest | |
# %% ../nbs/03_xtras.ipynb | |
def repo_details(url): | |
"Tuple of `owner,name` from ssh or https git repo `url`" | |
res = remove_suffix(url.strip(), '.git') | |
res = res.split(':')[-1] | |
return res.split('/')[-2:] | |
# %% ../nbs/03_xtras.ipynb | |
def run(cmd, *rest, same_in_win=False, ignore_ex=False, as_bytes=False, stderr=False): | |
"Pass `cmd` (splitting with `shlex` if string) to `subprocess.run`; return `stdout`; raise `IOError` if fails" | |
# Even the command is same on Windows, we have to add `cmd /c `" | |
import subprocess | |
if rest: | |
if sys.platform == 'win32' and same_in_win: | |
cmd = ('cmd', '/c', cmd, *rest) | |
else: | |
cmd = (cmd,)+rest | |
elif isinstance(cmd, str): | |
if sys.platform == 'win32' and same_in_win: cmd = 'cmd /c ' + cmd | |
import shlex | |
cmd = shlex.split(cmd) | |
elif isinstance(cmd, list): | |
if sys.platform == 'win32' and same_in_win: cmd = ['cmd', '/c'] + cmd | |
res = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
stdout = res.stdout | |
if stderr and res.stderr: stdout += b' ;; ' + res.stderr | |
if not as_bytes: stdout = stdout.decode().strip() | |
if ignore_ex: return (res.returncode, stdout) | |
if res.returncode: raise IOError(stdout) | |
return stdout | |
# %% ../nbs/03_xtras.ipynb | |
def open_file(fn, mode='r', **kwargs): | |
"Open a file, with optional compression if gz or bz2 suffix" | |
if isinstance(fn, io.IOBase): return fn | |
import bz2,gzip,zipfile | |
fn = Path(fn) | |
if fn.suffix=='.bz2': return bz2.BZ2File(fn, mode, **kwargs) | |
elif fn.suffix=='.gz' : return gzip.GzipFile(fn, mode, **kwargs) | |
elif fn.suffix=='.zip': return zipfile.ZipFile(fn, mode, **kwargs) | |
else: return open(fn,mode, **kwargs) | |
# %% ../nbs/03_xtras.ipynb | |
def save_pickle(fn, o): | |
"Save a pickle file, to a file name or opened file" | |
import pickle | |
with open_file(fn, 'wb') as f: pickle.dump(o, f) | |
# %% ../nbs/03_xtras.ipynb | |
def load_pickle(fn): | |
"Load a pickle file from a file name or opened file" | |
import pickle | |
with open_file(fn, 'rb') as f: return pickle.load(f) | |
# %% ../nbs/03_xtras.ipynb | |
def parse_env(s:str=None, fn:Union[str,Path]=None) -> dict: | |
"Parse a shell-style environment string or file" | |
assert bool(s)^bool(fn), "Must pass exactly one of `s` or `fn`" | |
if fn: s = Path(fn).read_text() | |
def _f(line): | |
m = re.match(r'^\s*(?:export\s+)?(\w+)\s*=\s*(["\']?)(.*?)(\2)\s*(?:#.*)?$', line).groups() | |
return m[0], m[2] | |
return dict(_f(o.strip()) for o in s.splitlines() if o.strip() and not re.match(r'\s*#', o)) | |
# %% ../nbs/03_xtras.ipynb | |
def expand_wildcards(code): | |
"Expand all wildcard imports in the given code string." | |
import ast,importlib | |
tree = ast.parse(code) | |
def _replace_node(code, old_node, new_node): | |
"Replace `old_node` in the source `code` with `new_node`." | |
lines = code.splitlines() | |
lnum = old_node.lineno | |
indent = ' ' * (len(lines[lnum-1]) - len(lines[lnum-1].lstrip())) | |
new_lines = [indent+line for line in ast.unparse(new_node).splitlines()] | |
lines[lnum-1 : old_node.end_lineno] = new_lines | |
return '\n'.join(lines) | |
def _expand_import(node, mod, existing): | |
"Create expanded import `node` in `tree` from wildcard import of `mod`." | |
mod_all = getattr(mod, '__all__', None) | |
available_names = set(mod_all) if mod_all is not None else set(dir(mod)) | |
used_names = {n.id for n in ast.walk(tree) if isinstance(n, ast.Name) and n.id in available_names} - existing | |
if not used_names: return node | |
names = [ast.alias(name=name, asname=None) for name in sorted(used_names)] | |
return ast.ImportFrom(module=node.module, names=names, level=node.level) | |
existing = set() | |
for node in ast.walk(tree): | |
if isinstance(node, ast.ImportFrom) and node.names[0].name != '*': existing.update(n.name for n in node.names) | |
elif isinstance(node, ast.Import): existing.update(n.name.split('.')[0] for n in node.names) | |
for node in ast.walk(tree): | |
if isinstance(node, ast.ImportFrom) and any(n.name == '*' for n in node.names): | |
new_import = _expand_import(node, importlib.import_module(node.module), existing) | |
code = _replace_node(code, node, new_import) | |
return code | |
# %% ../nbs/03_xtras.ipynb | |
def dict2obj(d, list_func=L, dict_func=AttrDict): | |
"Convert (possibly nested) dicts (or lists of dicts) to `AttrDict`" | |
if isinstance(d, (L,list)): return list_func(d).map(dict2obj) | |
if not isinstance(d, dict): return d | |
return dict_func(**{k:dict2obj(v) for k,v in d.items()}) | |
# %% ../nbs/03_xtras.ipynb | |
def obj2dict(d): | |
"Convert (possibly nested) AttrDicts (or lists of AttrDicts) to `dict`" | |
if isinstance(d, (L,list)): return list(L(d).map(obj2dict)) | |
if not isinstance(d, dict): return d | |
return dict(**{k:obj2dict(v) for k,v in d.items()}) | |
# %% ../nbs/03_xtras.ipynb | |
def _repr_dict(d, lvl): | |
if isinstance(d,dict): | |
its = [f"{k}: {_repr_dict(v,lvl+1)}" for k,v in d.items()] | |
elif isinstance(d,(list,L)): its = [_repr_dict(o,lvl+1) for o in d] | |
else: return str(d) | |
return '\n' + '\n'.join([" "*(lvl*2) + "- " + o for o in its]) | |
# %% ../nbs/03_xtras.ipynb | |
def repr_dict(d): | |
"Print nested dicts and lists, such as returned by `dict2obj`" | |
return _repr_dict(d,0).strip() | |
# %% ../nbs/03_xtras.ipynb | |
def is_listy(x): | |
"`isinstance(x, (tuple,list,L,slice,Generator))`" | |
return isinstance(x, (tuple,list,L,slice,Generator)) | |
# %% ../nbs/03_xtras.ipynb | |
def mapped(f, it): | |
"map `f` over `it`, unless it's not listy, in which case return `f(it)`" | |
return L(it).map(f) if is_listy(it) else f(it) | |
# %% ../nbs/03_xtras.ipynb | |
@patch | |
def readlines(self:Path, hint=-1, encoding='utf8'): | |
"Read the content of `self`" | |
with self.open(encoding=encoding) as f: return f.readlines(hint) | |
# %% ../nbs/03_xtras.ipynb | |
@patch | |
def read_json(self:Path, encoding=None, errors=None): | |
"Same as `read_text` followed by `loads`" | |
return loads(self.read_text(encoding=encoding, errors=errors)) | |
# %% ../nbs/03_xtras.ipynb | |
@patch | |
def mk_write(self:Path, data, encoding=None, errors=None, mode=511): | |
"Make all parent dirs of `self`, and write `data`" | |
self.parent.mkdir(exist_ok=True, parents=True, mode=mode) | |
self.write_text(data, encoding=encoding, errors=errors) | |
# %% ../nbs/03_xtras.ipynb | |
@patch | |
def relpath(self:Path, start=None): | |
"Same as `os.path.relpath`, but returns a `Path`, and resolves symlinks" | |
return Path(os.path.relpath(self.resolve(), Path(start).resolve())) | |
# %% ../nbs/03_xtras.ipynb | |
@patch | |
def ls(self:Path, n_max=None, file_type=None, file_exts=None): | |
"Contents of path as a list" | |
import mimetypes | |
extns=L(file_exts) | |
if file_type: extns += L(k for k,v in mimetypes.types_map.items() if v.startswith(file_type+'/')) | |
has_extns = len(extns)==0 | |
res = (o for o in self.iterdir() if has_extns or o.suffix in extns) | |
if n_max is not None: res = itertools.islice(res, n_max) | |
return L(res) | |
# %% ../nbs/03_xtras.ipynb | |
@patch | |
def __repr__(self:Path): | |
b = getattr(Path, 'BASE_PATH', None) | |
if b: | |
try: self = self.relative_to(b) | |
except: pass | |
return f"Path({self.as_posix()!r})" | |
# %% ../nbs/03_xtras.ipynb | |
@patch | |
def delete(self:Path): | |
"Delete a file, symlink, or directory tree" | |
if not self.exists(): return | |
if self.is_dir(): | |
import shutil | |
shutil.rmtree(self) | |
else: self.unlink() | |
# %% ../nbs/03_xtras.ipynb | |
class IterLen: | |
"Base class to add iteration to anything supporting `__len__` and `__getitem__`" | |
def __iter__(self): return (self[i] for i in range_of(self)) | |
# %% ../nbs/03_xtras.ipynb | |
@docs | |
class ReindexCollection(GetAttr, IterLen): | |
"Reindexes collection `coll` with indices `idxs` and optional LRU cache of size `cache`" | |
_default='coll' | |
def __init__(self, coll, idxs=None, cache=None, tfm=noop): | |
if idxs is None: idxs = L.range(coll) | |
store_attr() | |
if cache is not None: self._get = functools.lru_cache(maxsize=cache)(self._get) | |
def _get(self, i): return self.tfm(self.coll[i]) | |
def __getitem__(self, i): return self._get(self.idxs[i]) | |
def __len__(self): return len(self.coll) | |
def reindex(self, idxs): self.idxs = idxs | |
def shuffle(self): | |
import random | |
random.shuffle(self.idxs) | |
def cache_clear(self): self._get.cache_clear() | |
def __getstate__(self): return {'coll': self.coll, 'idxs': self.idxs, 'cache': self.cache, 'tfm': self.tfm} | |
def __setstate__(self, s): self.coll,self.idxs,self.cache,self.tfm = s['coll'],s['idxs'],s['cache'],s['tfm'] | |
_docs = dict(reindex="Replace `self.idxs` with idxs", | |
shuffle="Randomly shuffle indices", | |
cache_clear="Clear LRU cache") | |
# %% ../nbs/03_xtras.ipynb | |
def _is_type_dispatch(x): return type(x).__name__ == "TypeDispatch" | |
def _unwrapped_type_dispatch_func(x): return x.first() if _is_type_dispatch(x) else x | |
def _is_property(x): return type(x)==property | |
def _has_property_getter(x): return _is_property(x) and hasattr(x, 'fget') and hasattr(x.fget, 'func') | |
def _property_getter(x): return x.fget.func if _has_property_getter(x) else x | |
def _unwrapped_func(x): | |
x = _unwrapped_type_dispatch_func(x) | |
x = _property_getter(x) | |
return x | |
def get_source_link(func): | |
"Return link to `func` in source code" | |
import inspect | |
func = _unwrapped_func(func) | |
try: line = inspect.getsourcelines(func)[1] | |
except Exception: return '' | |
mod = inspect.getmodule(func) | |
module = mod.__name__.replace('.', '/') + '.py' | |
try: | |
nbdev_mod = import_module(mod.__package__.split('.')[0] + '._nbdev') | |
return f"{nbdev_mod.git_url}{module}#L{line}" | |
except: return f"{module}#L{line}" | |
# %% ../nbs/03_xtras.ipynb | |
def truncstr(s:str, maxlen:int, suf:str='…', space='')->str: | |
"Truncate `s` to length `maxlen`, adding suffix `suf` if truncated" | |
return s[:maxlen-len(suf)]+suf if len(s)+len(space)>maxlen else s+space | |
# %% ../nbs/03_xtras.ipynb | |
spark_chars = '▁▂▃▅▆▇' | |
# %% ../nbs/03_xtras.ipynb | |
def _ceil(x, lim=None): return x if (not lim or x <= lim) else lim | |
def _sparkchar(x, mn, mx, incr, empty_zero): | |
if x is None or (empty_zero and not x): return ' ' | |
if incr == 0: return spark_chars[0] | |
res = int((_ceil(x,mx)-mn)/incr-0.5) | |
return spark_chars[res] | |
# %% ../nbs/03_xtras.ipynb | |
def sparkline(data, mn=None, mx=None, empty_zero=False): | |
"Sparkline for `data`, with `None`s (and zero, if `empty_zero`) shown as empty column" | |
valid = [o for o in data if o is not None] | |
if not valid: return ' ' | |
mn,mx,n = ifnone(mn,min(valid)),ifnone(mx,max(valid)),len(spark_chars) | |
res = [_sparkchar(x=o, mn=mn, mx=mx, incr=(mx-mn)/n, empty_zero=empty_zero) for o in data] | |
return ''.join(res) | |
# %% ../nbs/03_xtras.ipynb | |
def modify_exception( | |
e:Exception, # An exception | |
msg:str=None, # A custom message | |
replace:bool=False, # Whether to replace e.args with [msg] | |
) -> Exception: | |
"Modifies `e` with a custom message attached" | |
e.args = [f'{e.args[0]} {msg}'] if not replace and len(e.args) > 0 else [msg] | |
return e | |
# %% ../nbs/03_xtras.ipynb | |
def round_multiple(x, mult, round_down=False): | |
"Round `x` to nearest multiple of `mult`" | |
def _f(x_): return (int if round_down else round)(x_/mult)*mult | |
res = L(x).map(_f) | |
return res if is_listy(x) else res[0] | |
# %% ../nbs/03_xtras.ipynb | |
def set_num_threads(nt): | |
"Get numpy (and others) to use `nt` threads" | |
try: import mkl; mkl.set_num_threads(nt) | |
except: pass | |
try: import torch; torch.set_num_threads(nt) | |
except: pass | |
os.environ['IPC_ENABLE']='1' | |
for o in ['OPENBLAS_NUM_THREADS','NUMEXPR_NUM_THREADS','OMP_NUM_THREADS','MKL_NUM_THREADS']: | |
os.environ[o] = str(nt) | |
# %% ../nbs/03_xtras.ipynb | |
def join_path_file(file, path, ext=''): | |
"Return `path/file` if file is a string or a `Path`, file otherwise" | |
if not isinstance(file, (str, Path)): return file | |
path.mkdir(parents=True, exist_ok=True) | |
return path/f'{file}{ext}' | |
# %% ../nbs/03_xtras.ipynb | |
def autostart(g): | |
"Decorator that automatically starts a generator" | |
@functools.wraps(g) | |
def f(): | |
r = g() | |
next(r) | |
return r | |
return f | |
# %% ../nbs/03_xtras.ipynb | |
class EventTimer: | |
"An event timer with history of `store` items of time `span`" | |
def __init__(self, store=5, span=60): | |
import collections | |
self.hist,self.span,self.last = collections.deque(maxlen=store),span,perf_counter() | |
self._reset() | |
def _reset(self): self.start,self.events = self.last,0 | |
def add(self, n=1): | |
"Record `n` events" | |
if self.duration>self.span: | |
self.hist.append(self.freq) | |
self._reset() | |
self.events +=n | |
self.last = perf_counter() | |
@property | |
def duration(self): return perf_counter()-self.start | |
@property | |
def freq(self): return self.events/self.duration | |
# %% ../nbs/03_xtras.ipynb | |
_fmt = string.Formatter() | |
# %% ../nbs/03_xtras.ipynb | |
def stringfmt_names(s:str)->list: | |
"Unique brace-delimited names in `s`" | |
return uniqueify(o[1] for o in _fmt.parse(s) if o[1]) | |
# %% ../nbs/03_xtras.ipynb | |
class PartialFormatter(string.Formatter): | |
"A `string.Formatter` that doesn't error on missing fields, and tracks missing fields and unused args" | |
def __init__(self): | |
self.missing = set() | |
super().__init__() | |
def get_field(self, nm, args, kwargs): | |
try: return super().get_field(nm, args, kwargs) | |
except KeyError: | |
self.missing.add(nm) | |
return '{'+nm+'}',nm | |
def check_unused_args(self, used, args, kwargs): | |
self.xtra = filter_keys(kwargs, lambda o: o not in used) | |
# %% ../nbs/03_xtras.ipynb | |
def partial_format(s:str, **kwargs): | |
"string format `s`, ignoring missing field errors, returning missing and extra fields" | |
fmt = PartialFormatter() | |
res = fmt.format(s, **kwargs) | |
return res,list(fmt.missing),fmt.xtra | |
# %% ../nbs/03_xtras.ipynb | |
def utc2local(dt:datetime)->datetime: | |
"Convert `dt` from UTC to local time" | |
return dt.replace(tzinfo=timezone.utc).astimezone(tz=None) | |
# %% ../nbs/03_xtras.ipynb | |
def local2utc(dt:datetime)->datetime: | |
"Convert `dt` from local to UTC time" | |
return dt.replace(tzinfo=None).astimezone(tz=timezone.utc) | |
# %% ../nbs/03_xtras.ipynb | |
def trace(f): | |
"Add `set_trace` to an existing function `f`" | |
from pdb import set_trace | |
if getattr(f, '_traced', False): return f | |
def _inner(*args,**kwargs): | |
set_trace() | |
return f(*args,**kwargs) | |
_inner._traced = True | |
return _inner | |
# %% ../nbs/03_xtras.ipynb | |
@contextmanager | |
def modified_env(*delete, **replace): | |
"Context manager temporarily modifying `os.environ` by deleting `delete` and replacing `replace`" | |
prev = dict(os.environ) | |
try: | |
os.environ.update(replace) | |
for k in delete: os.environ.pop(k, None) | |
yield | |
finally: | |
os.environ.clear() | |
os.environ.update(prev) | |
# %% ../nbs/03_xtras.ipynb | |
class ContextManagers(GetAttr): | |
"Wrapper for `contextlib.ExitStack` which enters a collection of context managers" | |
def __init__(self, mgrs): self.default,self.stack = L(mgrs),ExitStack() | |
def __enter__(self): self.default.map(self.stack.enter_context) | |
def __exit__(self, *args, **kwargs): self.stack.__exit__(*args, **kwargs) | |
# %% ../nbs/03_xtras.ipynb | |
def shufflish(x, pct=0.04): | |
"Randomly relocate items of `x` up to `pct` of `len(x)` from their starting location" | |
n = len(x) | |
import random | |
return L(x[i] for i in sorted(range_of(x), key=lambda o: o+n*(1+random.random()*pct))) | |
# %% ../nbs/03_xtras.ipynb | |
def console_help( | |
libname:str): # name of library for console script listing | |
"Show help for all console scripts from `libname`" | |
from fastcore.style import S | |
from pkg_resources import iter_entry_points as ep | |
for e in ep('console_scripts'): | |
if e.module_name == libname or e.module_name.startswith(libname+'.'): | |
nm = S.bold.light_blue(e.name) | |
print(f'{nm:45}{e.load().__doc__}') | |
# %% ../nbs/03_xtras.ipynb | |
def hl_md(s, lang='xml', show=True): | |
"Syntax highlight `s` using `lang`." | |
md = f'```{lang}\n{s}\n```' | |
if not show: return md | |
try: | |
from IPython import display | |
return display.Markdown(md) | |
except ImportError: print(s) | |
# %% ../nbs/03_xtras.ipynb | |
def type2str(typ:type)->str: | |
"Stringify `typ`" | |
if typ is None or typ is NoneType: return 'None' | |
if hasattr(typ, '__origin__'): | |
args = ", ".join(type2str(arg) for arg in typ.__args__) | |
if typ.__origin__ is Union: return f"Union[{args}]" | |
return f"{typ.__origin__.__name__}[{args}]" | |
elif isinstance(typ, type): return typ.__name__ | |
return str(typ) | |
# %% ../nbs/03_xtras.ipynb | |
def dataclass_src(cls): | |
src = f"@dataclass\nclass {cls.__name__}:\n" | |
for f in dataclasses.fields(cls): | |
d = "" if f.default is dataclasses.MISSING else f" = {f.default!r}" | |
src += f" {f.name}: {type2str(f.type)}{d}\n" | |
return src | |
# %% ../nbs/03_xtras.ipynb | |
class Unset(Enum): | |
_Unset='' | |
def __repr__(self): return 'UNSET' | |
def __str__ (self): return 'UNSET' | |
def __bool__(self): return False | |
UNSET = Unset._Unset | |
# %% ../nbs/03_xtras.ipynb | |
def nullable_dc(cls): | |
"Like `dataclass`, but default of `UNSET` added to fields without defaults" | |
for k,v in get_annotations_ex(cls)[0].items(): | |
if not hasattr(cls,k): setattr(cls, k, field(default=UNSET)) | |
return dataclass(cls) | |
# %% ../nbs/03_xtras.ipynb | |
def make_nullable(clas): | |
if hasattr(clas, '_nullable'): return | |
clas._nullable = True | |
original_init = clas.__init__ | |
def __init__(self, *args, **kwargs): | |
flds = fields(clas) | |
dargs = {k.name:v for k,v in zip(flds, args)} | |
for f in flds: | |
nm = f.name | |
if nm not in dargs and nm not in kwargs and f.default is None and f.default_factory is MISSING: | |
kwargs[nm] = UNSET | |
original_init(self, *args, **kwargs) | |
clas.__init__ = __init__ | |
for f in fields(clas): | |
if f.default is MISSING and f.default_factory is MISSING: f.default = None | |
return clas | |
# %% ../nbs/03_xtras.ipynb | |
def flexiclass(cls): | |
"Convert `cls` into a `dataclass` like `make_nullable`" | |
if is_dataclass(cls): return make_nullable(cls) | |
for k,v in get_annotations_ex(cls)[0].items(): | |
if not hasattr(cls,k) or getattr(cls,k) is MISSING: | |
setattr(cls, k, field(default=UNSET)) | |
return dataclass(cls, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False) | |
# %% ../nbs/03_xtras.ipynb | |
def asdict(o)->dict: | |
"Convert `o` to a `dict`, supporting dataclasses, namedtuples, iterables, and `__dict__` attrs." | |
if isinstance(o, dict): return o | |
if is_dataclass(o): r = dataclasses.asdict(o) | |
elif hasattr(o, '_asdict'): r = o._asdict() | |
elif hasattr(o, '__iter__'): | |
try: r = dict(o) | |
except TypeError: pass | |
elif hasattr(o, '__dict__'): r = o.__dict__ | |
else: raise TypeError(f'Can not convert {o} to a dict') | |
return {k:v for k,v in r.items() if v not in (UNSET,MISSING)} | |
# %% ../nbs/03_xtras.ipynb | |
def is_typeddict(cls:type)->bool: | |
"Check if `cls` is a `TypedDict`" | |
attrs = 'annotations', 'required_keys', 'optional_keys' | |
return isinstance(cls, type) and all(hasattr(cls, f'__{attr}__') for attr in attrs) | |
# %% ../nbs/03_xtras.ipynb | |
def is_namedtuple(cls): | |
"`True` if `cls` is a namedtuple type" | |
return issubclass(cls, tuple) and hasattr(cls, '_fields') | |
# %% ../nbs/03_xtras.ipynb | |
def flexicache(*funcs, maxsize=128): | |
"Like `lru_cache`, but customisable with policy `funcs`" | |
import asyncio | |
def _f(func): | |
cache,states = {}, [None]*len(funcs) | |
def _cache_logic(key, execute_func): | |
if key in cache: | |
result,states = cache[key] | |
if not any(f(state) for f,state in zip(funcs, states)): | |
cache[key] = cache.pop(key) | |
return result | |
del cache[key] | |
try: newres = execute_func() | |
except: | |
if key not in cache: raise | |
cache[key] = cache.pop(key) | |
return result | |
cache[key] = (newres, [f(None) for f in funcs]) | |
if len(cache) > maxsize: cache.popitem() | |
return newres | |
@wraps(func) | |
def wrapper(*args, **kwargs): | |
return _cache_logic(f"{args} // {kwargs}", lambda: func(*args, **kwargs)) | |
@wraps(func) | |
async def async_wrapper(*args, **kwargs): | |
return await _cache_logic(f"{args} // {kwargs}", lambda: asyncio.ensure_future(func(*args, **kwargs))) | |
return async_wrapper if asyncio.iscoroutinefunction(func) else wrapper | |
return _f | |
# %% ../nbs/03_xtras.ipynb | |
def time_policy(seconds): | |
"A `flexicache` policy that expires cached items after `seconds` have passed" | |
def policy(last_time): | |
now = time() | |
return now if last_time is None or now-last_time>seconds else None | |
return policy | |
# %% ../nbs/03_xtras.ipynb | |
def mtime_policy(filepath): | |
"A `flexicache` policy that expires cached items after `filepath` modified-time changes" | |
def policy(mtime): | |
current_mtime = getmtime(filepath) | |
return current_mtime if mtime is None or current_mtime>mtime else None | |
return policy | |
# %% ../nbs/03_xtras.ipynb | |
def timed_cache(seconds=60, maxsize=128): | |
"Like `lru_cache`, but also with time-based eviction" | |
return flexicache(time_policy(seconds), maxsize=maxsize) | |
--- | |
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/03a_parallel.ipynb. | |
# %% auto 0 | |
__all__ = ['threaded', 'startthread', 'startproc', 'parallelable', 'ThreadPoolExecutor', 'ProcessPoolExecutor', 'parallel', | |
'add_one', 'run_procs', 'parallel_gen'] | |
# %% ../nbs/03a_parallel.ipynb | |
from .imports import * | |
from .basics import * | |
from .foundation import * | |
from .meta import * | |
from .xtras import * | |
from functools import wraps | |
import concurrent.futures,time | |
from multiprocessing import Process,Queue,Manager,set_start_method,get_all_start_methods,get_context | |
from threading import Thread | |
try: | |
if sys.platform == 'darwin' and IN_NOTEBOOK: set_start_method("fork") | |
except: pass | |
# %% ../nbs/03a_parallel.ipynb | |
def threaded(process=False): | |
"Run `f` in a `Thread` (or `Process` if `process=True`), and returns it" | |
def _r(f): | |
def g(_obj_td, *args, **kwargs): | |
res = f(*args, **kwargs) | |
_obj_td.result = res | |
@wraps(f) | |
def _f(*args, **kwargs): | |
res = (Thread,Process)[process](target=g, args=args, kwargs=kwargs) | |
res._args = (res,)+res._args | |
res.start() | |
return res | |
return _f | |
if callable(process): | |
o = process | |
process = False | |
return _r(o) | |
return _r | |
# %% ../nbs/03a_parallel.ipynb | |
def startthread(f): | |
"Like `threaded`, but start thread immediately" | |
return threaded(f)() | |
# %% ../nbs/03a_parallel.ipynb | |
def startproc(f): | |
"Like `threaded(True)`, but start Process immediately" | |
return threaded(True)(f)() | |
# %% ../nbs/03a_parallel.ipynb | |
def _call(lock, pause, n, g, item): | |
l = False | |
if pause: | |
try: | |
l = lock.acquire(timeout=pause*(n+2)) | |
time.sleep(pause) | |
finally: | |
if l: lock.release() | |
return g(item) | |
# %% ../nbs/03a_parallel.ipynb | |
def parallelable(param_name, num_workers, f=None): | |
f_in_main = f == None or sys.modules[f.__module__].__name__ == "__main__" | |
if sys.platform == "win32" and IN_NOTEBOOK and num_workers > 0 and f_in_main: | |
print("Due to IPython and Windows limitation, python multiprocessing isn't available now.") | |
print(f"So `{param_name}` has to be changed to 0 to avoid getting stuck") | |
return False | |
return True | |
# %% ../nbs/03a_parallel.ipynb | |
class ThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor): | |
"Same as Python's ThreadPoolExecutor, except can pass `max_workers==0` for serial execution" | |
def __init__(self, max_workers=defaults.cpus, on_exc=print, pause=0, **kwargs): | |
if max_workers is None: max_workers=defaults.cpus | |
store_attr() | |
self.not_parallel = max_workers==0 | |
if self.not_parallel: max_workers=1 | |
super().__init__(max_workers, **kwargs) | |
def map(self, f, items, *args, timeout=None, chunksize=1, **kwargs): | |
if self.not_parallel == False: self.lock = Manager().Lock() | |
g = partial(f, *args, **kwargs) | |
if self.not_parallel: return map(g, items) | |
_g = partial(_call, self.lock, self.pause, self.max_workers, g) | |
try: return super().map(_g, items, timeout=timeout, chunksize=chunksize) | |
except Exception as e: self.on_exc(e) | |
# %% ../nbs/03a_parallel.ipynb | |
@delegates() | |
class ProcessPoolExecutor(concurrent.futures.ProcessPoolExecutor): | |
"Same as Python's ProcessPoolExecutor, except can pass `max_workers==0` for serial execution" | |
def __init__(self, max_workers=defaults.cpus, on_exc=print, pause=0, **kwargs): | |
if max_workers is None: max_workers=defaults.cpus | |
store_attr() | |
self.not_parallel = max_workers==0 | |
if self.not_parallel: max_workers=1 | |
super().__init__(max_workers, **kwargs) | |
def map(self, f, items, *args, timeout=None, chunksize=1, **kwargs): | |
if not parallelable('max_workers', self.max_workers, f): self.max_workers = 0 | |
self.not_parallel = self.max_workers==0 | |
if self.not_parallel: self.max_workers=1 | |
if self.not_parallel == False: self.lock = Manager().Lock() | |
g = partial(f, *args, **kwargs) | |
if self.not_parallel: return map(g, items) | |
_g = partial(_call, self.lock, self.pause, self.max_workers, g) | |
try: return super().map(_g, items, timeout=timeout, chunksize=chunksize) | |
except Exception as e: self.on_exc(e) | |
# %% ../nbs/03a_parallel.ipynb | |
try: from fastprogress import progress_bar | |
except: progress_bar = None | |
# %% ../nbs/03a_parallel.ipynb | |
def parallel(f, items, *args, n_workers=defaults.cpus, total=None, progress=None, pause=0, | |
method=None, threadpool=False, timeout=None, chunksize=1, **kwargs): | |
"Applies `func` in parallel to `items`, using `n_workers`" | |
kwpool = {} | |
if threadpool: pool = ThreadPoolExecutor | |
else: | |
if not method and sys.platform == 'darwin': method='fork' | |
if method: kwpool['mp_context'] = get_context(method) | |
pool = ProcessPoolExecutor | |
with pool(n_workers, pause=pause, **kwpool) as ex: | |
r = ex.map(f,items, *args, timeout=timeout, chunksize=chunksize, **kwargs) | |
if progress and progress_bar: | |
if total is None: total = len(items) | |
r = progress_bar(r, total=total, leave=False) | |
return L(r) | |
# %% ../nbs/03a_parallel.ipynb | |
def add_one(x, a=1): | |
# this import is necessary for multiprocessing in notebook on windows | |
import random | |
time.sleep(random.random()/80) | |
return x+a | |
# %% ../nbs/03a_parallel.ipynb | |
def run_procs(f, f_done, args): | |
"Call `f` for each item in `args` in parallel, yielding `f_done`" | |
processes = L(args).map(Process, args=arg0, target=f) | |
for o in processes: o.start() | |
yield from f_done() | |
processes.map(Self.join()) | |
# %% ../nbs/03a_parallel.ipynb | |
def _f_pg(obj, queue, batch, start_idx): | |
for i,b in enumerate(obj(batch)): queue.put((start_idx+i,b)) | |
def _done_pg(queue, items): return (queue.get() for _ in items) | |
# %% ../nbs/03a_parallel.ipynb | |
def parallel_gen(cls, items, n_workers=defaults.cpus, **kwargs): | |
"Instantiate `cls` in `n_workers` procs & call each on a subset of `items` in parallel." | |
if not parallelable('n_workers', n_workers): n_workers = 0 | |
if n_workers==0: | |
yield from enumerate(list(cls(**kwargs)(items))) | |
return | |
batches = L(chunked(items, n_chunks=n_workers)) | |
idx = L(itertools.accumulate(0 + batches.map(len))) | |
queue = Queue() | |
if progress_bar: items = progress_bar(items, leave=False) | |
f=partial(_f_pg, cls(**kwargs), queue) | |
done=partial(_done_pg, queue, items) | |
yield from run_procs(f, done, L(batches,idx).zip()) | |
--- | |
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/12_py2pyi.ipynb. | |
# %% auto 0 | |
__all__ = ['functypes', 'imp_mod', 'has_deco', 'sig2str', 'ast_args', 'create_pyi', 'py2pyi', 'replace_wildcards'] | |
# %% ../nbs/12_py2pyi.ipynb | |
import ast, sys, inspect, re, os, importlib.util, importlib.machinery | |
from ast import parse, unparse | |
from inspect import signature, getsource | |
from .utils import * | |
from .meta import delegates | |
# %% ../nbs/12_py2pyi.ipynb | |
def imp_mod(module_path, package=None): | |
"Import dynamically the module referenced in `fn`" | |
module_path = str(module_path) | |
module_name = os.path.splitext(os.path.basename(module_path))[0] | |
spec = importlib.machinery.ModuleSpec(module_name, None, origin=module_path) | |
module = importlib.util.module_from_spec(spec) | |
spec.loader = importlib.machinery.SourceFileLoader(module_name, module_path) | |
if package is not None: module.__package__ = package | |
module.__file__ = os.path.abspath(module_path) | |
spec.loader.exec_module(module) | |
return module | |
# %% ../nbs/12_py2pyi.ipynb | |
def _get_tree(mod): | |
return parse(getsource(mod)) | |
# %% ../nbs/12_py2pyi.ipynb | |
@patch | |
def __repr__(self:ast.AST): | |
return unparse(self) | |
@patch | |
def _repr_markdown_(self:ast.AST): | |
return f"""```python | |
{self!r} | |
```""" | |
# %% ../nbs/12_py2pyi.ipynb | |
functypes = (ast.FunctionDef,ast.AsyncFunctionDef) | |
# %% ../nbs/12_py2pyi.ipynb | |
def _deco_id(d:Union[ast.Name,ast.Attribute])->bool: | |
"Get the id for AST node `d`" | |
return d.id if isinstance(d, ast.Name) else d.func.id | |
def has_deco(node:Union[ast.FunctionDef,ast.AsyncFunctionDef], name:str)->bool: | |
"Check if a function node `node` has a decorator named `name`" | |
return any(_deco_id(d)==name for d in getattr(node, 'decorator_list', [])) | |
# %% ../nbs/12_py2pyi.ipynb | |
def _get_proc(node): | |
if isinstance(node, ast.ClassDef): return _proc_class | |
if not isinstance(node, functypes): return None | |
if not has_deco(node, 'delegates'): return _proc_body | |
if has_deco(node, 'patch'): return _proc_patched | |
return _proc_func | |
# %% ../nbs/12_py2pyi.ipynb | |
def _proc_tree(tree, mod): | |
for node in tree.body: | |
proc = _get_proc(node) | |
if proc: proc(node, mod) | |
# %% ../nbs/12_py2pyi.ipynb | |
def _proc_mod(mod): | |
tree = _get_tree(mod) | |
_proc_tree(tree, mod) | |
return tree | |
# %% ../nbs/12_py2pyi.ipynb | |
def sig2str(sig): | |
s = str(sig) | |
s = re.sub(r"<class '(.*?)'>", r'\1', s) | |
s = re.sub(r"dynamic_module\.", "", s) | |
return s | |
# %% ../nbs/12_py2pyi.ipynb | |
def ast_args(func): | |
sig = signature(func) | |
return ast.parse(f"def _{sig2str(sig)}: ...").body[0].args | |
# %% ../nbs/12_py2pyi.ipynb | |
def _body_ellip(n: ast.AST): | |
stidx = 1 if isinstance(n.body[0], ast.Expr) and isinstance(n.body[0].value, ast.Str) else 0 | |
n.body[stidx:] = [ast.Expr(ast.Constant(...))] | |
# %% ../nbs/12_py2pyi.ipynb | |
def _update_func(node, sym): | |
"""Replace the parameter list of the source code of a function `f` with a different signature. | |
Replace the body of the function with just `pass`, and remove any decorators named 'delegates'""" | |
node.args = ast_args(sym) | |
_body_ellip(node) | |
node.decorator_list = [d for d in node.decorator_list if _deco_id(d) != 'delegates'] | |
# %% ../nbs/12_py2pyi.ipynb | |
def _proc_body(node, mod): _body_ellip(node) | |
# %% ../nbs/12_py2pyi.ipynb | |
def _proc_func(node, mod): | |
sym = getattr(mod, node.name) | |
_update_func(node, sym) | |
# %% ../nbs/12_py2pyi.ipynb | |
def _proc_patched(node, mod): | |
ann = node.args.args[0].annotation | |
if hasattr(ann, 'elts'): ann = ann.elts[0] | |
cls = getattr(mod, ann.id) | |
sym = getattr(cls, node.name) | |
_update_func(node, sym) | |
# %% ../nbs/12_py2pyi.ipynb | |
def _proc_class(node, mod): | |
cls = getattr(mod, node.name) | |
_proc_tree(node, cls) | |
# %% ../nbs/12_py2pyi.ipynb | |
def create_pyi(fn, package=None): | |
"Convert `fname.py` to `fname.pyi` by removing function bodies and expanding `delegates` kwargs" | |
fn = Path(fn) | |
mod = imp_mod(fn, package=package) | |
tree = _proc_mod(mod) | |
res = unparse(tree) | |
fn.with_suffix('.pyi').write_text(res) | |
# %% ../nbs/12_py2pyi.ipynb | |
from .script import call_parse | |
# %% ../nbs/12_py2pyi.ipynb | |
@call_parse | |
def py2pyi(fname:str, # The file name to convert | |
package:str=None # The parent package | |
): | |
"Convert `fname.py` to `fname.pyi` by removing function bodies and expanding `delegates` kwargs" | |
create_pyi(fname, package) | |
# %% ../nbs/12_py2pyi.ipynb | |
@call_parse | |
def replace_wildcards( | |
# Path to the Python file to process | |
path: str): | |
"Expand wildcard imports in the specified Python file." | |
path = Path(path) | |
path.write_text(expand_wildcards(path.read_text())) | |
---- | |
import sys,os,re,typing,itertools,operator,functools,math,warnings,functools,io,enum | |
from operator import itemgetter,attrgetter | |
from warnings import warn | |
from typing import Iterable,Generator,Sequence,Iterator,List,Set,Dict,Union,Optional,Tuple | |
from functools import partial,reduce | |
from pathlib import Path | |
try: | |
from types import WrapperDescriptorType,MethodWrapperType,MethodDescriptorType | |
except ImportError: | |
WrapperDescriptorType = type(object.__init__) | |
MethodWrapperType = type(object().__str__) | |
MethodDescriptorType = type(str.join) | |
from types import BuiltinFunctionType,BuiltinMethodType,MethodType,FunctionType,SimpleNamespace | |
NoneType = type(None) | |
string_classes = (str,bytes) | |
def is_iter(o): | |
"Test whether `o` can be used in a `for` loop" | |
#Rank 0 tensors in PyTorch are not really iterable | |
return isinstance(o, (Iterable,Generator)) and getattr(o,'ndim',1) | |
def is_coll(o): | |
"Test whether `o` is a collection (i.e. has a usable `len`)" | |
#Rank 0 tensors in PyTorch do not have working `len` | |
return hasattr(o, '__len__') and getattr(o,'ndim',1) | |
def all_equal(a,b): | |
"Compares whether `a` and `b` are the same length and have the same contents" | |
if not is_iter(b): return a==b | |
return all(equals(a_,b_) for a_,b_ in itertools.zip_longest(a,b)) | |
def noop (x=None, *args, **kwargs): | |
"Do nothing" | |
return x | |
def noops(self, x=None, *args, **kwargs): | |
"Do nothing (method)" | |
return x | |
def any_is_instance(t, *args): return any(isinstance(a,t) for a in args) | |
def isinstance_str(x, cls_name): | |
"Like `isinstance`, except takes a type name instead of a type" | |
return cls_name in [t.__name__ for t in type(x).__mro__] | |
def array_equal(a,b): | |
if hasattr(a, '__array__'): a = a.__array__() | |
if hasattr(b, '__array__'): b = b.__array__() | |
if isinstance_str(a, 'ndarray') and isinstance_str(b, 'ndarray'): return (a==b).all() | |
return all_equal(a,b) | |
def df_equal(a,b): return a.equals(b) if isinstance_str(a, 'NDFrame') else b.equals(a) | |
def equals(a,b): | |
"Compares `a` and `b` for equality; supports sublists, tensors and arrays too" | |
if (a is None) ^ (b is None): return False | |
if any_is_instance(type,a,b): return a==b | |
if hasattr(a, '__array_eq__'): return a.__array_eq__(b) | |
if hasattr(b, '__array_eq__'): return b.__array_eq__(a) | |
cmp = (array_equal if isinstance_str(a, 'ndarray') or isinstance_str(b, 'ndarray') else | |
array_equal if isinstance_str(a, 'Tensor') or isinstance_str(b, 'Tensor') else | |
df_equal if isinstance_str(a, 'NDFrame') or isinstance_str(b, 'NDFrame') else | |
operator.eq if any_is_instance((str,dict,set), a, b) else | |
all_equal if is_iter(a) or is_iter(b) else | |
operator.eq) | |
return cmp(a,b) | |
def ipython_shell(): | |
"Same as `get_ipython` but returns `False` if not in IPython" | |
try: return get_ipython() | |
except NameError: return False | |
def in_ipython(): | |
"Check if code is running in some kind of IPython environment" | |
return bool(ipython_shell()) | |
def in_colab(): | |
"Check if the code is running in Google Colaboratory" | |
return 'google.colab' in sys.modules | |
def in_jupyter(): | |
"Check if the code is running in a jupyter notebook" | |
if not in_ipython(): return False | |
return ipython_shell().__class__.__name__ == 'ZMQInteractiveShell' | |
def in_notebook(): | |
"Check if the code is running in a jupyter notebook" | |
return in_colab() or in_jupyter() | |
IN_IPYTHON,IN_JUPYTER,IN_COLAB,IN_NOTEBOOK = in_ipython(),in_jupyter(),in_colab(),in_notebook() | |
def remove_prefix(text, prefix): | |
"Temporary until py39 is a prereq" | |
return text[text.startswith(prefix) and len(prefix):] | |
def remove_suffix(text, suffix): | |
"Temporary until py39 is a prereq" | |
return text[:-len(suffix)] if text.endswith(suffix) else text | |
---- | |
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/04_dispatch.ipynb. | |
# %% ../nbs/04_dispatch.ipynb 1 | |
from __future__ import annotations | |
from .imports import * | |
from .foundation import * | |
from .utils import * | |
from collections import defaultdict | |
# %% auto 0 | |
__all__ = ['typedispatch', 'lenient_issubclass', 'sorted_topologically', 'TypeDispatch', 'DispatchReg', 'retain_meta', | |
'default_set_meta', 'cast', 'retain_type', 'retain_types', 'explode_types'] | |
# %% ../nbs/04_dispatch.ipynb | |
def lenient_issubclass(cls, types): | |
"If possible return whether `cls` is a subclass of `types`, otherwise return False." | |
if cls is object and types is not object: return False # treat `object` as highest level | |
try: return isinstance(cls, types) or issubclass(cls, types) | |
except: return False | |
# %% ../nbs/04_dispatch.ipynb | |
def sorted_topologically(iterable, *, cmp=operator.lt, reverse=False): | |
"Return a new list containing all items from the iterable sorted topologically" | |
l,res = L(list(iterable)),[] | |
for _ in range(len(l)): | |
t = l.reduce(lambda x,y: y if cmp(y,x) else x) | |
res.append(t), l.remove(t) | |
return res[::-1] if reverse else res | |
# %% ../nbs/04_dispatch.ipynb | |
def _chk_defaults(f, ann): | |
pass | |
# Implementation removed until we can figure out how to do this without `inspect` module | |
# try: # Some callables don't have signatures, so ignore those errors | |
# params = list(inspect.signature(f).parameters.values())[:min(len(ann),2)] | |
# if any(p.default!=inspect.Parameter.empty for p in params): | |
# warn(f"{f.__name__} has default params. These will be ignored.") | |
# except ValueError: pass | |
# %% ../nbs/04_dispatch.ipynb | |
def _p2_anno(f): | |
"Get the 1st 2 annotations of `f`, defaulting to `object`" | |
hints = type_hints(f) | |
ann = [o for n,o in hints.items() if n!='return'] | |
if callable(f): _chk_defaults(f, ann) | |
while len(ann)<2: ann.append(object) | |
return ann[:2] | |
# %% ../nbs/04_dispatch.ipynb | |
class _TypeDict: | |
def __init__(self): self.d,self.cache = {},{} | |
def _reset(self): | |
self.d = {k:self.d[k] for k in sorted_topologically(self.d, cmp=lenient_issubclass)} | |
self.cache = {} | |
def add(self, t, f): | |
"Add type `t` and function `f`" | |
if not isinstance(t, tuple): t = tuple(L(union2tuple(t))) | |
for t_ in t: self.d[t_] = f | |
self._reset() | |
def all_matches(self, k): | |
"Find first matching type that is a super-class of `k`" | |
if k not in self.cache: | |
types = [f for f in self.d if lenient_issubclass(k,f)] | |
self.cache[k] = [self.d[o] for o in types] | |
return self.cache[k] | |
def __getitem__(self, k): | |
"Find first matching type that is a super-class of `k`" | |
res = self.all_matches(k) | |
return res[0] if len(res) else None | |
def __repr__(self): return self.d.__repr__() | |
def first(self): return first(self.d.values()) | |
# %% ../nbs/04_dispatch.ipynb | |
class TypeDispatch: | |
"Dictionary-like object; `__getitem__` matches keys of types using `issubclass`" | |
def __init__(self, funcs=(), bases=()): | |
self.funcs,self.bases = _TypeDict(),L(bases).filter(is_not(None)) | |
for o in L(funcs): self.add(o) | |
self.inst = None | |
self.owner = None | |
def add(self, f): | |
"Add type `t` and function `f`" | |
if isinstance(f, staticmethod): a0,a1 = _p2_anno(f.__func__) | |
else: a0,a1 = _p2_anno(f) | |
t = self.funcs.d.get(a0) | |
if t is None: | |
t = _TypeDict() | |
self.funcs.add(a0, t) | |
t.add(a1, f) | |
def first(self): | |
"Get first function in ordered dict of type:func." | |
return self.funcs.first().first() | |
def returns(self, x): | |
"Get the return type of annotation of `x`." | |
return anno_ret(self[type(x)]) | |
def _attname(self,k): return getattr(k,'__name__',str(k)) | |
def __repr__(self): | |
r = [f'({self._attname(k)},{self._attname(l)}) -> {getattr(v, "__name__", type(v).__name__)}' | |
for k in self.funcs.d for l,v in self.funcs[k].d.items()] | |
r = r + [o.__repr__() for o in self.bases] | |
return '\n'.join(r) | |
def __call__(self, *args, **kwargs): | |
ts = L(args).map(type)[:2] | |
f = self[tuple(ts)] | |
if not f: return args[0] | |
if isinstance(f, staticmethod): f = f.__func__ | |
elif self.inst is not None: f = MethodType(f, self.inst) | |
elif self.owner is not None: f = MethodType(f, self.owner) | |
return f(*args, **kwargs) | |
def __get__(self, inst, owner): | |
self.inst = inst | |
self.owner = owner | |
return self | |
def __getitem__(self, k): | |
"Find first matching type that is a super-class of `k`" | |
k = L(k) | |
while len(k)<2: k.append(object) | |
r = self.funcs.all_matches(k[0]) | |
for t in r: | |
o = t[k[1]] | |
if o is not None: return o | |
for base in self.bases: | |
res = base[k] | |
if res is not None: return res | |
return None | |
# %% ../nbs/04_dispatch.ipynb | |
class DispatchReg: | |
"A global registry for `TypeDispatch` objects keyed by function name" | |
def __init__(self): self.d = defaultdict(TypeDispatch) | |
def __call__(self, f): | |
if isinstance(f, (classmethod, staticmethod)): nm = f'{f.__func__.__qualname__}' | |
else: nm = f'{f.__qualname__}' | |
if isinstance(f, classmethod): f=f.__func__ | |
self.d[nm].add(f) | |
return self.d[nm] | |
typedispatch = DispatchReg() | |
# %% ../nbs/04_dispatch.ipynb | |
_all_=['cast'] | |
# %% ../nbs/04_dispatch.ipynb | |
def retain_meta(x, res, as_copy=False): | |
"Call `res.set_meta(x)`, if it exists" | |
if hasattr(res,'set_meta'): res.set_meta(x, as_copy=as_copy) | |
return res | |
# %% ../nbs/04_dispatch.ipynb | |
def default_set_meta(self, x, as_copy=False): | |
"Copy over `_meta` from `x` to `res`, if it's missing" | |
if hasattr(x, '_meta') and not hasattr(self, '_meta'): | |
meta = x._meta | |
if as_copy: meta = copy(meta) | |
self._meta = meta | |
return self | |
# %% ../nbs/04_dispatch.ipynb | |
@typedispatch | |
def cast(x, typ): | |
"cast `x` to type `typ` (may also change `x` inplace)" | |
res = typ._before_cast(x) if hasattr(typ, '_before_cast') else x | |
if risinstance('ndarray', res): res = res.view(typ) | |
elif hasattr(res, 'as_subclass'): res = res.as_subclass(typ) | |
else: | |
try: res.__class__ = typ | |
except: res = typ(res) | |
return retain_meta(x, res) | |
# %% ../nbs/04_dispatch.ipynb | |
def retain_type(new, old=None, typ=None, as_copy=False): | |
"Cast `new` to type of `old` or `typ` if it's a superclass" | |
# e.g. old is TensorImage, new is Tensor - if not subclass then do nothing | |
if new is None: return | |
assert old is not None or typ is not None | |
if typ is None: | |
if not isinstance(old, type(new)): return new | |
typ = old if isinstance(old,type) else type(old) | |
# Do nothing the new type is already an instance of requested type (i.e. same type) | |
if typ==NoneType or isinstance(new, typ): return new | |
return retain_meta(old, cast(new, typ), as_copy=as_copy) | |
# %% ../nbs/04_dispatch.ipynb | |
def retain_types(new, old=None, typs=None): | |
"Cast each item of `new` to type of matching item in `old` if it's a superclass" | |
if not is_listy(new): return retain_type(new, old, typs) | |
if typs is not None: | |
if isinstance(typs, dict): | |
t = first(typs.keys()) | |
typs = typs[t] | |
else: t,typs = typs,None | |
else: t = type(old) if old is not None and isinstance(old,type(new)) else type(new) | |
return t(L(new, old, typs).map_zip(retain_types, cycled=True)) | |
# %% ../nbs/04_dispatch.ipynb | |
def explode_types(o): | |
"Return the type of `o`, potentially in nested dictionaries for thing that are listy" | |
if not is_listy(o): return type(o) | |
return {type(o): [explode_types(o_) for o_ in o]} | |
--- | |
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/06_docments.ipynb. | |
# %% ../nbs/06_docments.ipynb 2 | |
from __future__ import annotations | |
import re | |
from tokenize import tokenize,COMMENT | |
from ast import parse,FunctionDef,AsyncFunctionDef,AnnAssign | |
from io import BytesIO | |
from textwrap import dedent | |
from types import SimpleNamespace | |
from inspect import getsource,isfunction,ismethod,isclass,signature,Parameter | |
from dataclasses import dataclass, is_dataclass | |
from .utils import * | |
from .meta import delegates | |
from . import docscrape | |
from inspect import isclass,getdoc | |
# %% auto 0 | |
__all__ = ['empty', 'docstring', 'parse_docstring', 'isdataclass', 'get_dataclass_source', 'get_source', 'get_name', 'qual_name', | |
'docments'] | |
# %% ../nbs/06_docments.ipynb | |
def docstring(sym): | |
"Get docstring for `sym` for functions ad classes" | |
if isinstance(sym, str): return sym | |
res = getdoc(sym) | |
if not res and isclass(sym): res = getdoc(sym.__init__) | |
return res or "" | |
# %% ../nbs/06_docments.ipynb | |
def parse_docstring(sym): | |
"Parse a numpy-style docstring in `sym`" | |
docs = docstring(sym) | |
return AttrDict(**docscrape.NumpyDocString(docstring(sym))) | |
# %% ../nbs/06_docments.ipynb | |
def isdataclass(s): | |
"Check if `s` is a dataclass but not a dataclass' instance" | |
return is_dataclass(s) and isclass(s) | |
# %% ../nbs/06_docments.ipynb | |
def get_dataclass_source(s): | |
"Get source code for dataclass `s`" | |
return getsource(s) if not getattr(s, "__module__") == '__main__' else "" | |
# %% ../nbs/06_docments.ipynb | |
def get_source(s): | |
"Get source code for string, function object or dataclass `s`" | |
return getsource(s) if isfunction(s) or ismethod(s) else get_dataclass_source(s) if isdataclass(s) else s | |
# %% ../nbs/06_docments.ipynb | |
def _parses(s): | |
"Parse Python code in string, function object or dataclass `s`" | |
return parse(dedent(get_source(s))) | |
def _tokens(s): | |
"Tokenize Python code in string or function object `s`" | |
s = get_source(s) | |
return tokenize(BytesIO(s.encode('utf-8')).readline) | |
_clean_re = re.compile(r'^\s*#(.*)\s*$') | |
def _clean_comment(s): | |
res = _clean_re.findall(s) | |
return res[0] if res else None | |
def _param_locs(s, returns=True): | |
"`dict` of parameter line numbers to names" | |
body = _parses(s).body | |
if len(body)==1: #or not isinstance(body[0], FunctionDef): return None | |
defn = body[0] | |
if isinstance(defn, (FunctionDef, AsyncFunctionDef)): | |
res = {arg.lineno:arg.arg for arg in defn.args.args} | |
if returns and defn.returns: res[defn.returns.lineno] = 'return' | |
return res | |
elif isdataclass(s): | |
res = {arg.lineno:arg.target.id for arg in defn.body if isinstance(arg, AnnAssign)} | |
return res | |
return None | |
# %% ../nbs/06_docments.ipynb | |
empty = Parameter.empty | |
# %% ../nbs/06_docments.ipynb | |
def _get_comment(line, arg, comments, parms): | |
if line in comments: return comments[line].strip() | |
line -= 1 | |
res = [] | |
while line and line in comments and line not in parms: | |
res.append(comments[line]) | |
line -= 1 | |
return dedent('\n'.join(reversed(res))) if res else None | |
def _get_full(anno, name, default, docs): | |
if anno==empty and default!=empty: anno = type(default) | |
return AttrDict(docment=docs.get(name), anno=anno, default=default) | |
# %% ../nbs/06_docments.ipynb | |
def _merge_doc(dm, npdoc): | |
if not npdoc: return dm | |
if not dm.anno or dm.anno==empty: dm.anno = npdoc.type | |
if not dm.docment: dm.docment = '\n'.join(npdoc.desc) | |
return dm | |
def _merge_docs(dms, npdocs): | |
npparams = npdocs['Parameters'] | |
params = {nm:_merge_doc(dm,npparams.get(nm,None)) for nm,dm in dms.items()} | |
if 'return' in dms: params['return'] = _merge_doc(dms['return'], npdocs['Returns']) | |
return params | |
# %% ../nbs/06_docments.ipynb | |
def _get_property_name(p): | |
"Get the name of property `p`" | |
if hasattr(p, 'fget'): | |
return p.fget.func.__qualname__ if hasattr(p.fget, 'func') else p.fget.__qualname__ | |
else: return next(iter(re.findall(r'\'(.*)\'', str(p)))).split('.')[-1] | |
# %% ../nbs/06_docments.ipynb | |
def get_name(obj): | |
"Get the name of `obj`" | |
if hasattr(obj, '__name__'): return obj.__name__ | |
elif getattr(obj, '_name', False): return obj._name | |
elif hasattr(obj,'__origin__'): return str(obj.__origin__).split('.')[-1] #for types | |
elif type(obj)==property: return _get_property_name(obj) | |
else: return str(obj).split('.')[-1] | |
# %% ../nbs/06_docments.ipynb | |
def qual_name(obj): | |
"Get the qualified name of `obj`" | |
if hasattr(obj,'__qualname__'): return obj.__qualname__ | |
if ismethod(obj): return f"{get_name(obj.__self__)}.{get_name(fn)}" | |
return get_name(obj) | |
# %% ../nbs/06_docments.ipynb | |
def _docments(s, returns=True, eval_str=False): | |
"`dict` of parameter names to 'docment-style' comments in function or string `s`" | |
nps = parse_docstring(s) | |
if isclass(s) and not is_dataclass(s): s = s.__init__ # Constructor for a class | |
comments = {o.start[0]:_clean_comment(o.string) for o in _tokens(s) if o.type==COMMENT} | |
parms = _param_locs(s, returns=returns) or {} | |
docs = {arg:_get_comment(line, arg, comments, parms) for line,arg in parms.items()} | |
if isinstance(s,str): s = eval(s) | |
sig = signature(s) | |
res = {arg:_get_full(p.annotation, p.name, p.default, docs) for arg,p in sig.parameters.items()} | |
if returns: res['return'] = _get_full(sig.return_annotation, 'return', empty, docs) | |
res = _merge_docs(res, nps) | |
if eval_str: | |
hints = type_hints(s) | |
for k,v in res.items(): | |
if k in hints: v['anno'] = hints.get(k) | |
return res | |
# %% ../nbs/06_docments.ipynb | |
@delegates(_docments) | |
def docments(elt, full=False, **kwargs): | |
"Generates a `docment`" | |
r = {} | |
params = set(signature(elt).parameters) | |
params.add('return') | |
def _update_docments(f, r): | |
if hasattr(f, '__delwrap__'): _update_docments(f.__delwrap__, r) | |
r.update({k:v for k,v in _docments(f, **kwargs).items() if k in params | |
and (v.get('docment', None) or not nested_idx(r, k, 'docment'))}) | |
_update_docments(elt, r) | |
if not full: r = {k:v['docment'] for k,v in r.items()} | |
return AttrDict(r) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment