Last active
January 7, 2023 17:00
-
-
Save diaoenmao/23f18fd78ac8da9615c347905e64fc78 to your computer and use it in GitHub Desktop.
Extract patches from images and recover orginal images from patches
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def extract_patches_2d(img,patch_shape,step=[1.0,1.0],batch_first=False): | |
patch_H, patch_W = patch_shape[0], patch_shape[1] | |
if(img.size(2)<patch_H): | |
num_padded_H_Top = (patch_H - img.size(2))//2 | |
num_padded_H_Bottom = patch_H - img.size(2) - num_padded_H_Top | |
padding_H = nn.ConstantPad2d((0,0,num_padded_H_Top,num_padded_H_Bottom),0) | |
img = padding_H(img) | |
if(img.size(3)<patch_W): | |
num_padded_W_Left = (patch_W - img.size(3))//2 | |
num_padded_W_Right = patch_W - img.size(3) - num_padded_W_Left | |
padding_W = nn.ConstantPad2d((num_padded_W_Left,num_padded_W_Right,0,0),0) | |
img = padding_W(img) | |
step_int = [0,0] | |
step_int[0] = int(patch_H*step[0]) if(isinstance(step[0], float)) else step[0] | |
step_int[1] = int(patch_W*step[1]) if(isinstance(step[1], float)) else step[1] | |
patches_fold_H = img.unfold(2, patch_H, step_int[0]) | |
if((img.size(2) - patch_H) % step_int[0] != 0): | |
patches_fold_H = torch.cat((patches_fold_H,img[:,:,-patch_H:,].permute(0,1,3,2).unsqueeze(2)),dim=2) | |
patches_fold_HW = patches_fold_H.unfold(3, patch_W, step_int[1]) | |
if((img.size(3) - patch_W) % step_int[1] != 0): | |
patches_fold_HW = torch.cat((patches_fold_HW,patches_fold_H[:,:,:,-patch_W:,:].permute(0,1,2,4,3).unsqueeze(3)),dim=3) | |
patches = patches_fold_HW.permute(2,3,0,1,4,5) | |
patches = patches.reshape(-1,img.size(0),img.size(1),patch_H,patch_W) | |
if(batch_first): | |
patches = patches.permute(1,0,2,3,4) | |
return patches | |
def reconstruct_from_patches_2d(patches,img_shape,step=[1.0,1.0],batch_first=False): | |
if(batch_first): | |
patches = patches.permute(1,0,2,3,4) | |
patch_H, patch_W = patches.size(3), patches.size(4) | |
img_size = (patches.size(1), patches.size(2),max(img_shape[0], patch_H), max(img_shape[1], patch_W)) | |
step_int = [0,0] | |
step_int[0] = int(patch_H*step[0]) if(isinstance(step[0], float)) else step[0] | |
step_int[1] = int(patch_W*step[1]) if(isinstance(step[1], float)) else step[1] | |
nrow, ncol = 1 + (img_size[-2] - patch_H)//step_int[0], 1 + (img_size[-1] - patch_W)//step_int[1] | |
r_nrow = nrow + 1 if((img_size[2] - patch_H) % step_int[0] != 0) else nrow | |
r_ncol = ncol + 1 if((img_size[3] - patch_W) % step_int[1] != 0) else ncol | |
patches = patches.reshape(r_nrow,r_ncol,img_size[0],img_size[1],patch_H,patch_W) | |
img = torch.zeros(img_size, device = patches.device) | |
overlap_counter = torch.zeros(img_size, device = patches.device) | |
for i in range(nrow): | |
for j in range(ncol): | |
img[:,:,i*step_int[0]:i*step_int[0]+patch_H,j*step_int[1]:j*step_int[1]+patch_W] += patches[i,j,] | |
overlap_counter[:,:,i*step_int[0]:i*step_int[0]+patch_H,j*step_int[1]:j*step_int[1]+patch_W] += 1 | |
if((img_size[2] - patch_H) % step_int[0] != 0): | |
for j in range(ncol): | |
img[:,:,-patch_H:,j*step_int[1]:j*step_int[1]+patch_W] += patches[-1,j,] | |
overlap_counter[:,:,-patch_H:,j*step_int[1]:j*step_int[1]+patch_W] += 1 | |
if((img_size[3] - patch_W) % step_int[1] != 0): | |
for i in range(nrow): | |
img[:,:,i*step_int[0]:i*step_int[0]+patch_H,-patch_W:] += patches[i,-1,] | |
overlap_counter[:,:,i*step_int[0]:i*step_int[0]+patch_H,-patch_W:] += 1 | |
if((img_size[2] - patch_H) % step_int[0] != 0 and (img_size[3] - patch_W) % step_int[1] != 0): | |
img[:,:,-patch_H:,-patch_W:] += patches[-1,-1,] | |
overlap_counter[:,:,-patch_H:,-patch_W:] += 1 | |
img /= overlap_counter | |
if(img_shape[0]<patch_H): | |
num_padded_H_Top = (patch_H - img_shape[0])//2 | |
num_padded_H_Bottom = patch_H - img_shape[0] - num_padded_H_Top | |
img = img[:,:,num_padded_H_Top:-num_padded_H_Bottom,] | |
if(img_shape[1]<patch_W): | |
num_padded_W_Left = (patch_W - img_shape[1])//2 | |
num_padded_W_Right = patch_W - img_shape[1] - num_padded_W_Left | |
img = img[:,:,:,num_padded_W_Left:-num_padded_W_Right] | |
return img |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Is it possible to write the code for 3D images?
For example, is the size of the image is 172,220,156 and you wanna create patches of 32x32x32 or so?