Skip to content

Instantly share code, notes, and snippets.

@innat
Last active March 11, 2025 08:12
Show Gist options
  • Save innat/1104962d0ea62a52e73ffeb4c1ad18d0 to your computer and use it in GitHub Desktop.
Save innat/1104962d0ea62a52e73ffeb4c1ad18d0 to your computer and use it in GitHub Desktop.
Depth interpolation for 3D volume
import tensorflow as tf
def linear_interpolation(volume, target_depth, depth_axis=0):
# Get the original depth size along the specified axis
original_depth = tf.shape(volume)[depth_axis]
# Generate floating-point indices for the target depth
indices = tf.linspace(0.0, tf.cast(original_depth - 1, tf.float32), target_depth)
# Split indices into integer and fractional parts
lower_indices = tf.cast(tf.floor(indices), tf.int32)
alpha = indices - tf.cast(lower_indices, tf.float32) # Fractional part
# Adjust the shape of alpha for broadcasting along the depth axis
alpha_shape = [1] * len(tf.shape(volume))
alpha_shape[depth_axis] = target_depth
alpha = tf.reshape(alpha, alpha_shape) # Reshape alpha for proper broadcasting
# Gather the lower and upper slices along the specified depth axis
lower_indices = tf.maximum(lower_indices, 0)
upper_indices = tf.minimum(lower_indices + 1, original_depth - 1)
lower_slices = tf.gather(volume, lower_indices, axis=depth_axis)
upper_slices = tf.gather(volume, upper_indices, axis=depth_axis)
# Cast slices to float32 to ensure type compatibility with alpha
lower_slices = tf.cast(lower_slices, tf.float32)
upper_slices = tf.cast(upper_slices, tf.float32)
# Perform linear interpolation along the specified depth axis
interpolated_volume = (1 - alpha) * lower_slices + alpha * upper_slices
return interpolated_volume
def cubic_interpolation(volume, target_depth, depth_axis=0):
# Get the original depth size along the specified axis
original_depth = tf.shape(volume)[depth_axis]
# Generate floating-point indices for the target depth
indices = tf.linspace(0.0, tf.cast(original_depth - 1, tf.float32), target_depth)
# Split indices into integer and fractional parts
lower_indices = tf.cast(tf.floor(indices), tf.int32)
alpha = indices - tf.cast(lower_indices, tf.float32) # Fractional part
# Adjust the shape of alpha for broadcasting along the depth axis
alpha_shape = [1] * len(tf.shape(volume))
alpha_shape[depth_axis] = target_depth
alpha = tf.reshape(alpha, alpha_shape) # Reshape alpha for proper broadcasting
# Gather the four neighboring slices along the specified depth axis
indices_0 = tf.maximum(lower_indices - 1, 0)
indices_1 = lower_indices
indices_2 = tf.minimum(lower_indices + 1, original_depth - 1)
indices_3 = tf.minimum(lower_indices + 2, original_depth - 1)
slices_0 = tf.gather(volume, indices_0, axis=depth_axis)
slices_1 = tf.gather(volume, indices_1, axis=depth_axis)
slices_2 = tf.gather(volume, indices_2, axis=depth_axis)
slices_3 = tf.gather(volume, indices_3, axis=depth_axis)
# Cast slices to float32 to ensure type compatibility with alpha
slices_0 = tf.cast(slices_0, tf.float32)
slices_1 = tf.cast(slices_1, tf.float32)
slices_2 = tf.cast(slices_2, tf.float32)
slices_3 = tf.cast(slices_3, tf.float32)
# Cubic interpolation coefficients
alpha_sq = alpha ** 2
alpha_cu = alpha ** 3
w0 = -0.5 * alpha_cu + 1.0 * alpha_sq - 0.5 * alpha
w1 = 1.5 * alpha_cu - 2.5 * alpha_sq + 1.0
w2 = -1.5 * alpha_cu + 2.0 * alpha_sq + 0.5 * alpha
w3 = 0.5 * alpha_cu - 0.5 * alpha_sq
# Perform cubic interpolation along the specified depth axis
interpolated_volume = (
w0 * slices_0 +
w1 * slices_1 +
w2 * slices_2 +
w3 * slices_3
)
return interpolated_volume
def nearest_interpolation(volume, target_depth, depth_axis=0):
# Generate floating-point indices for the target depth
depth_indices = tf.linspace(0.0, tf.cast(tf.shape(volume)[depth_axis] - 1, tf.float32), target_depth)
# Round the indices to the nearest integer (nearest-neighbor interpolation)
depth_indices = tf.cast(depth_indices, tf.int32)
# Gather slices from the original volume using the rounded indices
resized_volume = tf.gather(volume, depth_indices, axis=depth_axis)
return resized_volume
def depth_interpolation(volume, target_depth, depth_axis=0, method='linear'):
SUPPORTED_METHOD = ('linear', 'nearest', 'cubic')
if method not in SUPPORTED_METHOD:
raise ValuerError(
f'Support interplation methods are {SUPPORTED_METHOD} '
f'But got {method}'
)
methods = {
SUPPORTED_METHOD[0]: linear_interpolation,
SUPPORTED_METHOD[1]: nearest_interpolation,
SUPPORTED_METHOD[2]: cubic_interpolation,
}
return methods.get(method)(volume, target_depth, depth_axis)
random_tensor = tf.random.uniform(shape=(2, 5, 5), minval=1, maxval=6, dtype=tf.int32)
t = (5, 4, 4)
linear = depth_interpolation(random_tensor, t[0]+2, depth_axis=2, method='linear')
nearest = depth_interpolation(random_tensor, t[0]+2, depth_axis=1, method='nearest')
cubic = depth_interpolation(random_tensor, t[0]+2, depth_axis=1, method='cubic')
linear.shape, nearest.shape, cubic.shape
from matplotlib import pyplot as plt
def plot_slices(original, nearest, linear, cubic, target_depth):
# Ensure target_depth does not exceed the number of slices available along the depth axis
max_depth = min(
target_depth,
tf.shape(original)[0],
tf.shape(nearest)[0],
tf.shape(linear)[0],
tf.shape(cubic)[0]
)
if isinstance(max_depth, tf.Tensor):
max_depth = max_depth.numpy()
fig, axes = plt.subplots(max_depth, 4, figsize=(12, 3 * max_depth))
fig.suptitle("Original vs Nearest-Neighbor vs Linear vs Cubic Interpolation", fontsize=16)
for i in range(max_depth):
# Plot Original slices (repeated for comparison)
axes[i, 0].imshow(original[i], cmap='viridis', vmin=1, vmax=6)
axes[i, 0].set_title(f"Original Slice {i}")
# Plot Nearest-neighbor interpolation
axes[i, 1].imshow(nearest[i], cmap='viridis', vmin=1, vmax=6)
axes[i, 1].set_title(f"Nearest Slice {i}")
# Plot Linear interpolation
axes[i, 2].imshow(linear[i], cmap='viridis', vmin=1, vmax=6)
axes[i, 2].set_title(f"Linear Slice {i}")
# Plot Cubic interpolation
axes[i, 3].imshow(cubic[i], cmap='viridis', vmin=1, vmax=6)
axes[i, 3].set_title(f"Cubic Slice {i}")
# Remove axis labels
for ax in axes[i]:
ax.set_xticks([])
ax.set_yticks([])
plt.tight_layout()
plt.show()
# Example usage:
random_tensor = tf.random.uniform(shape=(5, 2, 5), minval=1, maxval=6, dtype=tf.int32)
# input: 5, 2, 5 : target: 5, 5, 5
# Perform interpolations (assuming `depth_interpolation` is defined)
linear = depth_interpolation(random_tensor, target_depth=5, depth_axis=1, method='linear')
nearest = depth_interpolation(random_tensor, target_depth=5, depth_axis=1, method='nearest')
cubic = depth_interpolation(random_tensor, target_depth=5, depth_axis=1, method='cubic')
print('original shape ', random_tensor.shape)
print('nearest shape ', nearest.shape)
print('linear shape ', linear.shape)
print('cubic shape ', cubic.shape)
# Plot the results
plot_slices(
random_tensor.numpy(),
nearest.numpy(),
linear.numpy(),
cubic.numpy(),
target_depth=5
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment