Skip to content

Instantly share code, notes, and snippets.

@alexlib
Created April 26, 2025 06:38
Show Gist options
  • Save alexlib/15a2334103f42dd0edbf11c7071f3bc4 to your computer and use it in GitHub Desktop.
Save alexlib/15a2334103f42dd0edbf11c7071f3bc4 to your computer and use it in GitHub Desktop.
John D'Errico's inpaint in 3D with Python counterexample using Scipy and tests. The Matlab function is from pivSuite v0.8.3
test_inpaint_equivalence.py
..
----------------------------------------------------------------------
Ran 2 tests in 1.908s

OK
(.conda) user@user-NUC8i7BEH:~/Documents/repos/OpenOpticalFlow_PIV_v1$ /home/user/Documents/repos/OpenOpticalFlow_PIV_v1/.conda/bin/python /home/user/Documents/repos/OpenOpticalFlow_PIV_v1/test_inpaint_equivalence.py
..
----------------------------------------------------------------------
Ran 2 tests in 1.504s

OK
import numpy as np
from scipy.interpolate import griddata
def inpaint_nans_3d(array):
# Get coordinates of non-nan values
valid_mask = ~np.isnan(array)
coords = np.array(np.nonzero(valid_mask)).T
values = array[valid_mask]
# Get coordinates of nan values
nan_coords = np.array(np.nonzero(~valid_mask)).T
# Interpolate
filled_values = griddata(coords, values, nan_coords, method='linear')
# Create output array and fill with interpolated values
result = array.copy()
result[~valid_mask] = filled_values
return result
function B=inpaint_nans3(A,method)
% INPAINT_NANS3: in-paints over nans in a 3-D array
% usage: B=INPAINT_NANS3(A) % default method (0)
% usage: B=INPAINT_NANS3(A,method) % specify method used
%
% Solves approximation to a boundary value problem to
% interpolate and extrapolate holes in a 3-D array.
%
% Note that if the array is large, and there are many NaNs
% to be filled in, this may take a long time, or run into
% memory problems.
%
% arguments (input):
% A - n1 x n2 x n3 array with some NaNs to be filled in
%
% method - (OPTIONAL) scalar numeric flag - specifies
% which approach (or physical metaphor to use
% for the interpolation.) All methods are capable
% of extrapolation, some are better than others.
% There are also speed differences, as well as
% accuracy differences for smooth surfaces.
%
% method 0 uses a simple plate metaphor.
% method 1 uses a spring metaphor.
%
% method == 0 --> (DEFAULT) Solves the Laplacian
% equation over the set of nan elements in the
% array.
% Extrapolation behavior is roughly linear.
%
% method == 1 --+ Uses a spring metaphor. Assumes
% springs (with a nominal length of zero)
% connect each node with every neighbor
% (horizontally, vertically and diagonally)
% Since each node tries to be like its neighbors,
% extrapolation is roughly a constant function where
% this is consistent with the neighboring nodes.
%
% There are only two different methods in this code,
% chosen as the most useful ones (IMHO) from my
% original inpaint_nans code.
%
%
% arguments (output):
% B - n1xn2xn3 array with NaNs replaced
%
%
% Example:
% % A linear function of 3 independent variables,
% % used to test whether inpainting will interpolate
% % the missing elements correctly.
% [x,y,z] = ndgrid(-10:10,-10:10,-10:10);
% W = x + y + z;
%
% % Pick a set of distinct random elements to NaN out.
% ind = unique(ceil(rand(3000,1)*numel(W)));
% Wnan = W;
% Wnan(ind) = NaN;
%
% % Do inpainting
% Winp = inpaint_nans3(Wnan,0);
%
% % Show that the inpainted values are essentially
% % within eps of the originals.
% std(Winp(ind) - W(ind))
% ans =
% 4.3806e-15
%
%
% See also: griddatan, inpaint_nans
%
% Author: John D'Errico
% e-mail address: [email protected]
% Release: 1
% Release date: 8/21/08
% Need to know which elements are NaN, and
% what size is the array. Unroll A for the
% inpainting, although inpainting will be done
% fully in 3-d.
NA = size(A);
A = A(:);
nt = prod(NA);
k = isnan(A(:));
% list the nodes which are known, and which will
% be interpolated
nan_list=find(k);
known_list=find(~k);
% how many nans overall
nan_count=length(nan_list);
% convert NaN indices to (r,c) form
% nan_list==find(k) are the unrolled (linear) indices
% (row,column) form
[n1,n2,n3]=ind2sub(NA,nan_list);
% both forms of index for all the nan elements in one array:
% column 1 == unrolled index
% column 2 == index 1
% column 3 == index 2
% column 4 == index 3
nan_list=[nan_list,n1,n2,n3];
% supply default method
if (nargin<2) || isempty(method)
method = 0;
elseif ~ismember(method,[0 1])
error 'If supplied, method must be one of: {0,1}.'
end
% alternative methods
switch method
case 0
% The same as method == 1, except only work on those
% elements which are NaN, or at least touch a NaN.
% horizontal and vertical neighbors only
talks_to = [-1 0 0;1 0 0;0 -1 0;0 1 0;0 0 -1;0 0 1];
neighbors_list=identify_neighbors(NA,nan_list,talks_to);
% list of all nodes we have identified
all_list=[nan_list;neighbors_list];
% generate sparse array with second partials on row
% variable for each element in either list, but only
% for those nodes which have a row index > 1 or < n
L = find((all_list(:,2) > 1) & (all_list(:,2) < NA(1)));
nL=length(L);
if nL>0
fda=sparse(repmat(all_list(L,1),1,3), ...
repmat(all_list(L,1),1,3)+repmat([-1 0 1],nL,1), ...
repmat([1 -2 1],nL,1),nt,nt);
else
fda=spalloc(nt,nt,size(all_list,1)*7);
end
% 2nd partials on column index
L = find((all_list(:,3) > 1) & (all_list(:,3) < NA(2)));
nL=length(L);
if nL>0
fda=fda+sparse(repmat(all_list(L,1),1,3), ...
repmat(all_list(L,1),1,3)+repmat([-NA(1) 0 NA(1)],nL,1), ...
repmat([1 -2 1],nL,1),nt,nt);
end
% 2nd partials on third index
L = find((all_list(:,4) > 1) & (all_list(:,4) < NA(3)));
nL=length(L);
if nL>0
ntimesm = NA(1)*NA(2);
fda=fda+sparse(repmat(all_list(L,1),1,3), ...
repmat(all_list(L,1),1,3)+repmat([-ntimesm 0 ntimesm],nL,1), ...
repmat([1 -2 1],nL,1),nt,nt);
end
% eliminate knowns
rhs=-fda(:,known_list)*A(known_list);
k=find(any(fda(:,nan_list(:,1)),2));
% and solve...
B=A;
B(nan_list(:,1))=fda(k,nan_list(:,1))\rhs(k);
case 1
% Spring analogy
% interpolating operator.
% list of all springs between a node and a horizontal
% or vertical neighbor
hv_list=[-1 -1 0 0;1 1 0 0;-NA(1) 0 -1 0;NA(1) 0 1 0; ...
-NA(1)*NA(2) 0 0 -1;NA(1)*NA(2) 0 0 1];
hv_springs=[];
for i=1:size(hv_list,1)
hvs=nan_list+repmat(hv_list(i,:),nan_count,1);
k=(hvs(:,2)>=1) & (hvs(:,2)<=NA(1)) & ...
(hvs(:,3)>=1) & (hvs(:,3)<=NA(2)) & ...
(hvs(:,4)>=1) & (hvs(:,4)<=NA(3));
hv_springs=[hv_springs;[nan_list(k,1),hvs(k,1)]];
end
% delete replicate springs
hv_springs=unique(sort(hv_springs,2),'rows');
% build sparse matrix of connections
nhv=size(hv_springs,1);
springs=sparse(repmat((1:nhv)',1,2),hv_springs, ...
repmat([1 -1],nhv,1),nhv,prod(NA));
% eliminate knowns
rhs=-springs(:,known_list)*A(known_list);
% and solve...
B=A;
B(nan_list(:,1))=springs(:,nan_list(:,1))\rhs;
end
% all done, make sure that B is the same shape as
% A was when we came in.
B=reshape(B,NA);
% ====================================================
% end of main function
% ====================================================
% ====================================================
% begin subfunctions
% ====================================================
function neighbors_list=identify_neighbors(NA,nan_list,talks_to)
% identify_neighbors: identifies all the neighbors of
% those nodes in nan_list, not including the nans
% themselves
%
% arguments (input):
% NA - 1x3 vector = size(A), where A is the
% array to be interpolated
% nan_list - array - list of every nan element in A
% nan_list(i,1) == linear index of i'th nan element
% nan_list(i,2) == row index of i'th nan element
% nan_list(i,3) == column index of i'th nan element
% nan_list(i,4) == third index of i'th nan element
% talks_to - px2 array - defines which nodes communicate
% with each other, i.e., which nodes are neighbors.
%
% talks_to(i,1) - defines the offset in the row
% dimension of a neighbor
% talks_to(i,2) - defines the offset in the column
% dimension of a neighbor
%
% For example, talks_to = [-1 0;0 -1;1 0;0 1]
% means that each node talks only to its immediate
% neighbors horizontally and vertically.
%
% arguments(output):
% neighbors_list - array - list of all neighbors of
% all the nodes in nan_list
if ~isempty(nan_list)
% use the definition of a neighbor in talks_to
nan_count=size(nan_list,1);
talk_count=size(talks_to,1);
nn=zeros(nan_count*talk_count,3);
j=[1,nan_count];
for i=1:talk_count
nn(j(1):j(2),:)=nan_list(:,2:4) + ...
repmat(talks_to(i,:),nan_count,1);
j=j+nan_count;
end
% drop those nodes which fall outside the bounds of the
% original array
L = (nn(:,1)<1) | (nn(:,1)>NA(1)) | ...
(nn(:,2)<1) | (nn(:,2)>NA(2)) | ...
(nn(:,3)<1) | (nn(:,3)>NA(3));
nn(L,:)=[];
% form the same format 4 column array as nan_list
neighbors_list=[sub2ind(NA,nn(:,1),nn(:,2),nn(:,3)),nn];
% delete replicates in the neighbors list
neighbors_list=unique(neighbors_list,'rows');
% and delete those which are also in the list of NaNs.
neighbors_list=setdiff(neighbors_list,nan_list,'rows');
else
neighbors_list=[];
end
% Load the test data generated by Python
load('test_data.mat');
% Verify the data shape
size_data = size(data_with_nans);
fprintf('Input data size: %dx%dx%d\n', size_data(1), size_data(2), size_data(3));
% Run inpaint_nans3
inpainted_data = inpaint_nans3(data_with_nans, 0);
% Verify the output shape
size_result = size(inpainted_data);
fprintf('Output data size: %dx%dx%d\n', size_result(1), size_result(2), size_result(3));
% Save the results for Python to read
save('matlab_result.mat', 'inpainted_data');
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
from inpaint3d import inpaint_nans_3d
def generate_test_data(shape=(20, 20, 20)):
"""Generate test data with known NaN positions"""
# Create a 3D test array with a known pattern
x, y, z = np.meshgrid(np.linspace(-2, 2, shape[0]),
np.linspace(-2, 2, shape[1]),
np.linspace(-2, 2, shape[2]))
data = np.sin(x) * np.cos(y) * np.exp(-0.1 * (x**2 + y**2 + z**2))
# Add NaN values randomly
nan_mask = np.random.random(data.shape) < 0.3
data_with_nans = data.copy()
data_with_nans[nan_mask] = np.nan
return data, data_with_nans, nan_mask
def compare_results(original, matlab_result, python_result):
"""Compare the results between MATLAB and Python implementations"""
# Calculate error metrics
matlab_error = np.nanmean(np.abs(original - matlab_result))
python_error = np.nanmean(np.abs(original - python_result))
difference = np.nanmean(np.abs(matlab_result - python_result))
print(f"Mean Absolute Error (MATLAB): {matlab_error:.6f}")
print(f"Mean Absolute Error (Python): {python_error:.6f}")
print(f"Mean Difference between implementations: {difference:.6f}")
# Visualize middle slices
mid_slice = original.shape[2] // 2
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes[0,0].imshow(original[:,:,mid_slice])
axes[0,0].set_title('Original')
axes[0,1].imshow(matlab_result[:,:,mid_slice])
axes[0,1].set_title('MATLAB Result')
axes[1,0].imshow(python_result[:,:,mid_slice])
axes[1,0].set_title('Python Result')
diff = np.abs(matlab_result - python_result)
axes[1,1].imshow(diff[:,:,mid_slice])
axes[1,1].set_title('Absolute Difference')
plt.tight_layout()
plt.show()
def main():
# Generate test data
original, data_with_nans, nan_mask = generate_test_data()
# Save test data for MATLAB
sio.savemat('test_data.mat',
{'data_with_nans': data_with_nans})
# Run Python implementation
python_result = inpaint_nans_3d(data_with_nans)
# Load MATLAB results (after running MATLAB script)
try:
matlab_data = sio.loadmat('matlab_result.mat')
matlab_result = matlab_data['inpainted_data']
# Compare results
compare_results(original, matlab_result, python_result)
except FileNotFoundError:
print("MATLAB results not found. Please run the MATLAB script first.")
if __name__ == "__main__":
main()
import unittest
import numpy as np
import scipy.io as sio
from inpaint3d import inpaint_nans_3d
def generate_test_data(shape=(20, 20, 20)):
"""Generate test data with known NaN positions"""
# Create a 3D test array with a known pattern
x, y, z = np.meshgrid(np.linspace(-2, 2, shape[0]),
np.linspace(-2, 2, shape[1]),
np.linspace(-2, 2, shape[2]))
data = np.sin(x) * np.cos(y) * np.exp(-0.1 * (x**2 + y**2 + z**2))
# Add NaN values randomly
np.random.seed(42) # For reproducibility
nan_mask = np.random.random(data.shape) < 0.3
data_with_nans = data.copy()
data_with_nans[nan_mask] = np.nan
return data, data_with_nans, nan_mask
class TestInpaintEquivalence(unittest.TestCase):
def setUp(self):
# Use the same shape as in test_inpaint_comparison.py
self.shape = (20, 20, 20)
self.original, self.data_with_nans, _ = generate_test_data(self.shape)
# Save for MATLAB
sio.savemat('test_data.mat', {'data_with_nans': self.data_with_nans})
def test_results_close_to_original(self):
"""Test if both implementations give results close to original data"""
# Run Python implementation
python_result = inpaint_nans_3d(self.data_with_nans)
try:
# Load MATLAB results
matlab_data = sio.loadmat('matlab_result.mat')
matlab_result = matlab_data['inpainted_data']
# Verify shapes match
self.assertEqual(matlab_result.shape, self.original.shape,
"MATLAB result shape doesn't match original data shape")
# Test if results are close to original (within 5% error)
python_error = np.nanmean(np.abs(self.original - python_result))
matlab_error = np.nanmean(np.abs(self.original - matlab_result))
self.assertLess(python_error, 0.05)
self.assertLess(matlab_error, 0.05)
except FileNotFoundError:
self.skipTest("MATLAB results file not found. Run MATLAB script first.")
def test_implementations_equivalent(self):
"""Test if both implementations give similar results"""
python_result = inpaint_nans_3d(self.data_with_nans)
try:
matlab_data = sio.loadmat('matlab_result.mat')
matlab_result = matlab_data['inpainted_data']
# Verify shapes match
self.assertEqual(matlab_result.shape, python_result.shape,
"MATLAB and Python results have different shapes")
# Test if results are within 1% of each other
difference = np.nanmean(np.abs(matlab_result - python_result))
self.assertLess(difference, 0.01)
except FileNotFoundError:
self.skipTest("MATLAB results file not found. Run MATLAB script first.")
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment