Created
August 15, 2023 06:09
-
-
Save Lauler/ccfa87faf006144209b3d4eda6b042fe to your computer and use it in GitHub Desktop.
Preprocess srt files and bucket to ~30s chunks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import os | |
import pandas as pd | |
import pysrt | |
import argparse | |
from tqdm import tqdm | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--data_dir", | |
type=str, | |
default="kb_exempel_2", | |
help="Directory containing subdirectories with audio and srt files.", | |
) | |
parser.add_argument( | |
"--output_file", | |
type=str, | |
default="subs_preprocessed.parquet", | |
help="Name of output file.", | |
) | |
args = parser.parse_args() | |
subdirs = os.listdir(args.data_dir) | |
df = pd.DataFrame(subdirs, columns=["subdir"]) | |
files = {} | |
for subdir in df["subdir"].tolist(): | |
files[subdir] = os.listdir(os.path.join(args.data_dir, subdir)) | |
df["files"] = df["subdir"].map(files) | |
df["audio"] = df["files"].map(lambda x: [file for file in x if file.endswith(".wav")][0]) | |
df["srt"] = df["files"].map(lambda x: [file for file in x if file.endswith(".srt")][0]) | |
df.drop("files", axis=1, inplace=True) | |
# Read every srt file in df and save each line and timestamp in a dataframe | |
df_subs = [] | |
for subdir, srt, audio in zip(df["subdir"].tolist(), df["srt"].tolist(), df["audio"].tolist()): | |
sub = pysrt.open(os.path.join(args.data_dir, subdir, srt)) | |
sub_block_data = [] | |
for sub_block in sub: | |
sub_block_data.append( | |
{ | |
"subdir": subdir, | |
"start": sub_block.start, | |
"end": sub_block.end, | |
"text": sub_block.text, | |
"srt": srt, | |
"audio": audio, | |
} | |
) | |
df_sub = pd.DataFrame(sub_block_data) | |
df_subs.append(df_sub) | |
df_subs = pd.concat(df_subs).reset_index(drop=True) | |
# Convert srt timestamps to milliseconds | |
df_subs["start_ms"] = df_subs["start"].map( | |
lambda x: x.hours * 3600000 + x.minutes * 60000 + x.seconds * 1000 + x.milliseconds | |
) | |
df_subs["end_ms"] = df_subs["end"].map( | |
lambda x: x.hours * 3600000 + x.minutes * 60000 + x.seconds * 1000 + x.milliseconds | |
) | |
df_subs["duration_s"] = (df_subs["end_ms"] - df_subs["start_ms"]) / 1000 | |
# Divide the subtitle blocks into 30 second buckets | |
df_groups = [] | |
for group, df_group in tqdm( | |
df_subs.groupby("audio"), | |
total=df_subs.groupby("audio").ngroups, | |
): | |
start = df_group["start_ms"].iloc[0] | |
bucket_nr = 0 | |
bucket_cumsum = [] | |
bucket_nrs = [] | |
for i, end in enumerate(df_group["end_ms"]): | |
if ((end - start) / 1000) >= 30: | |
bucket_nr += 1 | |
start = df_group["start_ms"].iloc[i] | |
f"Bucket {bucket_nr} has duration {prev_segment_length}." | |
prev_segment_length = (end - start) / 1000 | |
bucket_cumsum.append(prev_segment_length) | |
bucket_nrs.append(bucket_nr) | |
df_group["observation_nr"] = bucket_nrs | |
df_group["bucket_cumsum"] = bucket_cumsum | |
df_groups.append(df_group) | |
df_groups = pd.concat(df_groups) | |
df_groups = df_groups.reset_index(drop=True) | |
# Maximum value of bucket_cumsum in each bucket (observation_nr group) is the duration of the observation | |
df_groups["bucket_duration_s"] = df_groups.groupby("observation_nr")["bucket_cumsum"].transform(max) | |
# Relative start and end times for each subtitle block within a bucket (observation_nr grouping) | |
df_groups["start_relative"] = df_groups["start_ms"] - df_groups.groupby("observation_nr")["start_ms"].transform(min) | |
df_groups["end_relative"] = df_groups["end_ms"] - df_groups.groupby("observation_nr")["start_ms"].transform(min) | |
# Round to nearest 20 ms (Whisper quantizes to nearest 20 ms for its timestamps) | |
df_groups["start_relative"] = (np.round(df_groups["start_relative"] / 20) * 20) / 1000 | |
df_groups["end_relative"] = (np.round(df_groups["end_relative"] / 20)) * 20 / 1000 | |
# start_bucket is the start_ms of the bucket in an observation_nr group | |
df_groups["start_bucket"] = df_groups.groupby("observation_nr")["start_ms"].transform(min) | |
# end_bucket is the end_ms of the bucket in an observation_nr group | |
df_groups["end_bucket"] = df_groups.groupby("observation_nr")["end_ms"].transform(max) | |
def format_timestamp(timestamp): | |
timestamp = "<|" + f"{timestamp:.2f}" + "|>" | |
return timestamp | |
df_groups["start_timestamp"] = df_groups["start_relative"].map(format_timestamp) | |
df_groups["end_timestamp"] = df_groups["end_relative"].map(format_timestamp) | |
df_groups["text_timestamps"] = df_groups["start_timestamp"] + df_groups["text"] + df_groups["end_timestamp"] | |
# Create a new column that joins the text_timestamps for each observation_nr group | |
df_groups["text_timestamps_bucket"] = df_groups.groupby("observation_nr")["text_timestamps"].transform( | |
lambda x: " ".join(x) | |
) | |
df_groups[ | |
[ | |
"subdir", | |
"audio", | |
"observation_nr", | |
"start_ms", | |
"end_ms", | |
"start_bucket", | |
"end_bucket", | |
"text", | |
"text_timestamps_bucket", | |
"start_relative", | |
"end_relative", | |
"bucket_duration_s", | |
] | |
].to_parquet(args.output_file, index=False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment