Last active
January 10, 2022 21:46
-
-
Save mjm522/c263db0f2b535222d748f8dfa8d260da to your computer and use it in GitHub Desktop.
Compare 2D arrays. Return list of indices containing the rows.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def two_d_array_compare(array_a, array_b, thresh): | |
""" | |
this function does a pairwise comparision matching between array_b and array_a | |
it returns all the row numbers in array_a that has a similar entry in array_b | |
:param array_a: m x k | |
:param array_b: n x k | |
:thresh float difference | |
:return: row numbers of similar rows | |
""" | |
thresh = abs(thresh) | |
similar_indices = [] | |
if array_b.ndim == 1: | |
array_b = array_b[None, :] | |
if array_a.ndim == 1: | |
array_a = array_a[None, :] | |
for row_b in array_b: | |
res = (row_b - thresh <= array_a) * (array_a <= row_b + thresh) | |
if res.shape[0] > 1: | |
res = np.prod(res.squeeze(), 1) | |
else: | |
res = np.prod(res.squeeze(), 0) | |
if np.any(res): | |
similar_indices += list(np.where(res)[0]) | |
return list(set(similar_indices)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment