Skip to content

Instantly share code, notes, and snippets.

@amaarora
Created November 10, 2019 00:14
Show Gist options
  • Save amaarora/e20ad0f06974c29f696ccdcdc590884a to your computer and use it in GitHub Desktop.
Save amaarora/e20ad0f06974c29f696ccdcdc590884a to your computer and use it in GitHub Desktop.
#Cell
class TfmdList(FilteredBase, L, GetAttr):
"A `Pipeline` of `tfms` applied to a collection of `items`"
_default='tfms'
def __init__(self, items, tfms, use_list=None, do_setup=True, as_item=True, split_idx=None, train_setup=True, splits=None):
super().__init__(items, use_list=use_list)
self.splits = L([slice(None),[]] if splits is None else splits).map(mask2idxs)
if isinstance(tfms,TfmdList): tfms = tfms.tfms
if isinstance(tfms,Pipeline): do_setup=False
self.tfms = Pipeline(tfms, as_item=as_item, split_idx=split_idx)
if do_setup: self.setup(train_setup=train_setup)
def _new(self, items, **kwargs): return super()._new(items, tfms=self.tfms, do_setup=False, **kwargs)
def subset(self, i): return self._new(self._get(self.splits[i]), split_idx=i)
def _after_item(self, o): return self.tfms(o)
def __repr__(self): return f"{self.__class__.__name__}: {self.items}\ntfms - {self.tfms.fs}"
def __iter__(self): return (self[i] for i in range(len(self)))
def show(self, o, **kwargs): return self.tfms.show(o, **kwargs)
def decode(self, o, **kwargs): return self.tfms.decode(o, **kwargs)
def __call__(self, o, **kwargs): return self.tfms.__call__(o, **kwargs)
def setup(self, train_setup=True): self.tfms.setup(getattr(self,'train',self) if train_setup else self)
def overlapping_splits(self): return L(Counter(self.splits.concat()).values()).filter(gt(1))
def __getitem__(self, idx):
res = super().__getitem__(idx)
if self._after_item is None: return res
return self._after_item(res) if is_indexer(idx) else res.map(self._after_item)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment