Last active
November 6, 2019 21:42
-
-
Save amaarora/6fb9b7ca43b78ec17c7abd75c370b9cd to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #Cell | |
| @docs | |
| @delegates(TfmdList) | |
| class DataSource(FilteredBase): | |
| "A dataset that creates a tuple from each `tfms`, passed thru `item_tfms`" | |
| def __init__(self, items=None, tfms=None, tls=None, n_inp=None, dl_type=None, **kwargs): | |
| super().__init__(dl_type=dl_type) | |
| self.tls = L(tls if tls else [TfmdList(items, t, **kwargs) for t in L(ifnone(tfms,[None]))]) | |
| self.n_inp = (1 if len(self.tls)==1 else len(self.tls)-1) if n_inp is None else n_inp | |
| def __getitem__(self, it): | |
| res = tuple([tl[it] for tl in self.tls]) | |
| return res if is_indexer(it) else list(zip(*res)) | |
| def __getattr__(self,k): return gather_attrs(self, k, 'tls') | |
| def __dir__(self): return super().__dir__() + gather_attr_names(self, 'tls') | |
| def __len__(self): return len(self.tls[0]) | |
| def __iter__(self): return (self[i] for i in range(len(self))) | |
| def __repr__(self): return coll_repr(self) | |
| def decode(self, o, full=True): return tuple(tl.decode(o_, full=full) for o_,tl in zip(o,tuplify(self.tls, match=o))) | |
| def subset(self, i): return type(self)(tls=L(tl.subset(i) for tl in self.tls), n_inp=self.n_inp) | |
| def _new(self, items, *args, **kwargs): return super()._new(items, tfms=self.tfms, do_setup=False, **kwargs) | |
| def overlapping_splits(self): return self.tls[0].overlapping_splits() | |
| @property | |
| def splits(self): return self.tls[0].splits | |
| @property | |
| def split_idx(self): return self.tls[0].tfms.split_idx | |
| @property | |
| def items(self): return self.tls[0].items | |
| @items.setter | |
| def items(self, v): | |
| for tl in self.tls: tl.items = v | |
| def show(self, o, ctx=None, **kwargs): | |
| for o_,tl in zip(o,self.tls): ctx = tl.show(o_, ctx=ctx, **kwargs) | |
| return ctx | |
| def new_empty(self): | |
| tls = [tl._new([], split_idx=tl.split_idx) for tl in self.tls] | |
| return type(self)(tls=tls, n_inp=self.n_inp) | |
| _docs=dict( | |
| decode="Compose `decode` of all `tuple_tfms` then all `tfms` on `i`", | |
| show="Show item `o` in `ctx`", | |
| databunch="Get a `DataBunch`", | |
| overlapping_splits="All splits that are in more than one split", | |
| subset="New `DataSource` that only includes subset `i`", | |
| new_empty="Create a new empty version of the `self`, keeping only the transforms") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment