Skip to content

Instantly share code, notes, and snippets.

@JoFrhwld
Last active July 1, 2022 21:41
Show Gist options
  • Save JoFrhwld/2e7f7b4659aa8d5f28e5a880843db88d to your computer and use it in GitHub Desktop.
Save JoFrhwld/2e7f7b4659aa8d5f28e5a880843db88d to your computer and use it in GitHub Desktop.
Given two pandas data frames of time stamped transcription chunks of the same audio, return the smallest possible span of aligned chunks between the two.
import pandas as pd
def align_chunks(X_df, Y_df, A_idx_list, B_idx_list, state, **kwargs):
"""
Given X_df and Y_df, return the dataframes (A_df and B_df) with the
smallest possible span of overlapping transcription chunks.
Function operates recursively.
"""
# If there are no indices in A_idx_list, find out which df has the earliest
# start time to function as A_df. Optional args X_last and Y_last for this case
if len(A_idx_list) == 0:
x_start = min(X_df.loc[kwargs["X_last"]+1:].start_ms)
y_start = min(Y_df.loc[kwargs["Y_last"]+1:].start_ms)
if x_start < y_start:
A_df = X_df
B_df = Y_df
A_idx_list += [kwargs["X_last"]+1]
else:
A_df = Y_df
B_df = X_df
A_idx_list += [kwargs["Y_last"]+1]
else:
A_df = X_df
B_df = Y_df
# Start and end times of the currently first A_df transcription chunk.
## Restart Point ##
curr_A_idx = max(A_idx_list)
A_start = A_df.loc[curr_A_idx].start_ms
A_end = A_df.loc[curr_A_idx].end_ms
# If there is no index in B_idx_list,
# Get the first possible B_df chunk.
if len(B_idx_list) == 0:
B_possible_idx = B_df.loc[(B_df.end_ms > A_start)].index
if len(B_possible_idx) == 0:
print("Search ended, no more alignable chunks")
return((A_df, B_df, A_idx_list, B_idx_list, state))
else:
B_idx_list += [B_possible_idx[0]]
last_B_idx = max(B_idx_list)
# Identify any B_df transcription chunks that begin after the end
# of the current A chunk.
B_candidates = B_df.loc[last_B_idx+1:].loc[B_df.start_ms < A_end]
n_B = B_candidates.shape[0]
# If there are no other B chunks that start after current A chunk...
if n_B == 0:
B_end = B_df.loc[last_B_idx].end_ms
# Backtrack to see if there are any A chunks between the current
# A chunk and the end of current B chunk
A_candidates = A_df[curr_A_idx+1:].loc[A_df.start_ms < B_end]
n_A = A_candidates.shape[0]
#If not, the end, return dataframes and indices
if n_A == 0:
return((A_df, B_df, A_idx_list, B_idx_list, state))
else:
state = state +1
#If there are, add their indices to A_idx_list
A_idx_list += [i for i in A_candidates.index]
# The final A chunk may now overlap with a new B chunk.
# Go to Restart Point
return(align_chunks(A_df, B_df, A_idx_list, B_idx_list, state))
else:
# If there were B chunks to add...
state = state +1
# Add their indices to B_idx_list
B_idx_list += [i for i in B_candidates.index]
# Final B chunk may now overlap with an A chunk.
# Go to Restart Point
# B is now first in time.
return(align_chunks(B_df, A_df, B_idx_list, A_idx_list, state))
import pandas as pd
def align_transcriptions(X_df, Y_df, x_name = "X", y_name = "Y"):
"""
Given X_df and Y_df data frames, align all possible transcription chunks
"""
X_df["name"] = x_name
Y_df["name"] = y_name
summary_df = pd.DataFrame()
X_start = X_df.start_ms[0]
Y_start = Y_df.start_ms[0]
if X_start > Y_start:
X_last = -1
Y_last = Y_df.loc[Y_df.end_ms > X_start].index[0]-1
else:
Y_last = -1
X_last = X_df.loc[X_df.end_ms > Y_start].index[0]-1
grouping = 0
A_df, B_df, a_idx, b_idx, state = align_chunks(X_df, Y_df, [], [], 1, X_last = X_last, Y_last = Y_last)
stop = False
while not stop:
A_sub = A_df.loc[a_idx]
B_sub = B_df.loc[b_idx]
A_summary = {"idx_start" : [min(a_idx)],
"idx_end" : [max(a_idx)],
"start_ms" : [min(A_sub.start_ms)],
"end_ms" : [max(A_sub.end_ms)],
"text" : [" ".join(A_sub.text)],
"name" : [A_sub.name[a_idx[0]]],
"grouping" : [grouping]}
B_summary = {"idx_start" : [min(b_idx)],
"idx_end" : [max(b_idx)],
"start_ms" : [min(B_sub.start_ms)],
"end_ms" : [max(B_sub.end_ms)],
"text" : [" ".join(B_sub.text)],
"name" : [B_sub.name[b_idx[0]]],
"grouping" : [grouping]}
summary_df = pd.concat([summary_df, pd.DataFrame(A_summary), pd.DataFrame(B_summary)])
X_last = max(a_idx)
Y_last = max(b_idx)
grouping += 1
A_df, B_df, a_idx, b_idx, state = align_chunks(A_df, B_df, [], [], 1, X_last = X_last, Y_last = Y_last)
if len(b_idx) == 0:
stop = True
if len(a_idx) == 0:
stop = True
if max(a_idx) + 1 == A_df.shape[0]:
stop = True
if max(b_idx) + 1 == B_df.shape[0]:
stop = True
return(summary_df)
import pandas as pd
from jiwer import compute_measures, cer
def compare_transcriptions(ground, hyp):
"""
compute error metrics for two dataframes of misaligned transcriptions
"""
aligned = align_transcriptions(ground, hyp, x_name="ground", y_name="hyp")
to_comp = aligned[["text", "name", "grouping"]].pivot(index = "grouping", columns = "name", values = "text")
measures_dict = compute_measures(to_comp.ground.tolist(), to_comp.hyp.tolist())
measures_dict["cer"] = cer(to_comp.ground.tolist(), to_comp.hyp.tolist())
return(measures_dict)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment