Skip to content

Instantly share code, notes, and snippets.

@jvasilakes
Created September 6, 2022 12:00
Show Gist options
  • Save jvasilakes/ea4a4b402104c20154d856ceece74382 to your computer and use it in GitHub Desktop.
Save jvasilakes/ea4a4b402104c20154d856ceece74382 to your computer and use it in GitHub Desktop.
A generic function to uncollate a batch in PyTorch
import torch
def _is_list_like_excluding_str(x):
if isinstance(x, str):
return False
try:
iter(x)
except TypeError:
return False
return True
def uncollate(batch):
"""
Modified from
https://lightning-flash.readthedocs.io/en/stable/_modules/flash/core/data/batch.html#default_uncollate
to work with arbitarily nested batches.
>>> batch = {'x': [[0., 1., 1.], [1., 1., 0.]], "metadata": {"example_id": [0, 1]}}
>>> uncollate(batch)
[{'x': [0., 1., 1.], "metadata": {"example_id": 0}},
{'x': [1., 1., 0.], "metadata": {"example_id": 1}}]
This function is used to uncollate a batch into samples.
The following conditions are used:
- if the ``batch`` is a ``dict``, the result will be a list of dicts
- if the ``batch`` is list-like, the result is guaranteed to be a list
Args:
batch: The batch of outputs to be uncollated.
Returns:
The uncollated list of predictions.
Raises:
ValueError: If input ``dict`` values are not all list-like.
ValueError: If input ``dict`` values are not all the same length.
ValueError: If the input is not a ``dict`` or list-like.
"""
if isinstance(batch, dict):
if any(not _is_list_like_excluding_str(sub_batch)
for sub_batch in batch.values()):
raise ValueError("When uncollating a dict, all sub-batches (values) are expected to be list-like.") # noqa
uncollated_vals = [uncollate(val) for val in batch.values()]
if len(set([len(v) for v in uncollated_vals])) > 1:
raise ValueError("When uncollating a dict, all sub-batches (values) are expected to have the same length.") # noqa
elements = list(zip(*uncollated_vals))
return [dict(zip(batch.keys(), element)) for element in elements]
if isinstance(batch, (list, tuple, torch.Tensor)):
return list(batch)
raise ValueError(
"The batch of outputs to be uncollated is expected to be a `dict` or list-like " # noqa
f"(e.g. `Tensor`, `list`, `tuple`, etc.), but got input of type: {type(batch)}" # noqa
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment