Created
March 31, 2021 23:02
-
-
Save fauxneticien/47bbc67d9b4737f77e0682a8205e1fc6 to your computer and use it in GitHub Desktop.
Segmental DTW implementation in Cython
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from distutils.core import setup, Extension | |
from Cython.Build import cythonize | |
import numpy | |
setup( | |
ext_modules = cythonize("dtw_cython.pyx", annotate=True), | |
include_dirs=[numpy.get_include()] | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
cimport numpy as np | |
from dtw import dtw | |
cpdef segdtw(np.ndarray[np.float64_t, ndim=2] distance_matrix, int win_step = 2): | |
cdef int query_length = distance_matrix.shape[0] | |
cdef int reference_length = distance_matrix.shape[1] | |
cdef float min_match_ratio = 0.5 | |
cdef float max_match_ratio = 1.5 | |
cdef int window_size = int(query_length * max_match_ratio) | |
cdef int last_segment_start = int(reference_length - (min_match_ratio * query_length)) | |
# Initialize distances as array of 1s | |
cdef np.ndarray[np.float64_t, ndim=1] segdtw_dists = np.ones(last_segment_start) | |
cdef float sim_score = 0 | |
for r_i in range(0, last_segment_start, win_step): | |
segment_start = r_i | |
segment_end = min(r_i + window_size, reference_length) | |
segment_data = distance_matrix[:,segment_start:segment_end] | |
dtw_obj = dtw(segment_data, | |
step_pattern = "symmetricP1", # See Sakoe & Chiba (1978) for definition of step pattern | |
open_end = True, # Let alignment end anywhere along the segment (need not be at lower corner) | |
distance_only = True # Speed up dtw(), no backtracing for alignment path | |
) | |
match_ratio = dtw_obj.jmin / query_length | |
# Update distance from initial value of 1 if alignment conditions are met | |
if match_ratio >= min_match_ratio and match_ratio <= max_match_ratio: | |
segdtw_dists[r_i] = dtw_obj.normalizedDistance | |
sim_score = max(0, 1 - min(segdtw_dists)) | |
return sim_score |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import timeit | |
import dtw_cython | |
import numpy as np | |
from dtw import dtw | |
# Make a query-reference matrix with all 1s | |
# and fill the diagonal of a sub-part with 0s | |
# to simulate the occurence of a query | |
qr_matrix = np.ones((25,100)) | |
np.fill_diagonal(qr_matrix[:,25:50], 0) | |
def py_segdtw(distance_matrix, win_step = 2): | |
segdtw_dists = [] | |
query_length, reference_length = distance_matrix.shape | |
# reject if alignment less than half of query size | |
# or if larger than 1.5 times query size | |
min_match_ratio, max_match_ratio = [0.5, 1.5] | |
window_size = int(query_length * max_match_ratio) | |
last_segment_end = int(reference_length - (min_match_ratio * query_length)) | |
for r_i in range(0, last_segment_end, win_step): | |
segment_start = r_i | |
segment_end = min(r_i + window_size, reference_length) | |
segment_data = distance_matrix[:,segment_start:segment_end] | |
dtw_obj = dtw(segment_data, | |
step_pattern = "symmetricP1", # See Sakoe & Chiba (1978) for definition of step pattern | |
open_end = True, # Let alignment end anywhere along the segment (need not be at lower corner) | |
distance_only = True # Speed up dtw(), no backtracing for alignment path | |
) | |
match_ratio = dtw_obj.jmin / query_length | |
if match_ratio < min_match_ratio or match_ratio > max_match_ratio: | |
segdtw_dists.append(1) | |
else: | |
segdtw_dists.append(dtw_obj.normalizedDistance) | |
# Convert distance (lower is better) to similary score (is higher better) | |
# makes it easier to compare with CNN output probabilities | |
# | |
# Return 0 if segdtw_dists is [] (i.e. no good alignments found) | |
sim_score = 0 if len(segdtw_dists) == 0 else 1 - min(segdtw_dists) | |
return sim_score | |
print("Cython segdtw") | |
print(timeit.timeit('dtw_cython.segdtw(qr_matrix)', number=1000, globals=globals())) | |
print("Python segdtw") | |
print(timeit.timeit('py_segdtw(qr_matrix)', number=1000, globals=globals())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment