Skip to content

Instantly share code, notes, and snippets.

@younesbelkada
Created January 30, 2023 18:57
Show Gist options
  • Select an option

  • Save younesbelkada/0d67445f5e8a874e9c239cd7789c5bdb to your computer and use it in GitHub Desktop.

Select an option

Save younesbelkada/0d67445f5e8a874e9c239cd7789c5bdb to your computer and use it in GitHub Desktop.
from `tf` to `torch` extract to patches
import tensorflow as tf
import torch
import math
import torch.nn.functional as F
# adapted from: https://discuss.pytorch.org/t/tf-extract-image-patches-in-pytorch/43837/8
def torch_extract_patches(
x, patch_height, patch_width, padding=None
):
x = x.unsqueeze(0)
if padding == "SAME":
x = F.pad(x, (1, 1, 1, 1))
# patches = x.unfold(2, patch_height, patch_height).unfold(3, patch_width, patch_width)
patches = x.unfold(2, patch_height, patch_height).unfold(3, patch_width, patch_width)
# Permute so that channels are next to patch dimension
patches = patches.permute(0, 2, 3, 1, 5, 4).contiguous() # [128, 32, 32, 16, 3, 3]
# View as [batch_size, height, width, channels*kh*kw]
patches = patches.reshape(*patches.size()[:3], -1)
return patches
def patch_sequence(
image: tf.Tensor,
max_patches: int,
patch_size):
"""Extract patch sequence."""
patch_height, patch_width = patch_size
image_shape = tf.shape(image)
image_height = image_shape[0]
image_width = image_shape[1]
image_channels = image_shape[2]
image_height = tf.cast(image_height, tf.float32)
image_width = tf.cast(image_width, tf.float32)
# maximize scale s.t.
# ceil(scale * image_height / patch_height) *
# ceil(scale * image_width / patch_width) <= max_patches
scale = tf.sqrt(
max_patches *
(patch_height / image_height) *
(patch_width / image_width))
num_feasible_rows = tf.maximum(tf.minimum(
tf.math.floor(scale * image_height / patch_height),
max_patches), 1)
num_feasible_cols = tf.maximum(tf.minimum(
tf.math.floor(scale * image_width / patch_width),
max_patches), 1)
resized_height = tf.maximum(
tf.cast(num_feasible_rows * patch_height, tf.int32), 1)
resized_width = tf.maximum(
tf.cast(num_feasible_cols * patch_width, tf.int32), 1)
image = tf.image.resize(
images=image,
size=(resized_height, resized_width),
preserve_aspect_ratio=False,
antialias=True)
# [1, rows, columns, patch_height * patch_width * image_channels]
patches_ = tf.image.extract_patches(
images=tf.expand_dims(image, 0),
sizes=[1, patch_height, patch_width, 1],
strides=[1, patch_height, patch_width, 1],
rates=[1, 1, 1, 1],
padding="SAME")
patches_shape = tf.shape(patches_)
rows = patches_shape[1]
columns = patches_shape[2]
depth = patches_shape[3]
# [rows * columns, patch_height * patch_width * image_channels]
patches = tf.reshape(patches_, [rows * columns, depth])
# [rows * columns, 1]
row_ids = tf.reshape(
tf.tile(tf.expand_dims(tf.range(rows), 1), [1, columns]),
[rows * columns, 1])
col_ids = tf.reshape(
tf.tile(tf.expand_dims(tf.range(columns), 0), [rows, 1]),
[rows * columns, 1])
# Offset by 1 so the ids do not contain zeros, which represent padding.
row_ids += 1
col_ids += 1
# Prepare additional patch information for concatenation with real values.
row_ids = tf.cast(row_ids, tf.float32)
col_ids = tf.cast(col_ids, tf.float32)
# [rows * columns, 2 + patch_height * patch_width * image_channels]
result = tf.concat([row_ids, col_ids, patches], -1)
# [max_patches, 2 + patch_height * patch_width * image_channels]
result = tf.pad(result, [[0, max_patches - (rows * columns)], [0, 0]])
original_shape = tf.stack(
[rows, columns, patch_height, patch_width, image_channels])
return patches_, result, original_shape
def patch_sequence_torch(
image: torch.Tensor,
max_patches: int,
patch_size):
"""Extract patch sequence."""
patch_height, patch_width = patch_size
image_shape = image.shape
image_height = image_shape[1]
image_width = image_shape[2]
image_channels = image_shape[0]
image_height = float(image_height)
image_width = float(image_width)
# maximize scale s.t.
scale = math.sqrt(
max_patches * (patch_height / image_height) * (patch_width / image_width))
num_feasible_rows = max(min(
math.floor(scale * image_height / patch_height), max_patches), 1)
num_feasible_cols = max(min(
math.floor(scale * image_width / patch_width), max_patches), 1)
resized_height = max(
num_feasible_rows * patch_height, 1)
resized_width = max(
num_feasible_cols * patch_width, 1)
image = torch.nn.functional.interpolate(
image.unsqueeze(0),
size=(resized_height, resized_width),
mode='bilinear',
align_corners=False).squeeze(0)
# [1, rows, columns, patch_height * patch_width * image_channels]
patches_ = torch_extract_patches(image, patch_height, patch_width, padding=None)
patches_shape = patches_.shape
rows = patches_shape[1]
columns = patches_shape[2]
depth = patches_shape[3]
# [rows * columns, patch_height * patch_width * image_channels]
patches = patches_.reshape([rows * columns, depth])
# [rows * columns, 1]
row_ids = torch.arange(rows).reshape([rows, 1]).repeat(1, columns).reshape([rows * columns, 1])
col_ids = torch.arange(columns).reshape([1, columns]).repeat(rows, 1).reshape([rows * columns, 1])
# Offset by 1 so the ids do not contain zeros, which represent padding.
row_ids += 1
col_ids += 1
# Prepare additional patch features.
# [rows * columns, 1]
row_ids = row_ids.to(torch.float32)
col_ids = col_ids.to(torch.float32)
# [rows * columns, 2 + patch_height * patch_width * image_channels]
result = torch.cat([row_ids, col_ids, patches], -1)
# [max_patches, 2 + patch_height * patch_width * image_channels]
result = torch.nn.functional.pad(result, [0, 0, 0, max_patches - (rows * columns)])
original_shape = torch.tensor([rows, columns, patch_height, patch_width, image_channels])
return patches_, result, original_shape
# random tf tensor with shape [1, 256, 256, 3]
# H x W x C
image_tf = tf.random.uniform(shape=(256, 256, 3))
# same torch tensor from tf tensor
image_torch = torch.from_numpy(image_tf.numpy())
patch_size = (16, 16)
max_patches = 2048
# tf
tf_patches, result, original_shape = patch_sequence(image_tf, max_patches, patch_size)
# permute image to follow torch convention
image_torch = image_torch.permute(2, 0, 1)
torch_patches, torch_result, torch_original_shape = patch_sequence_torch(image_torch, max_patches, patch_size)
tf_patches = torch.from_numpy(tf_patches.numpy())
# compare results
assert torch.allclose(tf_patches, torch_patches, atol=1e-3, rtol=1e-3)
assert torch.allclose(torch.from_numpy(result.numpy()), torch_result, atol=1e-3, rtol=1e-3)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment