Skip to content

Instantly share code, notes, and snippets.

@amaarora
Last active November 6, 2019 21:42
Show Gist options
  • Select an option

  • Save amaarora/6fb9b7ca43b78ec17c7abd75c370b9cd to your computer and use it in GitHub Desktop.

Select an option

Save amaarora/6fb9b7ca43b78ec17c7abd75c370b9cd to your computer and use it in GitHub Desktop.
#Cell
@docs
@delegates(TfmdList)
class DataSource(FilteredBase):
"A dataset that creates a tuple from each `tfms`, passed thru `item_tfms`"
def __init__(self, items=None, tfms=None, tls=None, n_inp=None, dl_type=None, **kwargs):
super().__init__(dl_type=dl_type)
self.tls = L(tls if tls else [TfmdList(items, t, **kwargs) for t in L(ifnone(tfms,[None]))])
self.n_inp = (1 if len(self.tls)==1 else len(self.tls)-1) if n_inp is None else n_inp
def __getitem__(self, it):
res = tuple([tl[it] for tl in self.tls])
return res if is_indexer(it) else list(zip(*res))
def __getattr__(self,k): return gather_attrs(self, k, 'tls')
def __dir__(self): return super().__dir__() + gather_attr_names(self, 'tls')
def __len__(self): return len(self.tls[0])
def __iter__(self): return (self[i] for i in range(len(self)))
def __repr__(self): return coll_repr(self)
def decode(self, o, full=True): return tuple(tl.decode(o_, full=full) for o_,tl in zip(o,tuplify(self.tls, match=o)))
def subset(self, i): return type(self)(tls=L(tl.subset(i) for tl in self.tls), n_inp=self.n_inp)
def _new(self, items, *args, **kwargs): return super()._new(items, tfms=self.tfms, do_setup=False, **kwargs)
def overlapping_splits(self): return self.tls[0].overlapping_splits()
@property
def splits(self): return self.tls[0].splits
@property
def split_idx(self): return self.tls[0].tfms.split_idx
@property
def items(self): return self.tls[0].items
@items.setter
def items(self, v):
for tl in self.tls: tl.items = v
def show(self, o, ctx=None, **kwargs):
for o_,tl in zip(o,self.tls): ctx = tl.show(o_, ctx=ctx, **kwargs)
return ctx
def new_empty(self):
tls = [tl._new([], split_idx=tl.split_idx) for tl in self.tls]
return type(self)(tls=tls, n_inp=self.n_inp)
_docs=dict(
decode="Compose `decode` of all `tuple_tfms` then all `tfms` on `i`",
show="Show item `o` in `ctx`",
databunch="Get a `DataBunch`",
overlapping_splits="All splits that are in more than one split",
subset="New `DataSource` that only includes subset `i`",
new_empty="Create a new empty version of the `self`, keeping only the transforms")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment