Skip to content

Instantly share code, notes, and snippets.

@vadimkantorov
Last active November 4, 2021 14:38
Show Gist options
  • Save vadimkantorov/9fec07a6af7fb35ebdacb8b9648d7efb to your computer and use it in GitHub Desktop.
Save vadimkantorov/9fec07a6af7fb35ebdacb8b9648d7efb to your computer and use it in GitHub Desktop.
PyTorch useful functions
import torch
import torch.nn.functional as F
import torchvision
import base64
gather_incomplete = lambda tensor, I: tensor.gather(I.ndim, I[(...,) + (None,) * (tensor.ndim - I.ndim)].expand((-1,) * (I.ndim + 1) + tensor.shape[I.ndim + 1:])).squeeze(I.ndim)
expand_dim = lambda tensor, expand, dim: tensor.unsqueeze(dim).expand((-1, ) * (dim if dim >= 0 else tensor.ndim + dim + 1) + (expand, ) + (-1, ) * (tensor.ndim - (dim if dim >= 0 else tensor.ndim + dim + 1)))
encode_image_as_html = lambda img, height, width: '<img height="{height}" width="{width}" src="data:image/jpeg;base64,{encoded}" />'.format(height = height, width = width, encoded = bytes(base64.b64encode(torchvision.io.encode_jpeg(F.interpolate(img[None], (height, width)).squeeze(0)))).decode())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment