Created
July 3, 2024 19:01
-
-
Save ankitgurua/eac069ed0c95e1ce5924a10923883133 to your computer and use it in GitHub Desktop.
Spacy for whispertimestamped
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import argparse | |
import logging | |
import json | |
from more_itertools import chunked | |
from collections.abc import Iterator | |
import spacy | |
from spacy.language import Language | |
from spacy.tokens import Doc, Span, Token | |
from spacy.matcher import Matcher | |
from whisper.utils import format_timestamp | |
# Define custom token extensions | |
Token.set_extension("can_fragment_after", default=False, force=True) | |
Token.set_extension("fragment_reason", default="", force=True) | |
def get_time_span(span: Span, timing: dict): | |
start_token = span[0] | |
end_token = span[-1] | |
while start_token.is_punct or not timing.get(start_token.idx, None): | |
start_token = start_token.nbor(-1) | |
while end_token.is_punct or not timing.get(end_token.idx, None): | |
end_token = end_token.nbor(-1) | |
end_index = end_token.idx | |
start_index = start_token.idx | |
start, _ = timing[start_index] | |
_, end = timing.get(end_index, (None, None)) | |
if not end: | |
logging.debug("Timing alignment error: %s %d", span.text, end_token.idx) | |
return (start, end) | |
Span.set_extension("get_time_span", method=get_time_span, force=True) | |
@Language.factory("fragmenter") | |
class FragmenterComponent: | |
def __init__(self, nlp: Language, name: str, verbal_pauses: list): | |
self.nlp = nlp | |
self.name = name | |
self.pauses = set(verbal_pauses) | |
def __call__(self, doc: Doc) -> Doc: | |
return self.fragmenter(doc) | |
def fragmenter(self, doc: Doc) -> Doc: | |
matcher = Matcher(self.nlp.vocab) | |
# Define patterns | |
punct_pattern = [{'IS_PUNCT': True, 'ORTH': {"IN": [",", ":", ";"]}}] | |
conj_pattern = [{"POS": {"IN": ["CCONJ", "SCONJ"]}}] | |
clause_pattern = [{"DEP": {"IN": ["advcl", "relcl", "acl", "acl:relcl"]}}] | |
# Add patterns to matcher | |
matcher.add("punct", [punct_pattern]) | |
matcher.add("conj", [conj_pattern]) | |
matcher.add("clause", [clause_pattern]) | |
# Find matches | |
matches = matcher(doc) | |
for match_id, start, end in matches: | |
rule = doc.vocab.strings[match_id] | |
matched_span = doc[start:end] | |
if rule == "punct": | |
matched_span[0]._.can_fragment_after = True | |
matched_span[0]._.fragment_reason = "punctuation" | |
elif rule == "conj": | |
if start > 0: | |
doc[start-1]._.can_fragment_after = True | |
doc[start-1]._.fragment_reason = "conjunction" | |
elif rule == "clause": | |
if start > 0: | |
doc[start-1]._.can_fragment_after = True | |
doc[start-1]._.fragment_reason = "clause" | |
# Handle verbal pauses | |
for token in doc: | |
if token.i in self.pauses: | |
token._.can_fragment_after = True | |
token._.fragment_reason = "verbal pause" | |
return doc | |
def load_whisper_json(file: str) -> tuple[str, dict]: | |
doc_timing = {} | |
doc_text = "" | |
with open(file) as js: | |
jsdata = json.load(js) | |
for s in jsdata['segments']: | |
for word_timed in s['words']: | |
if word_timed['text'] == '[*]': | |
continue # Skip non-speech segments | |
word = word_timed['text'] | |
if len(doc_text) == 0: | |
word = word.lstrip() | |
start_index = 0 | |
doc_text += word + ' ' # Add space between words | |
start_index = len(doc_text) - len(word) - 1 # Account for added space | |
doc_timing[start_index] = (word_timed['start'], word_timed['end']) | |
return doc_text.strip(), doc_timing | |
def scan_for_pauses(doc_text: str, timing: dict) -> list[int]: | |
pauses = [] | |
sorted_timings = sorted(timing.items()) | |
for i in range(len(sorted_timings) - 1): | |
(k1, (_, end)), (k2, (start, _)) = sorted_timings[i], sorted_timings[i+1] | |
gap = start - end | |
if gap > 0.5: | |
pauses.append(k1) | |
return pauses | |
def divide_span(span: Span, args) -> Iterator[Span]: | |
max_width = args.width | |
if span.end_char - span.start_char <= max_width: | |
yield span | |
return | |
for i, token in enumerate(span): | |
if token._.can_fragment_after: | |
first_part = span[:i+1] | |
if first_part.end_char - first_part.start_char <= max_width: | |
yield first_part | |
yield from divide_span(span[i+1:], args) | |
return | |
# If no suitable breakpoint found, break at max_width | |
yield span[:max_width] | |
yield from divide_span(span[max_width:], args) | |
def iterate_document(doc: Doc, timing: dict, args): | |
max_lines = args.lines | |
for sentence in doc.sents: | |
for chunk in chunked(divide_span(sentence, args), max_lines): | |
subtitle = '\n'.join(line.text for line in chunk) | |
sub_start, _ = chunk[0]._.get_time_span(timing) | |
_, sub_end = chunk[-1]._.get_time_span(timing) | |
yield sub_start, sub_end, subtitle | |
def write_srt(doc, timing, args): | |
comma: str = ',' | |
for i, (start, end, text) in enumerate(iterate_document(doc, timing, args), start=1): | |
ts1 = format_timestamp(start, always_include_hours=True, decimal_marker=comma) | |
ts2 = format_timestamp(end, always_include_hours=True, decimal_marker=comma) | |
print(f"{i}\n{ts1} --> {ts2}\n{text}\n") | |
def configure_spaCy(model: str, entities: str, pauses: list = []): | |
nlp = spacy.load(model) | |
if model.startswith('xx'): | |
raise NotImplementedError("spaCy multilanguage models are not currently supported") | |
nlp.add_pipe("fragmenter", config={"verbal_pauses": pauses}, last=True) | |
if entities: | |
ruler = nlp.add_pipe("entity_ruler", config={"overwrite_ents": True}) | |
ruler.from_disk(entities) | |
return nlp | |
def main(): | |
parser = argparse.ArgumentParser( | |
prog='subwisp', | |
description='Convert a whisper .json transcript into .srt subtitles with sentences, grammatically separated where possible.') | |
parser.add_argument('input_file') | |
parser.add_argument('-m', '--model', help='specify spaCy model', default="en_core_web_trf") | |
parser.add_argument('-e', '--entities', help='optional custom entities for spaCy (.jsonl format)', default="") | |
parser.add_argument('-w', '--width', help='maximum line width', default=42, type=int) | |
parser.add_argument('-l', '--lines', help='maximum lines per subtitle', default=2, type=int, choices=range(1,4)) | |
parser.add_argument('-d', '--debug', help='print debug information', | |
action="store_const", dest="loglevel", const=logging.DEBUG, default=logging.WARNING) | |
parser.add_argument('--verbose', help='be verbose', | |
action="store_const", dest="loglevel", const=logging.INFO) | |
args = parser.parse_args() | |
logging.basicConfig(level=args.loglevel) | |
if not os.path.isfile(args.input_file): | |
logging.error("File not found: %s", args.input_file) | |
exit(-1) | |
if not args.model: | |
logging.error("No spacy model specified") | |
exit(-1) | |
if len(args.entities) > 0 and not os.path.isfile(args.entities): | |
logging.error("Entities file not found: %s", args.entities) | |
exit(-1) | |
wtext, word_timing = load_whisper_json(args.input_file) | |
verbal_pauses = scan_for_pauses(wtext, word_timing) | |
nlp = configure_spaCy(args.model, args.entities, verbal_pauses) | |
doc = nlp(wtext) | |
write_srt(doc, word_timing, args) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment