Created
September 24, 2023 13:32
-
-
Save cloneofsimo/6d1d8e98ce25fb88cd95565cef18ddf4 to your computer and use it in GitHub Desktop.
Got confused by Unfold operation, yet again LOL
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch.nn as nn | |
# Load an example image | |
image_path = "/home/simo/just_dl_stuff/vid2data/vender_1.png" | |
img = Image.open(image_path).convert("RGB") | |
img = img.resize((512, 512)) # Resize to desired size | |
img_tensor = torch.tensor(np.array(img)).permute(2, 0, 1).unsqueeze(0).float() | |
B, C, H, W = img_tensor.shape | |
P = 128 | |
patch_pix = 128 | |
# Use Unfold module | |
unfold = nn.Unfold(kernel_size = [patch_pix, patch_pix], stride = [patch_pix, patch_pix]) | |
img_unfolded = unfold(img_tensor) | |
print(img_unfolded.shape) # torch.Size([1, 49152, 16]) | |
# Rearrange the dimensions | |
img_patched = img_unfolded.reshape(B, 3, P, P, -1) | |
print(img_patched.shape) # torch.Size([1, 3, 128, 128, 16]) | |
# Visualization | |
fig, axs = plt.subplots(H//P, W//P, figsize=(8, 8)) | |
for i in range(H//P): | |
for j in range(W//P): | |
patch = img_patched[0, :, :, :, i * (W//P) + j].permute(1, 2, 0).numpy().astype(np.uint8) | |
axs[i, j].imshow(patch) | |
axs[i, j].axis('off') | |
plt.tight_layout() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Vender Vender