Created
June 2, 2021 01:09
-
-
Save amaarora/30cf7233d4a2ebe3d82d7564fcd41d6f to your computer and use it in GitHub Desktop.
git_repos/fastai/nbs/03_data.core.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#hide\n#skip\n! [ -e /content ] && pip install -Uqq fastai # upgrade fastai on colab", | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#default_exp data.core", | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#export\nfrom fastai.torch_basics import *\nfrom fastai.data.load import *\nfrom torch.utils.data import DataLoader", | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#hide\nfrom nbdev.showdoc import *", | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "# Data core\n\n> Core functionality for gathering data" | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "The classes here provide functionality for applying a list of transforms to a set of items (`TfmdLists`, `Datasets`) or a `DataLoader` (`TfmdDl`) as well as the base class used to gather the data for model training: `DataLoaders`." | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "## TfmdDL -" | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#export\n@typedispatch\ndef show_batch(x, y, samples, ctxs=None, max_n=9, **kwargs):\n if ctxs is None: ctxs = Inf.nones\n if hasattr(samples[0], 'show'):\n ctxs = [s.show(ctx=c, **kwargs) for s,c,_ in zip(samples,ctxs,range(max_n))]\n else:\n for i in range_of(samples[0]):\n ctxs = [b.show(ctx=c, **kwargs) for b,c,_ in zip(samples.itemgot(i),ctxs,range(max_n))]\n return ctxs", | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "`show_batch` is a type-dispatched function that is responsible for showing decoded `samples`. `x` and `y` are the input and the target in the batch to be shown, and are passed along to dispatch on their types. There is a different implementation of `show_batch` if `x` is a `TensorImage` or a `TensorText` for instance (see vision.core or text.data for more details). `ctxs` can be passed but the function is responsible to create them if necessary. `kwargs` depend on the specific implementation." | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#export\n@typedispatch\ndef show_results(x, y, samples, outs, ctxs=None, max_n=9, **kwargs):\n if ctxs is None: ctxs = Inf.nones\n for i in range(len(samples[0])):\n ctxs = [b.show(ctx=c, **kwargs) for b,c,_ in zip(samples.itemgot(i),ctxs,range(max_n))]\n for i in range(len(outs[0])):\n ctxs = [b.show(ctx=c, **kwargs) for b,c,_ in zip(outs.itemgot(i),ctxs,range(max_n))]\n return ctxs", | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "`show_results` is a type-dispatched function that is responsible for showing decoded `samples` and their corresponding `outs`. Like in `show_batch`, `x` and `y` are the input and the target in the batch to be shown, and are passed along to dispatch on their types. `ctxs` can be passed but the function is responsible to create them if necessary. `kwargs` depend on the specific implementation." | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#export\n_all_ = [\"show_batch\", \"show_results\"]", | |
"execution_count": 7, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#export\n_batch_tfms = ('after_item','before_batch','after_batch')", | |
"execution_count": 8, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#export\n_collate_types = (ndarray, Tensor, typing.Mapping, str)", | |
"execution_count": 9, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#export\ndef fa_collate(t):\n \"A replacement for PyTorch `default_collate` which maintains types and handles `Sequence`s\"\n b = t[0]\n return (default_collate(t) if isinstance(b, _collate_types)\n else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence)\n else default_collate(t))", | |
"execution_count": 10, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#export\n@delegates()\nclass TfmdDL(DataLoader, GetAttr):\n \"Transformed `DataLoader`\"\n _default='dataset'\n def __init__(self, dataset, bs=64, shuffle=False, num_workers=None, collate_fn=fa_collate, verbose=False, do_setup=True, indexed=None, device=None, **kwargs):\n if num_workers is None: num_workers = min(16, defaults.cpus)\n for nm in _batch_tfms: kwargs[nm] = Pipeline(kwargs.get(nm,None))\n self.after_item = kwargs.pop('after_item')\n self.after_batch = kwargs.pop('after_batch')\n self.before_batch = kwargs.pop('before_batch')\n self.indexed = hasattr(dataset, '__getitem__') and not isinstance(dataset, IterableDataset)\n \n store_attr('shuffle, device, bs')\n if do_setup:\n for nm in _batch_tfms:\n pv(f\"Setting up {nm}: {getattr(self, nm)}\", verbose)\n getattr(self, nm).setup(self)\n \n if self.after_item is not noop:\n dataset = TfmdDataset(dataset, after_item=self.after_item)\n super().__init__(dataset=dataset, batch_size=bs, shuffle=shuffle, num_workers=num_workers, \n collate_fn=collate_fn, **kwargs)\n\n \n def __iter__(self):\n self.before_iter()\n for batch in super().__iter__():\n if self.device is not None: batch = to_device(batch, self.device)\n yield self.after_batch(batch) \n \n def before_iter(self):\n split_idx = getattr(self.dataset, 'split_idx', None)\n for nm in _batch_tfms:\n f = getattr(self,nm)\n if isinstance(f,Pipeline): f.split_idx=split_idx\n\n def _one_pass(self):\n its = next(iter(self))\n self._n_inp = 1 if not isinstance(its, (list,tuple)) or len(its)==1 else len(its)-1\n self._types = explode_types(its) \n \n def new(self, dataset=None, cls=None, **kwargs):\n if dataset is None: dataset=self.dataset\n if cls is None: cls = type(self)\n cur_kwargs = dict(dataset=dataset, num_workers=self.num_workers, pin_memory=self.pin_memory, timeout=self.timeout,\n bs=self.batch_size, shuffle=self.shuffle, drop_last=self.drop_last, indexed=self.indexed, device=self.device)\n for n in _batch_tfms:\n o = getattr(self, n)\n if not isinstance(o, MethodType): cur_kwargs[n] = o\n res = cls(**merge(cur_kwargs, kwargs))\n\n if not hasattr(self, '_n_inp') or not hasattr(self, '_types'):\n try:\n self._one_pass()\n res._n_inp,res._types = self._n_inp,self._types\n except: print(\"Could not do one pass in your dataloader, there is something wrong in it\")\n else: res._n_inp,res._types = self._n_inp,self._types\n return res\n\n\n def to(self, device):\n self.device = device\n for tfm in self.after_batch.fs:\n for a in L(getattr(tfm, 'parameters', None)): setattr(tfm, a, getattr(tfm, a).to(device))\n return self\n\n def _retain_dl(self,b):\n if not getattr(self, '_types', None): self._one_pass()\n return retain_types(b, typs=self._types)\n\n def decode(self, b): return to_cpu(self.after_batch.decode(self._retain_dl(b)))\n def decode_batch(self, b, max_n=9, full=True): return self._decode_batch(self.decode(b), max_n, full)\n\n def _decode_batch(self, b, max_n=9, full=True):\n f = self.after_item.decode\n f1 = self.before_batch.decode\n f = compose(f1, f, partial(getattr(self.dataset,'decode',noop), full = full))\n return L(batch_to_samples(b, max_n=max_n)).map(f)\n\n def _pre_show_batch(self, b, max_n=9):\n \"Decode `b` to be ready for `show_batch`\"\n b = self.decode(b)\n if hasattr(b, 'show'): return b,None,None\n its = self._decode_batch(b, max_n, full=False)\n if not is_listy(b): b,its = [b],L((o,) for o in its)\n return detuplify(b[:self.n_inp]),detuplify(b[self.n_inp:]),its\n\n def show_batch(self, b=None, max_n=9, ctxs=None, show=True, unique=False, **kwargs):\n if unique:\n old_get_idxs = self.get_idxs\n self.get_idxs = lambda: Inf.zeros\n if b is None: b = self.one_batch()\n if not show: return self._pre_show_batch(b, max_n=max_n)\n show_batch(*self._pre_show_batch(b, max_n=max_n), ctxs=ctxs, max_n=max_n, **kwargs)\n if unique: self.get_idxs = old_get_idxs\n\n def show_results(self, b, out, max_n=9, ctxs=None, show=True, **kwargs):\n x,y,its = self.show_batch(b, max_n=max_n, show=False)\n b_out = type(b)(b[:self.n_inp] + (tuple(out) if is_listy(out) else (out,)))\n x1,y1,outs = self.show_batch(b_out, max_n=max_n, show=False)\n res = (x,x1,None,None) if its is None else (x, y, its, outs.itemgot(slice(self.n_inp,None)))\n if not show: return res\n show_results(*res, ctxs=ctxs, max_n=max_n, **kwargs)\n\n def one_batch(self):\n res = first(self)\n return res\n \n @property\n def n_inp(self):\n if hasattr(self.dataset, 'n_inp'): return self.dataset.n_inp\n if not hasattr(self, '_n_inp'): self._one_pass()\n return self._n_inp", | |
"execution_count": 11, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "A `TfmdDL` is a `DataLoader` that creates `Pipeline` from a list of `Transform`s for the callbacks `after_item`, `before_batch` and `after_batch`. As a result, it can decode or show a processed `batch`." | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#export\nadd_docs(TfmdDL,\n decode=\"Decode `b` using `tfms`\",\n decode_batch=\"Decode `b` entirely\",\n new=\"Create a new version of self with a few changed attributes\",\n show_batch=\"Show `b` (defaults to `one_batch`), a list of lists of pipeline outputs (i.e. output of a `DataLoader`)\",\n show_results=\"Show each item of `b` and `out`\",\n one_batch=\"Return one batch from `DataLoader`.\",\n to=\"Put self and its transforms state on `device`\", \n before_iter=\"override\")", | |
"execution_count": 12, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "class _Category(int, ShowTitle): pass", | |
"execution_count": 13, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#Test retain type\nclass NegTfm(Transform):\n def encodes(self, x): return torch.neg(x)\n def decodes(self, x): return torch.neg(x)\n \ntdl = TfmdDL([(TensorImage([1]),)] * 4, after_batch=NegTfm(), bs=4, num_workers=4)\nb = tdl.one_batch()\ntest_eq(type(b[0]), TensorImage)\nb = (tensor([1.,1.,1.,1.]),)\ntest_eq(type(tdl.decode_batch(b)[0][0]), TensorImage)", | |
"execution_count": 14, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "class A(Transform): \n def encodes(self, x): return x\n def decodes(self, x): return TitledInt(x) \n\n@Transform\ndef f(x)->None: return fastuple((x,x))\n\nstart = torch.arange(50)\ntest_eq_type(f(2), fastuple((2,2)))", | |
"execution_count": 15, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "a = A()\ntdl = TfmdDL(start, after_item=lambda x: (a(x), f(x)), bs=4)\nx,y = tdl.one_batch()\ntest_eq(type(y), fastuple)\n\ns = tdl.decode_batch((x,y))\ntest_eq(type(s[0][1]), fastuple)", | |
"execution_count": 16, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "tdl = TfmdDL(torch.arange(0,50), after_item=A(), after_batch=NegTfm(), bs=4)\ntest_eq(tdl.dataset[0], start[0])\ntest_eq(len(tdl), (50-1)//4+1)\ntest_eq(tdl.bs, 4)\ntest_stdout(tdl.show_batch, '0\\n1\\n2\\n3')", | |
"execution_count": 17, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "class B(Transform):\n parameters = 'a'\n def __init__(self): self.a = torch.tensor(0.)\n def encodes(self, x): x\n \ntdl = TfmdDL([(TensorImage([1]),)] * 4, after_batch=B(), bs=4)\ntest_eq(tdl.after_batch.fs[0].a.device, torch.device('cpu'))\ntdl.to(default_device())\ntest_eq(tdl.after_batch.fs[0].a.device, default_device())", | |
"execution_count": 18, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "### Methods" | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "show_doc(TfmdDL.one_batch)", | |
"execution_count": 19, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "<IPython.core.display.Markdown object>", | |
"text/markdown": "<h4 id=\"TfmdDL.one_batch\" class=\"doc_header\"><code>TfmdDL.one_batch</code><a href=\"__main__.py#L106\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n\n> <code>TfmdDL.one_batch</code>()\n\nReturn one batch from `DataLoader`." | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "tfm = NegTfm()\ntdl = TfmdDL(start, after_batch=tfm, bs=4)", | |
"execution_count": 20, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "b = tdl.one_batch()\ntest_eq(tensor([0,-1,-2,-3]), b)", | |
"execution_count": 21, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "show_doc(TfmdDL.decode)", | |
"execution_count": 22, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "<IPython.core.display.Markdown object>", | |
"text/markdown": "<h4 id=\"TfmdDL.decode\" class=\"doc_header\"><code>TfmdDL.decode</code><a href=\"__main__.py#L72\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n\n> <code>TfmdDL.decode</code>(**`b`**)\n\nDecode `b` using `tfms`" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "test_eq(tdl.decode(b), tensor(0,1,2,3))", | |
"execution_count": 23, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "show_doc(TfmdDL.decode_batch)", | |
"execution_count": 24, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "<IPython.core.display.Markdown object>", | |
"text/markdown": "<h4 id=\"TfmdDL.decode_batch\" class=\"doc_header\"><code>TfmdDL.decode_batch</code><a href=\"__main__.py#L73\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n\n> <code>TfmdDL.decode_batch</code>(**`b`**, **`max_n`**=*`9`*, **`full`**=*`True`*)\n\nDecode `b` entirely" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "test_eq(tdl.decode_batch(b), [0,1,2,3])", | |
"execution_count": 25, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "show_doc(TfmdDL.show_batch)", | |
"execution_count": 26, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "<IPython.core.display.Markdown object>", | |
"text/markdown": "<h4 id=\"TfmdDL.show_batch\" class=\"doc_header\"><code>TfmdDL.show_batch</code><a href=\"__main__.py#L89\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n\n> <code>TfmdDL.show_batch</code>(**`b`**=*`None`*, **`max_n`**=*`9`*, **`ctxs`**=*`None`*, **`show`**=*`True`*, **`unique`**=*`False`*, **\\*\\*`kwargs`**)\n\nShow `b` (defaults to `one_batch`), a list of lists of pipeline outputs (i.e. output of a `DataLoader`)" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "show_doc(TfmdDL.to)", | |
"execution_count": 27, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "<IPython.core.display.Markdown object>", | |
"text/markdown": "<h4 id=\"TfmdDL.to\" class=\"doc_header\"><code>TfmdDL.to</code><a href=\"__main__.py#L62\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n\n> <code>TfmdDL.to</code>(**`device`**)\n\nPut self and its transforms state on `device`" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "## DataLoaders -" | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "# export\n@docs\nclass DataLoaders(GetAttr):\n \"Basic wrapper around several `DataLoader`s.\"\n _default='train'\n def __init__(self, *loaders, path='.', device=None):\n self.loaders,self.path = list(loaders),Path(path)\n if device is not None or hasattr(loaders[0],'to'): self.device = device\n\n def __getitem__(self, i): return self.loaders[i]\n def new_empty(self):\n loaders = [dl.new(dl.dataset.new_empty()) for dl in self.loaders]\n return type(self)(*loaders, path=self.path, device=self.device)\n\n def _set(i, self, v): self.loaders[i] = v\n train ,valid = add_props(lambda i,x: x[i], _set)\n train_ds,valid_ds = add_props(lambda i,x: x[i].dataset)\n\n @property\n def device(self): return self._device\n\n @device.setter\n def device(self, d):\n for dl in self.loaders: dl.to(d)\n self._device = d\n\n def to(self, device):\n self.device = device\n return self\n \n def _add_tfms(self, tfms, event, dl_idx):\n \"Adds `tfms` to `event` on `dl`\"\n if(isinstance(dl_idx,str)): dl_idx = 0 if(dl_idx=='train') else 1\n dl_tfms = getattr(self[dl_idx], event)\n apply(dl_tfms.add, tfms)\n \n def add_tfms(self,tfms,event,loaders=None):\n \"Adds `tfms` to `events` on `loaders`\"\n if(loaders is None): loaders=range(len(self.loaders))\n if not is_listy(loaders): loaders = listify(loaders)\n for loader in loaders:\n self._add_tfms(tfms,event,loader) \n\n def cuda(self): return self.to(device=default_device())\n def cpu(self): return self.to(device=torch.device('cpu'))\n\n @classmethod\n def from_dsets(cls, *ds, path='.', bs=64, device=None, dl_type=TfmdDL, **kwargs):\n default = (True,) + (False,) * (len(ds)-1)\n defaults = {'shuffle': default, 'drop_last': default}\n for nm in _batch_tfms:\n if nm in kwargs: kwargs[nm] = Pipeline(kwargs[nm])\n kwargs = merge(defaults, {k: tuplify(v, match=ds) for k,v in kwargs.items()})\n kwargs = [{k: v[i] for k,v in kwargs.items()} for i in range_of(ds)]\n return cls(*[dl_type(d, bs=bs, **k) for d,k in zip(ds, kwargs)], path=path, device=device)\n\n @classmethod\n def from_dblock(cls, dblock, source, path='.', bs=64, val_bs=None, shuffle=True, device=None, **kwargs):\n return dblock.dataloaders(source, path=path, bs=bs, val_bs=val_bs, shuffle=shuffle, device=device, **kwargs)\n\n _docs=dict(__getitem__=\"Retrieve `DataLoader` at `i` (`0` is training, `1` is validation)\",\n train=\"Training `DataLoader`\",\n valid=\"Validation `DataLoader`\",\n train_ds=\"Training `Dataset`\",\n valid_ds=\"Validation `Dataset`\",\n to=\"Use `device`\",\n add_tfms=\"Add `tfms` to `loaders` for `event\",\n cuda=\"Use the gpu if available\",\n cpu=\"Use the cpu\",\n new_empty=\"Create a new empty version of `self` with the same transforms\",\n from_dblock=\"Create a dataloaders from a given `dblock`\")", | |
"execution_count": 28, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "dls = DataLoaders(tdl,tdl)\nx = dls.train.one_batch()\nx2 = first(tdl)\ntest_eq(x,x2)\nx2 = dls.one_batch()\ntest_eq(x,x2)", | |
"execution_count": 29, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "dls[0].after_batch, dls[1].after_batch", | |
"execution_count": 30, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 30, | |
"data": { | |
"text/plain": "(Pipeline: NegTfm, Pipeline: NegTfm)" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#hide\n#test assignment works\ndls.train = dls.train.new(bs=4)", | |
"execution_count": 31, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "Multiple transforms can by added to multiple dataloaders using `Dataloaders.add_tfms`. You can specify the dataloaders by list of names `dls.add_tfms(...,'valid',...)` or by index `dls.add_tfms(...,1,....)`, by default transforms are added to all dataloaders. `event` is a required argument and determined when the transform will be run, for more information on events please refer to `TfmdDL`. `tfms` is a list of `Transform`, and is a required argument. " | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "class _TestTfm(Transform):\n def encodes(self, o): return torch.ones_like(o)\n def decodes(self, o): return o\ntdl1,tdl2 = TfmdDL(start, bs=4),TfmdDL(start, bs=4)\ndls2 = DataLoaders(tdl1,tdl2)\ndls2.add_tfms([_TestTfm()],'after_batch',['valid'])\ndls2.add_tfms([_TestTfm()],'after_batch',[1])\ndls2.train.after_batch,dls2.valid.after_batch,", | |
"execution_count": 32, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 32, | |
"data": { | |
"text/plain": "(Pipeline: , Pipeline: _TestTfm -> _TestTfm)" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#hide\ntest_eq(len(dls2.train.after_batch.fs),0)\ntest_eq(len(dls2.valid.after_batch.fs),2)\ntest_eq(next(iter(dls2.valid)),tensor([1,1,1,1]))", | |
"execution_count": 33, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "### Methods" | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "show_doc(DataLoaders.__getitem__)", | |
"execution_count": 34, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "<IPython.core.display.Markdown object>", | |
"text/markdown": "<h4 id=\"DataLoaders.__getitem__\" class=\"doc_header\"><code>DataLoaders.__getitem__</code><a href=\"__main__.py#L10\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n\n> <code>DataLoaders.__getitem__</code>(**`i`**)\n\nRetrieve `DataLoader` at `i` (`0` is training, `1` is validation)" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "x2 = dls[0].one_batch()\ntest_eq(x,x2)", | |
"execution_count": 35, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "show_doc(DataLoaders.train, name=\"DataLoaders.train\")", | |
"execution_count": 36, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "<IPython.core.display.Markdown object>", | |
"text/markdown": "<h4 id=\"DataLoaders.train\" class=\"doc_header\"><code>DataLoaders.train</code><a href=\"__main__.py#L16\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n\nTraining `DataLoader`" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "show_doc(DataLoaders.valid, name=\"DataLoaders.valid\")", | |
"execution_count": 37, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "<IPython.core.display.Markdown object>", | |
"text/markdown": "<h4 id=\"DataLoaders.valid\" class=\"doc_header\"><code>DataLoaders.valid</code><a href=\"__main__.py#L16\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n\nValidation `DataLoader`" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "show_doc(DataLoaders.train_ds, name=\"DataLoaders.train_ds\")", | |
"execution_count": 38, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "<IPython.core.display.Markdown object>", | |
"text/markdown": "<h4 id=\"DataLoaders.train_ds\" class=\"doc_header\"><code>DataLoaders.train_ds</code><a href=\"__main__.py#L17\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n\nTraining `Dataset`" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "show_doc(DataLoaders.valid_ds, name=\"DataLoaders.valid_ds\")", | |
"execution_count": 39, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "<IPython.core.display.Markdown object>", | |
"text/markdown": "<h4 id=\"DataLoaders.valid_ds\" class=\"doc_header\"><code>DataLoaders.valid_ds</code><a href=\"__main__.py#L17\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n\nValidation `Dataset`" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "## TfmdLists -" | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#export\nclass FilteredBase:\n \"Base class for lists with subsets\"\n _dl_type,_dbunch_type = TfmdDL,DataLoaders\n def __init__(self, *args, dl_type=None, **kwargs):\n if dl_type is not None: self._dl_type = dl_type\n self.dataloaders = delegates(self._dl_type.__init__)(self.dataloaders)\n super().__init__(*args, **kwargs)\n\n @property\n def n_subsets(self): return len(self.splits)\n def _new(self, items, **kwargs): return super()._new(items, splits=self.splits, **kwargs)\n def subset(self): raise NotImplemented\n\n def dataloaders(self, bs=64, shuffle_train=None, shuffle=True, val_shuffle=False,n=None, path='.', dl_type=None, dl_kwargs=None,\n device=None,drop_last=None,val_bs=None, **kwargs):\n if shuffle_train is not None: \n shuffle=shuffle_train\n warnings.warn('`shuffle_train` is deprecated. Use `shuffle` instead.',DeprecationWarning)\n if device is None: device=default_device()\n if dl_kwargs is None: dl_kwargs = [{}] * self.n_subsets\n if dl_type is None: dl_type = self._dl_type\n if drop_last is None: drop_last = shuffle\n val_kwargs={k[4:]:v for k,v in kwargs.items() if k.startswith('val_')}\n def_kwargs = {'bs':bs,'shuffle':shuffle,'drop_last':drop_last,'device':device}\n dl = dl_type(self.subset(0), **merge(kwargs,def_kwargs, dl_kwargs[0]))\n def_kwargs = {'bs':bs if val_bs is None else val_bs,'shuffle':val_shuffle,'drop_last':False}\n dls = [dl] + [dl.new(self.subset(i), **merge(kwargs,def_kwargs,val_kwargs,dl_kwargs[i]))\n for i in range(1, self.n_subsets)]\n return self._dbunch_type(*dls, path=path, device=device) \n\nFilteredBase.train,FilteredBase.valid = add_props(lambda i,x: x.subset(i))", | |
"execution_count": 40, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "show_doc(FilteredBase().dataloaders)", | |
"execution_count": 41, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "<IPython.core.display.Markdown object>", | |
"text/markdown": "<h4 id=\"FilteredBase.dataloaders\" class=\"doc_header\"><code>FilteredBase.dataloaders</code><a href=\"__main__.py#L15\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n\n> <code>FilteredBase.dataloaders</code>(**`bs`**=*`64`*, **`shuffle_train`**=*`None`*, **`shuffle`**=*`True`*, **`val_shuffle`**=*`False`*, **`n`**=*`None`*, **`path`**=*`'.'`*, **`dl_type`**=*`None`*, **`dl_kwargs`**=*`None`*, **`device`**=*`None`*, **`drop_last`**=*`None`*, **`val_bs`**=*`None`*, **`num_workers`**=*`None`*, **`collate_fn`**=*`fa_collate`*, **`verbose`**=*`False`*, **`do_setup`**=*`True`*, **`indexed`**=*`None`*, **`batch_size`**:`Optional`\\[`int`\\]=*`1`*, **`sampler`**:`Optional`\\[`Sampler'>`\\[`int`\\]\\]=*`None`*, **`batch_sampler`**:`Optional`\\[`Sampler'>`\\[`Sequence`\\[`int`\\]\\]\\]=*`None`*, **`pin_memory`**:`bool`=*`False`*, **`timeout`**:`float`=*`0`*, **`worker_init_fn`**:`Optional`\\[`int`\\]=*`None`*, **`multiprocessing_context`**=*`None`*, **`generator`**=*`None`*, **`prefetch_factor`**:`int`=*`2`*, **`persistent_workers`**:`bool`=*`False`*)\n\n" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#export\nclass TfmdLists(FilteredBase, L, GetAttr):\n \"A `Pipeline` of `tfms` applied to a collection of `items`\"\n _default='tfms'\n def __init__(self, items, tfms, use_list=None, do_setup=True, split_idx=None, train_setup=True,\n splits=None, types=None, verbose=False, dl_type=None):\n super().__init__(items, use_list=use_list)\n if dl_type is not None: self._dl_type = dl_type\n self.splits = L([slice(None),[]] if splits is None else splits).map(mask2idxs)\n if isinstance(tfms,TfmdLists): tfms = tfms.tfms\n if isinstance(tfms,Pipeline): do_setup=False\n self.tfms = Pipeline(tfms, split_idx=split_idx)\n store_attr('types,split_idx')\n if do_setup:\n pv(f\"Setting up {self.tfms}\", verbose)\n self.setup(train_setup=train_setup)\n\n def _new(self, items, split_idx=None, **kwargs):\n split_idx = ifnone(split_idx,self.split_idx)\n return super()._new(items, tfms=self.tfms, do_setup=False, types=self.types, split_idx=split_idx, **kwargs)\n def subset(self, i): return self._new(self._get(self.splits[i]), split_idx=i)\n def _after_item(self, o): return self.tfms(o)\n def __repr__(self): return f\"{self.__class__.__name__}: {self.items}\\ntfms - {self.tfms.fs}\"\n def __iter__(self): return (self[i] for i in range(len(self)))\n def show(self, o, **kwargs): return self.tfms.show(o, **kwargs)\n def decode(self, o, **kwargs): return self.tfms.decode(o, **kwargs)\n def __call__(self, o, **kwargs): return self.tfms.__call__(o, **kwargs)\n def overlapping_splits(self): return L(Counter(self.splits.concat()).values()).filter(gt(1))\n def new_empty(self): return self._new([])\n\n def setup(self, train_setup=True):\n self.tfms.setup(self, train_setup)\n if len(self) != 0:\n x = super().__getitem__(0) if self.splits is None else super().__getitem__(self.splits[0])[0]\n self.types = []\n for f in self.tfms.fs:\n self.types.append(getattr(f, 'input_types', type(x)))\n x = f(x)\n self.types.append(type(x))\n types = L(t if is_listy(t) else [t] for t in self.types).concat().unique()\n self.pretty_types = '\\n'.join([f' - {t}' for t in types])\n\n def infer_idx(self, x):\n # TODO: check if we really need this, or can simplify\n idx = 0\n for t in self.types:\n if isinstance(x, t): break\n idx += 1\n types = L(t if is_listy(t) else [t] for t in self.types).concat().unique()\n pretty_types = '\\n'.join([f' - {t}' for t in types])\n assert idx < len(self.types), f\"Expected an input of type in \\n{pretty_types}\\n but got {type(x)}\"\n return idx\n\n def infer(self, x):\n return compose_tfms(x, tfms=self.tfms.fs[self.infer_idx(x):], split_idx=self.split_idx)\n\n def __getitem__(self, idx):\n res = super().__getitem__(idx)\n if self._after_item is None: return res\n return self._after_item(res) if is_indexer(idx) else res.map(self._after_item)", | |
"execution_count": 42, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#export\nadd_docs(TfmdLists,\n setup=\"Transform setup with self\",\n decode=\"From `Pipeline`\",\n show=\"From `Pipeline`\",\n overlapping_splits=\"All splits that are in more than one split\",\n subset=\"New `TfmdLists` with same tfms that only includes items in `i`th split\",\n infer_idx=\"Finds the index where `self.tfms` can be applied to `x`, depending on the type of `x`\",\n infer=\"Apply `self.tfms` to `x` starting at the right tfm depending on the type of `x`\",\n new_empty=\"A new version of `self` but with no items\")", | |
"execution_count": 43, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#exports\ndef decode_at(o, idx):\n \"Decoded item at `idx`\"\n return o.decode(o[idx])", | |
"execution_count": 44, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#exports\ndef show_at(o, idx, **kwargs):\n \"Show item at `idx`\",\n return o.show(o[idx], **kwargs)", | |
"execution_count": 45, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "A `TfmdLists` combines a collection of object with a `Pipeline`. `tfms` can either be a `Pipeline` or a list of transforms, in which case, it will wrap them in a `Pipeline`. `use_list` is passed along to `L` with the `items` and `split_idx` are passed to each transform of the `Pipeline`. `do_setup` indicates if the `Pipeline.setup` method should be called during initialization." | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "class _IntFloatTfm(Transform):\n def encodes(self, o): return TitledInt(o)\n def decodes(self, o): return TitledFloat(o)\nint2f_tfm=_IntFloatTfm()\n\ndef _neg(o): return -o\nneg_tfm = Transform(_neg, _neg)", | |
"execution_count": 46, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "items = L([1.,2.,3.]); tfms = [neg_tfm, int2f_tfm]\ntl = TfmdLists(items, tfms=tfms)\ntest_eq_type(tl[0], TitledInt(-1))\ntest_eq_type(tl[1], TitledInt(-2))\ntest_eq_type(tl.decode(tl[2]), TitledFloat(3.))\ntest_stdout(lambda: show_at(tl, 2), '-3')\ntest_eq(tl.types, [float, float, TitledInt])\ntl", | |
"execution_count": 47, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 47, | |
"data": { | |
"text/plain": "TfmdLists: [1.0, 2.0, 3.0]\ntfms - [_neg:\nencodes: (object,object) -> _negdecodes: (object,object) -> _neg, _IntFloatTfm:\nencodes: (object,object) -> encodes\ndecodes: (object,object) -> decodes\n]" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "# add splits to TfmdLists\nsplits = [[0,2],[1]]\ntl = TfmdLists(items, tfms=tfms, splits=splits)\ntest_eq(tl.n_subsets, 2)\ntest_eq(tl.train, tl.subset(0))\ntest_eq(tl.valid, tl.subset(1))\ntest_eq(tl.train.items, items[splits[0]])\ntest_eq(tl.valid.items, items[splits[1]])\ntest_eq(tl.train.tfms.split_idx, 0)\ntest_eq(tl.valid.tfms.split_idx, 1)\ntest_eq(tl.train.new_empty().split_idx, 0)\ntest_eq(tl.valid.new_empty().split_idx, 1)\ntest_eq_type(tl.splits, L(splits))\nassert not tl.overlapping_splits()", | |
"execution_count": 48, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "df = pd.DataFrame(dict(a=[1,2,3],b=[2,3,4]))\ntl = TfmdLists(df, lambda o: o.a+1, splits=[[0],[1,2]])\ntest_eq(tl[1,2], [3,4])\ntr = tl.subset(0)\ntest_eq(tr[:], [2])\nval = tl.subset(1)\ntest_eq(val[:], [3,4])", | |
"execution_count": 49, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "class _B(Transform):\n def __init__(self): self.m = 0\n def encodes(self, o): return o+self.m\n def decodes(self, o): return o-self.m\n def setups(self, items): \n print(items)\n self.m = tensor(items).float().mean().item()\n\n# test for setup, which updates `self.m`\ntl = TfmdLists(items, _B())\ntest_eq(tl.m, 2)", | |
"execution_count": 50, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "TfmdLists: [1.0, 2.0, 3.0]\ntfms - []\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "Here's how we can use `TfmdLists.setup` to implement a simple category list, getting labels from a mock file list:" | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "class _Cat(Transform):\n order = 1\n def encodes(self, o): return int(self.o2i[o])\n def decodes(self, o): return TitledStr(self.vocab[o])\n def setups(self, items): self.vocab,self.o2i = uniqueify(L(items), sort=True, bidir=True)\ntcat = _Cat()\n\ndef _lbl(o): return TitledStr(o.split('_')[0])\n\n# Check that tfms are sorted by `order` & `_lbl` is called first\nfns = ['dog_0.jpg','cat_0.jpg','cat_2.jpg','cat_1.jpg','dog_1.jpg']\ntl = TfmdLists(fns, [tcat,_lbl])\nexp_voc = ['cat','dog']\ntest_eq(tcat.vocab, exp_voc)\ntest_eq(tl.tfms.vocab, exp_voc)\ntest_eq(tl.vocab, exp_voc)\ntest_eq(tl, (1,0,0,0,1))\ntest_eq([tl.decode(o) for o in tl], ('dog','cat','cat','cat','dog'))", | |
"execution_count": 51, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#Check only the training set is taken into account for setup\ntl = TfmdLists(fns, [tcat,_lbl], splits=[[0,4], [1,2,3]])\ntest_eq(tcat.vocab, ['dog'])", | |
"execution_count": 52, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "tfm = NegTfm(split_idx=1)\ntds = TfmdLists(start, A())\ntdl = TfmdDL(tds, after_batch=tfm, bs=4)\nx = tdl.one_batch()\ntest_eq(x, torch.arange(4))\ntds.split_idx = 1\nx = tdl.one_batch()\ntest_eq(x, -torch.arange(4))\ntds.split_idx = 0\nx = tdl.one_batch()\ntest_eq(x, torch.arange(4))", | |
"execution_count": 53, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "tds = TfmdLists(start, A())\ntdl = TfmdDL(tds, after_batch=NegTfm(), bs=4)\ntest_eq(tdl.dataset[0], start[0])\ntest_eq(len(tdl), (len(tds)-1)//4+1)\ntest_eq(tdl.bs, 4)\ntest_stdout(tdl.show_batch, '0\\n1\\n2\\n3')", | |
"execution_count": 54, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "show_doc(TfmdLists.subset)", | |
"execution_count": 55, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "<IPython.core.display.Markdown object>", | |
"text/markdown": "<h4 id=\"TfmdLists.subset\" class=\"doc_header\"><code>TfmdLists.subset</code><a href=\"__main__.py#L21\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n\n> <code>TfmdLists.subset</code>(**`i`**)\n\nNew [`TfmdLists`](/data.core.html#TfmdLists) with same tfms that only includes items in `i`th split" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "show_doc(TfmdLists.infer_idx)", | |
"execution_count": 56, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "<IPython.core.display.Markdown object>", | |
"text/markdown": "<h4 id=\"TfmdLists.infer_idx\" class=\"doc_header\"><code>TfmdLists.infer_idx</code><a href=\"__main__.py#L43\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n\n> <code>TfmdLists.infer_idx</code>(**`x`**)\n\nFinds the index where `self.tfms` can be applied to `x`, depending on the type of `x`" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "show_doc(TfmdLists.infer)", | |
"execution_count": 57, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "<IPython.core.display.Markdown object>", | |
"text/markdown": "<h4 id=\"TfmdLists.infer\" class=\"doc_header\"><code>TfmdLists.infer</code><a href=\"__main__.py#L54\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n\n> <code>TfmdLists.infer</code>(**`x`**)\n\nApply `self.tfms` to `x` starting at the right tfm depending on the type of `x`" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "def mult(x): return x*2\nmult.order = 2\n\nfns = ['dog_0.jpg','cat_0.jpg','cat_2.jpg','cat_1.jpg','dog_1.jpg']\ntl = TfmdLists(fns, [_lbl,_Cat(),mult])\n\ntest_eq(tl.infer_idx('dog_45.jpg'), 0)\ntest_eq(tl.infer('dog_45.jpg'), 2)\n\ntest_eq(tl.infer_idx(4), 2)\ntest_eq(tl.infer(4), 8)\n\ntest_fail(lambda: tl.infer_idx(2.0))\ntest_fail(lambda: tl.infer(2.0))", | |
"execution_count": 58, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#hide\n#Test input_types works on a Transform\ncat = _Cat()\ncat.input_types = (str, float)\ntl = TfmdLists(fns, [_lbl,cat,mult])\ntest_eq(tl.infer_idx(2.0), 1)\n\n#Test type annotations work on a function\ndef mult(x:(int,float)): return x*2\nmult.order = 2\ntl = TfmdLists(fns, [_lbl,_Cat(),mult])\ntest_eq(tl.infer_idx(2.0), 2)", | |
"execution_count": 59, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "## Datasets -" | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#export\n@docs\n@delegates(TfmdLists)\nclass Datasets(FilteredBase):\n \"A dataset that creates a tuple from each `tfms`, passed through `item_tfms`\"\n def __init__(self, items=None, tfms=None, tls=None, n_inp=None, dl_type=None, **kwargs):\n super().__init__(dl_type=dl_type)\n self.tls = L(tls if tls else [TfmdLists(items, t, **kwargs) for t in L(ifnone(tfms,[None]))])\n self.n_inp = ifnone(n_inp, max(1, len(self.tls)-1))\n\n def __getitem__(self, it):\n res = tuple([tl[it] for tl in self.tls])\n return res if is_indexer(it) else list(zip(*res))\n\n def __getattr__(self,k): return gather_attrs(self, k, 'tls')\n def __dir__(self): return super().__dir__() + gather_attr_names(self, 'tls')\n def __len__(self): return len(self.tls[0])\n def __iter__(self): return (self[i] for i in range(len(self)))\n def __repr__(self): return coll_repr(self)\n def decode(self, o, full=True): return tuple(tl.decode(o_, full=full) for o_,tl in zip(o,tuplify(self.tls, match=o)))\n def subset(self, i): return type(self)(tls=L(tl.subset(i) for tl in self.tls), n_inp=self.n_inp)\n def _new(self, items, *args, **kwargs): return super()._new(items, tfms=self.tfms, do_setup=False, **kwargs)\n def overlapping_splits(self): return self.tls[0].overlapping_splits()\n def new_empty(self): return type(self)(tls=[tl.new_empty() for tl in self.tls], n_inp=self.n_inp)\n @property\n def splits(self): return self.tls[0].splits\n @property\n def split_idx(self): return self.tls[0].tfms.split_idx\n @property\n def items(self): return self.tls[0].items\n @items.setter\n def items(self, v):\n for tl in self.tls: tl.items = v\n\n def show(self, o, ctx=None, **kwargs):\n for o_,tl in zip(o,self.tls): ctx = tl.show(o_, ctx=ctx, **kwargs)\n return ctx\n\n @contextmanager\n def set_split_idx(self, i):\n old_split_idx = self.split_idx\n for tl in self.tls: tl.tfms.split_idx = i\n try: yield self\n finally:\n for tl in self.tls: tl.tfms.split_idx = old_split_idx\n\n _docs=dict(\n decode=\"Compose `decode` of all `tuple_tfms` then all `tfms` on `i`\",\n show=\"Show item `o` in `ctx`\",\n dataloaders=\"Get a `DataLoaders`\",\n overlapping_splits=\"All splits that are in more than one split\",\n subset=\"New `Datasets` that only includes subset `i`\",\n new_empty=\"Create a new empty version of the `self`, keeping only the transforms\",\n set_split_idx=\"Contextmanager to use the same `Datasets` with another `split_idx`\"\n )", | |
"execution_count": 60, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "A `Datasets` creates a tuple from `items` (typically input,target) by applying to them each list of `Transform` (or `Pipeline`) in `tfms`. Note that if `tfms` contains only one list of `tfms`, the items given by `Datasets` will be tuples of one element. \n\n`n_inp` is the number of elements in the tuples that should be considered part of the input and will default to 1 if `tfms` consists of one set of transforms, `len(tfms)-1` otherwise. In most cases, the number of elements in the tuples spit out by `Datasets` will be 2 (for input,target) but it can happen that there is 3 (Siamese networks or tabular data) in which case we need to be able to determine when the inputs end and the targets begin." | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "items = [1,2,3,4]\ndsets = Datasets(items, [[neg_tfm,int2f_tfm], [add(1)]])\nt = dsets[0]\ntest_eq(t, (-1,2))\ntest_eq(dsets[0,1,2], [(-1,2),(-2,3),(-3,4)])\ntest_eq(dsets.n_inp, 1)\ndsets.decode(t)", | |
"execution_count": 61, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 61, | |
"data": { | |
"text/plain": "(1.0, 2)" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "class Norm(Transform):\n def encodes(self, o): return (o-self.m)/self.s\n def decodes(self, o): return (o*self.s)+self.m\n def setups(self, items):\n its = tensor(items).float()\n self.m,self.s = its.mean(),its.std()", | |
"execution_count": 62, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "items = [1,2,3,4]\nnrm = Norm()\ndsets = Datasets(items, [[neg_tfm,int2f_tfm], [neg_tfm,nrm]])\n\nx,y = zip(*dsets)\ntest_close(tensor(y).mean(), 0)\ntest_close(tensor(y).std(), 1)\ntest_eq(x, (-1,-2,-3,-4,))\ntest_eq(nrm.m, -2.5)\ntest_stdout(lambda:show_at(dsets, 1), '-2')\n\ntest_eq(dsets.m, nrm.m)\ntest_eq(dsets.norm.m, nrm.m)\ntest_eq(dsets.train.norm.m, nrm.m)", | |
"execution_count": 63, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#hide\n#Check filtering is properly applied\nclass B(Transform):\n def encodes(self, x)->None: return int(x+1)\n def decodes(self, x): return TitledInt(x-1)\nadd1 = B(split_idx=1)\n\ndsets = Datasets(items, [neg_tfm, [neg_tfm,int2f_tfm,add1]], splits=[[3],[0,1,2]])\ntest_eq(dsets[1], [-2,-2])\ntest_eq(dsets.valid[1], [-2,-1])\ntest_eq(dsets.valid[[1,1]], [[-2,-1], [-2,-1]])\ntest_eq(dsets.train[0], [-4,-4])", | |
"execution_count": 64, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "test_fns = ['dog_0.jpg','cat_0.jpg','cat_2.jpg','cat_1.jpg','kid_1.jpg']\ntcat = _Cat()\ndsets = Datasets(test_fns, [[tcat,_lbl]], splits=[[0,1,2], [3,4]])\ntest_eq(tcat.vocab, ['cat','dog'])\ntest_eq(dsets.train, [(1,),(0,),(0,)])\ntest_eq(dsets.valid[0], (0,))\ntest_stdout(lambda: show_at(dsets.train, 0), \"dog\")", | |
"execution_count": 65, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "inp = [0,1,2,3,4]\ndsets = Datasets(inp, tfms=[None])\n\ntest_eq(*dsets[2], 2) # Retrieve one item (subset 0 is the default)\ntest_eq(dsets[1,2], [(1,),(2,)]) # Retrieve two items by index\nmask = [True,False,False,True,False]\ntest_eq(dsets[mask], [(0,),(3,)]) # Retrieve two items by mask", | |
"execution_count": 66, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "inp = pd.DataFrame(dict(a=[5,1,2,3,4]))\ndsets = Datasets(inp, tfms=attrgetter('a')).subset(0)\ntest_eq(*dsets[2], 2) # Retrieve one item (subset 0 is the default)\ntest_eq(dsets[1,2], [(1,),(2,)]) # Retrieve two items by index\nmask = [True,False,False,True,False]\ntest_eq(dsets[mask], [(5,),(3,)]) # Retrieve two items by mask", | |
"execution_count": 67, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#test n_inp\ninp = [0,1,2,3,4]\ndsets = Datasets(inp, tfms=[None])\ntest_eq(dsets.n_inp, 1)\ndsets = Datasets(inp, tfms=[[None],[None],[None]])\ntest_eq(dsets.n_inp, 2)\ndsets = Datasets(inp, tfms=[[None],[None],[None]], n_inp=1)\ntest_eq(dsets.n_inp, 1)", | |
"execution_count": 68, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "# splits can be indices\ndsets = Datasets(range(5), tfms=[None], splits=[tensor([0,2]), [1,3,4]])\n\ntest_eq(dsets.subset(0), [(0,),(2,)])\ntest_eq(dsets.train, [(0,),(2,)]) # Subset 0 is aliased to `train`\ntest_eq(dsets.subset(1), [(1,),(3,),(4,)])\ntest_eq(dsets.valid, [(1,),(3,),(4,)]) # Subset 1 is aliased to `valid`\ntest_eq(*dsets.valid[2], 4)\n#assert '[(1,),(3,),(4,)]' in str(dsets) and '[(0,),(2,)]' in str(dsets)\ndsets", | |
"execution_count": 69, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 69, | |
"data": { | |
"text/plain": "(#5) [(0,),(1,),(2,),(3,),(4,)]" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "# splits can be boolean masks (they don't have to cover all items, but must be disjoint)\nsplits = [[False,True,True,False,True], [True,False,False,False,False]]\ndsets = Datasets(range(5), tfms=[None], splits=splits)\n\ntest_eq(dsets.train, [(1,),(2,),(4,)])\ntest_eq(dsets.valid, [(0,)])", | |
"execution_count": 70, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "# apply transforms to all items\ntfm = [[lambda x: x*2,lambda x: x+1]]\nsplits = [[1,2],[0,3,4]]\ndsets = Datasets(range(5), tfm, splits=splits)\ntest_eq(dsets.train,[(3,),(5,)])\ntest_eq(dsets.valid,[(1,),(7,),(9,)])\ntest_eq(dsets.train[False,True], [(5,)])", | |
"execution_count": 71, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "# only transform subset 1\nclass _Tfm(Transform):\n split_idx=1\n def encodes(self, x): return x*2\n def decodes(self, x): return TitledStr(x//2)", | |
"execution_count": 72, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "dsets = Datasets(range(5), [_Tfm()], splits=[[1,2],[0,3,4]])\ntest_eq(dsets.train,[(1,),(2,)])\ntest_eq(dsets.valid,[(0,),(6,),(8,)])\ntest_eq(dsets.train[False,True], [(2,)])\ndsets", | |
"execution_count": 73, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 73, | |
"data": { | |
"text/plain": "(#5) [(0,),(1,),(2,),(3,),(4,)]" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#A context manager to change the split_idx and apply the validation transform on the training set\nds = dsets.train\nwith ds.set_split_idx(1):\n test_eq(ds,[(2,),(4,)])\ntest_eq(dsets.train,[(1,),(2,)])", | |
"execution_count": 74, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#hide\n#Test Datasets pickles\ndsrc1 = pickle.loads(pickle.dumps(dsets))\ntest_eq(dsets.train, dsrc1.train)\ntest_eq(dsets.valid, dsrc1.valid)", | |
"execution_count": 75, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "dsets = Datasets(range(5), [_Tfm(),noop], splits=[[1,2],[0,3,4]])\ntest_eq(dsets.train,[(1,1),(2,2)])\ntest_eq(dsets.valid,[(0,0),(6,3),(8,4)])", | |
"execution_count": 76, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "start = torch.arange(0,50)\ntds = Datasets(start, [A()])\ntdl = TfmdDL(tds, after_item=NegTfm(), bs=4)\nb = tdl.one_batch()\ntest_eq(tdl.decode_batch(b), ((0,),(1,),(2,),(3,)))\ntest_stdout(tdl.show_batch, \"0\\n1\\n2\\n3\")", | |
"execution_count": 77, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "# only transform subset 1\nclass _Tfm(Transform):\n split_idx=1\n def encodes(self, x): return x*2\n\ndsets = Datasets(range(8), [None], splits=[[1,2,5,7],[0,3,4,6]])", | |
"execution_count": 78, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "# only transform subset 1\nclass _Tfm(Transform):\n split_idx=1\n def encodes(self, x): return x*2\n\ndsets = Datasets(range(8), [None], splits=[[1,2,5,7],[0,3,4,6]])\ndls = dsets.dataloaders(bs=4, after_batch=_Tfm(), shuffle=False, device=torch.device('cpu'))\ntest_eq(dls.train, [(tensor([1,2,5, 7]),)])\ntest_eq(dls.valid, [(tensor([0,6,8,12]),)])\ntest_eq(dls.n_inp, 1)", | |
"execution_count": 79, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "### Methods" | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "items = [1,2,3,4]\ndsets = Datasets(items, [[neg_tfm,int2f_tfm]])", | |
"execution_count": 80, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#hide_input\n_dsrc = Datasets([1,2])\nshow_doc(_dsrc.dataloaders, name=\"Datasets.dataloaders\")", | |
"execution_count": 81, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "<IPython.core.display.Markdown object>", | |
"text/markdown": "<h4 id=\"Datasets.dataloaders\" class=\"doc_header\"><code>Datasets.dataloaders</code><a href=\"__main__.py#L15\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n\n> <code>Datasets.dataloaders</code>(**`bs`**=*`64`*, **`shuffle_train`**=*`None`*, **`shuffle`**=*`True`*, **`val_shuffle`**=*`False`*, **`n`**=*`None`*, **`path`**=*`'.'`*, **`dl_type`**=*`None`*, **`dl_kwargs`**=*`None`*, **`device`**=*`None`*, **`drop_last`**=*`None`*, **`val_bs`**=*`None`*, **`num_workers`**=*`None`*, **`collate_fn`**=*`fa_collate`*, **`verbose`**=*`False`*, **`do_setup`**=*`True`*, **`indexed`**=*`None`*, **`batch_size`**:`Optional`\\[`int`\\]=*`1`*, **`sampler`**:`Optional`\\[`Sampler'>`\\[`int`\\]\\]=*`None`*, **`batch_sampler`**:`Optional`\\[`Sampler'>`\\[`Sequence`\\[`int`\\]\\]\\]=*`None`*, **`pin_memory`**:`bool`=*`False`*, **`timeout`**:`float`=*`0`*, **`worker_init_fn`**:`Optional`\\[`int`\\]=*`None`*, **`multiprocessing_context`**=*`None`*, **`generator`**=*`None`*, **`prefetch_factor`**:`int`=*`2`*, **`persistent_workers`**:`bool`=*`False`*)\n\nGet a [`DataLoaders`](/data.core.html#DataLoaders)" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "Used to create dataloaders. You may prepend 'val_' as in `val_shuffle` to override functionality for the validation set. `dl_kwargs` gives finer per dataloader control if you need to work with more than one dataloader. " | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "show_doc(Datasets.decode)", | |
"execution_count": 82, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "<IPython.core.display.Markdown object>", | |
"text/markdown": "<h4 id=\"Datasets.decode\" class=\"doc_header\"><code>Datasets.decode</code><a href=\"__main__.py#L20\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n\n> <code>Datasets.decode</code>(**`o`**, **`full`**=*`True`*)\n\nCompose `decode` of all `tuple_tfms` then all `tfms` on `i`" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "test_eq(*dsets[0], -1)\ntest_eq(*dsets.decode((-1,)), 1)", | |
"execution_count": 83, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "show_doc(Datasets.show)", | |
"execution_count": 84, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "<IPython.core.display.Markdown object>", | |
"text/markdown": "<h4 id=\"Datasets.show\" class=\"doc_header\"><code>Datasets.show</code><a href=\"__main__.py#L35\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n\n> <code>Datasets.show</code>(**`o`**, **`ctx`**=*`None`*, **\\*\\*`kwargs`**)\n\nShow item `o` in `ctx`" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "test_stdout(lambda:dsets.show(dsets[1]), '-2')", | |
"execution_count": 85, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "show_doc(Datasets.new_empty)", | |
"execution_count": 86, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "<IPython.core.display.Markdown object>", | |
"text/markdown": "<h4 id=\"Datasets.new_empty\" class=\"doc_header\"><code>Datasets.new_empty</code><a href=\"__main__.py#L24\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n\n> <code>Datasets.new_empty</code>()\n\nCreate a new empty version of the `self`, keeping only the transforms" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "items = [1,2,3,4]\nnrm = Norm()\ndsets = Datasets(items, [[neg_tfm,int2f_tfm], [neg_tfm]])\nempty = dsets.new_empty()\ntest_eq(empty.items, [])", | |
"execution_count": 87, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#hide\n#test it works for dataframes too\ndf = pd.DataFrame({'a':[1,2,3,4,5], 'b':[6,7,8,9,10]})\ndsets = Datasets(df, [[attrgetter('a')], [attrgetter('b')]])\nempty = dsets.new_empty()", | |
"execution_count": 88, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "## Add test set for inference" | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "class _Tfm(Transform):\n split_idx=1\n def encodes(self, x): return x*2", | |
"execution_count": 89, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "# only transform subset 0\nclass _Tfm1(Transform):\n split_idx=0\n def encodes(self, x): return x*3\n\ndsets = Datasets(range(8), [[_Tfm(),_Tfm1()]], splits=[[1,2,5,7],[0,3,4,6]])\ntest_eq(dsets.train, [(3,),(6,),(15,),(21,)])\ntest_eq(dsets.valid, [(0,),(6,),(8,),(12,)])", | |
"execution_count": 90, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#export\ndef test_set(dsets, test_items, rm_tfms=None, with_labels=False):\n \"Create a test set from `test_items` using validation transforms of `dsets`\"\n if isinstance(dsets, (Datasets, TfmdDataset)):\n tls = dsets.tls if with_labels else dsets.tls[:dsets.n_inp]\n test_tls = [tl._new(test_items, split_idx=1) for tl in tls]\n if rm_tfms is None: rm_tfms = [tl.infer_idx(get_first(test_items)) for tl in test_tls]\n else: rm_tfms = tuplify(rm_tfms, match=test_tls)\n for i,j in enumerate(rm_tfms): test_tls[i].tfms.fs = test_tls[i].tfms.fs[j:]\n return Datasets(tls=test_tls)\n elif isinstance(dsets, TfmdLists):\n test_tl = dsets._new(test_items, split_idx=1)\n if rm_tfms is None: rm_tfms = dsets.infer_idx(get_first(test_items))\n test_tl.tfms.fs = test_tl.tfms.fs[rm_tfms:]\n return test_tl\n else: raise Exception(f\"This method requires using the fastai library to assemble your data. Expected a `Datasets` or a `TfmdLists` but got {dsets.__class__.__name__}\")", | |
"execution_count": 91, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "class _Tfm1(Transform):\n split_idx=0\n def encodes(self, x): return x*3\n\ndsets = Datasets(range(8), [[_Tfm(),_Tfm1()]], splits=[[1,2,5,7],[0,3,4,6]])\ntest_eq(dsets.train, [(3,),(6,),(15,),(21,)])\ntest_eq(dsets.valid, [(0,),(6,),(8,),(12,)])\n\n#Tranform of the validation set are applied\ntst = test_set(dsets, [1,2,3])\ntest_eq(tst, [(2,),(4,),(6,)])", | |
"execution_count": 92, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#hide\n#Test with different types\ntfm = _Tfm1()\ntfm.split_idx,tfm.order = None,2\ndsets = Datasets(['dog', 'cat', 'cat', 'dog'], [[_Cat(),tfm]])\n\n#With strings\ntest_eq(test_set(dsets, ['dog', 'cat', 'cat']), [(3,), (0,), (0,)])\n#With ints\ntest_eq(test_set(dsets, [1,2]), [(3,), (6,)])", | |
"execution_count": 93, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#hide\n#Test with various input lengths\ndsets = Datasets(range(8), [[_Tfm(),_Tfm1()],[_Tfm(),_Tfm1()],[_Tfm(),_Tfm1()]], splits=[[1,2,5,7],[0,3,4,6]])\ntst = test_set(dsets, [1,2,3])\ntest_eq(tst, [(2,2),(4,4),(6,6)])\n\ndsets = Datasets(range(8), [[_Tfm(),_Tfm1()],[_Tfm(),_Tfm1()],[_Tfm(),_Tfm1()]], splits=[[1,2,5,7],[0,3,4,6]], n_inp=1)\ntst = test_set(dsets, [1,2,3])\ntest_eq(tst, [(2,),(4,),(6,)])", | |
"execution_count": 94, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#hide\n#Test with rm_tfms\ndsets = Datasets(range(8), [[_Tfm(),_Tfm()]], splits=[[1,2,5,7],[0,3,4,6]])\ntst = test_set(dsets, [1,2,3])\ntest_eq(tst, [(4,),(8,),(12,)])\n\ndsets = Datasets(range(8), [[_Tfm(),_Tfm()]], splits=[[1,2,5,7],[0,3,4,6]])\ntst = test_set(dsets, [1,2,3], rm_tfms=1)\ntest_eq(tst, [(2,),(4,),(6,)])\n\ndsets = Datasets(range(8), [[_Tfm(),_Tfm()], [_Tfm(),_Tfm()]], splits=[[1,2,5,7],[0,3,4,6]], n_inp=2)\ntst = test_set(dsets, [1,2,3], rm_tfms=(1,0))\ntest_eq(tst, [(2,4),(4,8),(6,12)])", | |
"execution_count": 95, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#export\n@patch\n@delegates(TfmdDL.__init__)\ndef test_dl(self:DataLoaders, test_items, rm_type_tfms=None, with_labels=False, **kwargs):\n \"Create a test dataloader from `test_items` using validation transforms of `dls`\"\n test_ds = test_set(self.valid_ds, test_items, rm_tfms=rm_type_tfms, with_labels=with_labels\n ) if isinstance(self.valid_ds, (Datasets, TfmdLists, TfmdDataset)) else test_items\n return self.valid.new(test_ds, **kwargs)", | |
"execution_count": 96, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "## Export -" | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#hide\nfrom nbdev.export import notebook2script\nnotebook2script()", | |
"execution_count": 97, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "Converted 00_torch_core.ipynb.\nConverted 01_layers.ipynb.\nConverted 01a_losses.ipynb.\nConverted 02_data.load.ipynb.\nConverted 02b_data.pytorch_load.ipynb.\nConverted 02c_data.pytorch.ipynb.\nConverted 03_data.core.ipynb.\nConverted 04_data.external.ipynb.\nConverted 05_data.transforms.ipynb.\nConverted 06_data.block.ipynb.\nConverted 07_vision.core.ipynb.\nConverted 08_vision.data.ipynb.\nConverted 09_vision.augment.ipynb.\nConverted 09b_vision.utils.ipynb.\nConverted 09c_vision.widgets.ipynb.\nConverted 10_tutorial.pets.ipynb.\nConverted 10b_tutorial.albumentations.ipynb.\nConverted 11_vision.models.xresnet.ipynb.\nConverted 12_optimizer.ipynb.\nConverted 13_callback.core.ipynb.\nConverted 13a_learner.ipynb.\nConverted 13b_metrics.ipynb.\nConverted 14_callback.schedule.ipynb.\nConverted 14a_callback.data.ipynb.\nConverted 15_callback.hook.ipynb.\nConverted 15a_vision.models.unet.ipynb.\nConverted 16_callback.progress.ipynb.\nConverted 17_callback.tracker.ipynb.\nConverted 18_callback.fp16.ipynb.\nConverted 18a_callback.training.ipynb.\nConverted 18b_callback.preds.ipynb.\nConverted 19_callback.mixup.ipynb.\nConverted 20_interpret.ipynb.\nConverted 20a_distributed.ipynb.\nConverted 21_vision.learner.ipynb.\nConverted 22_tutorial.imagenette.ipynb.\nConverted 23_tutorial.vision.ipynb.\nConverted 24_tutorial.image_sequence.ipynb.\nConverted 24_tutorial.siamese.ipynb.\nConverted 24_vision.gan.ipynb.\nConverted 30_text.core.ipynb.\nConverted 31_text.data.ipynb.\nConverted 32_text.models.awdlstm.ipynb.\nConverted 33_text.models.core.ipynb.\nConverted 34_callback.rnn.ipynb.\nConverted 35_tutorial.wikitext.ipynb.\nConverted 36_text.models.qrnn.ipynb.\nConverted 37_text.learner.ipynb.\nConverted 38_tutorial.text.ipynb.\nConverted 39_tutorial.transformers.ipynb.\nConverted 40_tabular.core.ipynb.\nConverted 41_tabular.data.ipynb.\nConverted 42_tabular.model.ipynb.\nConverted 43_tabular.learner.ipynb.\nConverted 44_tutorial.tabular.ipynb.\nConverted 45_collab.ipynb.\nConverted 46_tutorial.collab.ipynb.\nConverted 50_tutorial.datablock.ipynb.\nConverted 60_medical.imaging.ipynb.\nConverted 61_tutorial.medical_imaging.ipynb.\nConverted 65_medical.text.ipynb.\nConverted 70_callback.wandb.ipynb.\nConverted 71_callback.tensorboard.ipynb.\nConverted 72_callback.neptune.ipynb.\nConverted 73_callback.captum.ipynb.\nConverted 74_callback.azureml.ipynb.\nConverted 97_test_utils.ipynb.\nConverted 99_pytorch_doc.ipynb.\nConverted dev-setup.ipynb.\nConverted index.ipynb.\nConverted quick_start.ipynb.\nConverted tutorial.ipynb.\n", | |
"name": "stdout" | |
} | |
] | |
} | |
], | |
"metadata": { | |
"jupytext": { | |
"split_at_heading": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3", | |
"language": "python" | |
}, | |
"language_info": { | |
"name": "python", | |
"version": "3.7.10", | |
"mimetype": "text/x-python", | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"pygments_lexer": "ipython3", | |
"nbconvert_exporter": "python", | |
"file_extension": ".py" | |
}, | |
"gist": { | |
"id": "", | |
"data": { | |
"description": "git_repos/fastai/nbs/03_data.core.ipynb", | |
"public": true | |
} | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment