Last active
March 11, 2025 08:12
-
-
Save innat/1104962d0ea62a52e73ffeb4c1ad18d0 to your computer and use it in GitHub Desktop.
Depth interpolation for 3D volume
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
def linear_interpolation(volume, target_depth, depth_axis=0): | |
# Get the original depth size along the specified axis | |
original_depth = tf.shape(volume)[depth_axis] | |
# Generate floating-point indices for the target depth | |
indices = tf.linspace(0.0, tf.cast(original_depth - 1, tf.float32), target_depth) | |
# Split indices into integer and fractional parts | |
lower_indices = tf.cast(tf.floor(indices), tf.int32) | |
alpha = indices - tf.cast(lower_indices, tf.float32) # Fractional part | |
# Adjust the shape of alpha for broadcasting along the depth axis | |
alpha_shape = [1] * len(tf.shape(volume)) | |
alpha_shape[depth_axis] = target_depth | |
alpha = tf.reshape(alpha, alpha_shape) # Reshape alpha for proper broadcasting | |
# Gather the lower and upper slices along the specified depth axis | |
lower_indices = tf.maximum(lower_indices, 0) | |
upper_indices = tf.minimum(lower_indices + 1, original_depth - 1) | |
lower_slices = tf.gather(volume, lower_indices, axis=depth_axis) | |
upper_slices = tf.gather(volume, upper_indices, axis=depth_axis) | |
# Cast slices to float32 to ensure type compatibility with alpha | |
lower_slices = tf.cast(lower_slices, tf.float32) | |
upper_slices = tf.cast(upper_slices, tf.float32) | |
# Perform linear interpolation along the specified depth axis | |
interpolated_volume = (1 - alpha) * lower_slices + alpha * upper_slices | |
return interpolated_volume | |
def cubic_interpolation(volume, target_depth, depth_axis=0): | |
# Get the original depth size along the specified axis | |
original_depth = tf.shape(volume)[depth_axis] | |
# Generate floating-point indices for the target depth | |
indices = tf.linspace(0.0, tf.cast(original_depth - 1, tf.float32), target_depth) | |
# Split indices into integer and fractional parts | |
lower_indices = tf.cast(tf.floor(indices), tf.int32) | |
alpha = indices - tf.cast(lower_indices, tf.float32) # Fractional part | |
# Adjust the shape of alpha for broadcasting along the depth axis | |
alpha_shape = [1] * len(tf.shape(volume)) | |
alpha_shape[depth_axis] = target_depth | |
alpha = tf.reshape(alpha, alpha_shape) # Reshape alpha for proper broadcasting | |
# Gather the four neighboring slices along the specified depth axis | |
indices_0 = tf.maximum(lower_indices - 1, 0) | |
indices_1 = lower_indices | |
indices_2 = tf.minimum(lower_indices + 1, original_depth - 1) | |
indices_3 = tf.minimum(lower_indices + 2, original_depth - 1) | |
slices_0 = tf.gather(volume, indices_0, axis=depth_axis) | |
slices_1 = tf.gather(volume, indices_1, axis=depth_axis) | |
slices_2 = tf.gather(volume, indices_2, axis=depth_axis) | |
slices_3 = tf.gather(volume, indices_3, axis=depth_axis) | |
# Cast slices to float32 to ensure type compatibility with alpha | |
slices_0 = tf.cast(slices_0, tf.float32) | |
slices_1 = tf.cast(slices_1, tf.float32) | |
slices_2 = tf.cast(slices_2, tf.float32) | |
slices_3 = tf.cast(slices_3, tf.float32) | |
# Cubic interpolation coefficients | |
alpha_sq = alpha ** 2 | |
alpha_cu = alpha ** 3 | |
w0 = -0.5 * alpha_cu + 1.0 * alpha_sq - 0.5 * alpha | |
w1 = 1.5 * alpha_cu - 2.5 * alpha_sq + 1.0 | |
w2 = -1.5 * alpha_cu + 2.0 * alpha_sq + 0.5 * alpha | |
w3 = 0.5 * alpha_cu - 0.5 * alpha_sq | |
# Perform cubic interpolation along the specified depth axis | |
interpolated_volume = ( | |
w0 * slices_0 + | |
w1 * slices_1 + | |
w2 * slices_2 + | |
w3 * slices_3 | |
) | |
return interpolated_volume | |
def nearest_interpolation(volume, target_depth, depth_axis=0): | |
# Generate floating-point indices for the target depth | |
depth_indices = tf.linspace(0.0, tf.cast(tf.shape(volume)[depth_axis] - 1, tf.float32), target_depth) | |
# Round the indices to the nearest integer (nearest-neighbor interpolation) | |
depth_indices = tf.cast(depth_indices, tf.int32) | |
# Gather slices from the original volume using the rounded indices | |
resized_volume = tf.gather(volume, depth_indices, axis=depth_axis) | |
return resized_volume | |
def depth_interpolation(volume, target_depth, depth_axis=0, method='linear'): | |
SUPPORTED_METHOD = ('linear', 'nearest', 'cubic') | |
if method not in SUPPORTED_METHOD: | |
raise ValuerError( | |
f'Support interplation methods are {SUPPORTED_METHOD} ' | |
f'But got {method}' | |
) | |
methods = { | |
SUPPORTED_METHOD[0]: linear_interpolation, | |
SUPPORTED_METHOD[1]: nearest_interpolation, | |
SUPPORTED_METHOD[2]: cubic_interpolation, | |
} | |
return methods.get(method)(volume, target_depth, depth_axis) | |
random_tensor = tf.random.uniform(shape=(2, 5, 5), minval=1, maxval=6, dtype=tf.int32) | |
t = (5, 4, 4) | |
linear = depth_interpolation(random_tensor, t[0]+2, depth_axis=2, method='linear') | |
nearest = depth_interpolation(random_tensor, t[0]+2, depth_axis=1, method='nearest') | |
cubic = depth_interpolation(random_tensor, t[0]+2, depth_axis=1, method='cubic') | |
linear.shape, nearest.shape, cubic.shape | |
from matplotlib import pyplot as plt | |
def plot_slices(original, nearest, linear, cubic, target_depth): | |
# Ensure target_depth does not exceed the number of slices available along the depth axis | |
max_depth = min( | |
target_depth, | |
tf.shape(original)[0], | |
tf.shape(nearest)[0], | |
tf.shape(linear)[0], | |
tf.shape(cubic)[0] | |
) | |
if isinstance(max_depth, tf.Tensor): | |
max_depth = max_depth.numpy() | |
fig, axes = plt.subplots(max_depth, 4, figsize=(12, 3 * max_depth)) | |
fig.suptitle("Original vs Nearest-Neighbor vs Linear vs Cubic Interpolation", fontsize=16) | |
for i in range(max_depth): | |
# Plot Original slices (repeated for comparison) | |
axes[i, 0].imshow(original[i], cmap='viridis', vmin=1, vmax=6) | |
axes[i, 0].set_title(f"Original Slice {i}") | |
# Plot Nearest-neighbor interpolation | |
axes[i, 1].imshow(nearest[i], cmap='viridis', vmin=1, vmax=6) | |
axes[i, 1].set_title(f"Nearest Slice {i}") | |
# Plot Linear interpolation | |
axes[i, 2].imshow(linear[i], cmap='viridis', vmin=1, vmax=6) | |
axes[i, 2].set_title(f"Linear Slice {i}") | |
# Plot Cubic interpolation | |
axes[i, 3].imshow(cubic[i], cmap='viridis', vmin=1, vmax=6) | |
axes[i, 3].set_title(f"Cubic Slice {i}") | |
# Remove axis labels | |
for ax in axes[i]: | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
plt.tight_layout() | |
plt.show() | |
# Example usage: | |
random_tensor = tf.random.uniform(shape=(5, 2, 5), minval=1, maxval=6, dtype=tf.int32) | |
# input: 5, 2, 5 : target: 5, 5, 5 | |
# Perform interpolations (assuming `depth_interpolation` is defined) | |
linear = depth_interpolation(random_tensor, target_depth=5, depth_axis=1, method='linear') | |
nearest = depth_interpolation(random_tensor, target_depth=5, depth_axis=1, method='nearest') | |
cubic = depth_interpolation(random_tensor, target_depth=5, depth_axis=1, method='cubic') | |
print('original shape ', random_tensor.shape) | |
print('nearest shape ', nearest.shape) | |
print('linear shape ', linear.shape) | |
print('cubic shape ', cubic.shape) | |
# Plot the results | |
plot_slices( | |
random_tensor.numpy(), | |
nearest.numpy(), | |
linear.numpy(), | |
cubic.numpy(), | |
target_depth=5 | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment