Skip to content

Instantly share code, notes, and snippets.

@cloneofsimo
Created September 24, 2023 13:32
Show Gist options
  • Save cloneofsimo/6d1d8e98ce25fb88cd95565cef18ddf4 to your computer and use it in GitHub Desktop.
Save cloneofsimo/6d1d8e98ce25fb88cd95565cef18ddf4 to your computer and use it in GitHub Desktop.
Got confused by Unfold operation, yet again LOL
import torch
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
# Load an example image
image_path = "/home/simo/just_dl_stuff/vid2data/vender_1.png"
img = Image.open(image_path).convert("RGB")
img = img.resize((512, 512)) # Resize to desired size
img_tensor = torch.tensor(np.array(img)).permute(2, 0, 1).unsqueeze(0).float()
B, C, H, W = img_tensor.shape
P = 128
patch_pix = 128
# Use Unfold module
unfold = nn.Unfold(kernel_size = [patch_pix, patch_pix], stride = [patch_pix, patch_pix])
img_unfolded = unfold(img_tensor)
print(img_unfolded.shape) # torch.Size([1, 49152, 16])
# Rearrange the dimensions
img_patched = img_unfolded.reshape(B, 3, P, P, -1)
print(img_patched.shape) # torch.Size([1, 3, 128, 128, 16])
# Visualization
fig, axs = plt.subplots(H//P, W//P, figsize=(8, 8))
for i in range(H//P):
for j in range(W//P):
patch = img_patched[0, :, :, :, i * (W//P) + j].permute(1, 2, 0).numpy().astype(np.uint8)
axs[i, j].imshow(patch)
axs[i, j].axis('off')
plt.tight_layout()
plt.show()
@cloneofsimo
Copy link
Author

image
Vender Vender

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment