Last active
December 22, 2021 06:39
-
-
Save rish-16/4ec691c0f41340f4dc44eb8f51c91bbc to your computer and use it in GitHub Desktop.
Visualise selected patches from an image for comparison / sanity checks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import transforms | |
import numpy as np | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from patchify import patchify, unpatchify # pip install patchify | |
def viz(image, selector, pw=10, ph=10): | |
''' | |
Params | |
- image: a PyTorch Tensor of shape (H, W, C) | |
- selector: the patch selector of choice (should return indices) | |
Returns | |
None | just plots the image and processed image side-by-side | |
''' | |
def pad(img): | |
# implement your padding function if image is not 4x4 | |
raise NotImplementedError | |
# image = pad(img) # comment out if image is perfectly breakable into 4x4 | |
patches = patchify(image.numpy(), [ph, pw, 3], step=10) | |
patches_new = torch.from_numpy(patches.reshape(21*16, 3, ph, pw)) | |
N, C, h, w = patches_new.shape | |
# should return a tensor of K indices | |
patch_ids = selector(patches_new) # tweak if you have a custom way of getting indices | |
for i in range(len(patch_ids)): | |
cid = patch_ids[i].item() | |
patches_new[cid, :, :, :] = torch.ones(3, ph, pw) | |
reformed = patches_new.view(*patches.shape) | |
reformed_np = reformed.numpy() | |
reformed_img = unpatchify(reformed_np, image.shape) # reform image from patches | |
fig = plt.figure() | |
fig.add_subplot(121) | |
plt.imshow(image.reshape(210, 160, 3)) | |
plt.title("Original") | |
fig.add_subplot(122) | |
plt.imshow(reformed_img) | |
plt.title("Patches") | |
plt.show() | |
def dummy_clip(patches): | |
N, dim = patches.shape # (16, 512) | |
out = torch.Tensor(N, 512) | |
layer = nn.Linear(dim, 512) | |
for i in range(N): | |
out[i] = layer(patches[i, :]) | |
return out | |
def dummy_selector(patches): | |
idx = torch.randperm(16*21)[:10] | |
return idx | |
# viz(img_new, dummy_selector, dummy_clip) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment