|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from mpl_toolkits.mplot3d import Axes3D |
|
import matplotlib.patches as mpatches |
|
|
|
# Create synthetic data for 8 layers |
|
np.random.seed(42) |
|
layers = 8 |
|
grid_size = 5 |
|
data = np.random.gamma(2, 2, size=(layers, grid_size, grid_size)) |
|
|
|
# Function to apply temporal moving window |
|
def apply_temporal_window(data, window_size=5): |
|
padded_data = np.pad(data, ((window_size//2, window_size//2), (0, 0), (0, 0)), mode='edge') |
|
result = np.zeros_like(data) |
|
for i in range(layers): |
|
result[i] = np.mean(padded_data[i:i+window_size], axis=0) |
|
return result |
|
|
|
# Apply temporal moving window |
|
smoothed_data = apply_temporal_window(data) |
|
|
|
# Create 3D plot |
|
fig = plt.figure(figsize=(12, 8)) |
|
ax = fig.add_subplot(111, projection='3d') |
|
|
|
# Plot each layer |
|
for i in range(layers): |
|
x, y = np.meshgrid(range(grid_size), range(grid_size)) |
|
z = np.full_like(x, i) |
|
ax.plot_surface(x, y, z, facecolors=plt.cm.Blues(data[i]/np.max(data)), alpha=0.3) |
|
|
|
# Highlight the central pixel at (4,4) in the middle layer |
|
middle_layer = layers // 2 |
|
x, y = 4, 4 |
|
|
|
# Add a red cube to highlight the entire pixel |
|
ax.bar3d(x, y, middle_layer, 1, 1, 0.1, color='red', alpha=0.7) |
|
|
|
# Show the temporal window |
|
window_size = 5 |
|
for i in range(middle_layer - window_size//2, middle_layer + window_size//2 + 1): |
|
if i != middle_layer: |
|
ax.bar3d(x, y, i, 1, 1, 0.1, color='pink', alpha=0.5) |
|
ax.plot([x+0.5, x+0.5], [y+0.5, y+0.5], [i, middle_layer], color='pink', linestyle='--', alpha=0.5) |
|
|
|
# Set labels and title |
|
ax.set_xlabel('X') |
|
ax.set_ylabel('Y') |
|
ax.set_zlabel('Layer') |
|
ax.set_title('Temporal Moving Window Visualization') |
|
|
|
# Set the viewing angle |
|
ax.view_init(elev=20, azim=45) |
|
|
|
# Create custom legend |
|
red_patch = mpatches.Patch(color='red', label='Central Pixel (4,4)') |
|
pink_patch = mpatches.Patch(color='pink', label='Temporal Window (±2 layers)') |
|
|
|
ax.legend(handles=[red_patch, pink_patch]) |
|
|
|
# Set axis limits |
|
ax.set_xlim(0, grid_size) |
|
ax.set_ylim(0, grid_size) |
|
ax.set_zlim(0, layers-1) |
|
|
|
# Set tick labels |
|
ax.set_xticks(range(grid_size)) |
|
ax.set_yticks(range(grid_size)) |
|
ax.set_zticks(range(layers)) |
|
|
|
plt.show() |