Skip to content

Instantly share code, notes, and snippets.

@ankitgurua
Created July 3, 2024 19:01
Show Gist options
  • Save ankitgurua/eac069ed0c95e1ce5924a10923883133 to your computer and use it in GitHub Desktop.
Save ankitgurua/eac069ed0c95e1ce5924a10923883133 to your computer and use it in GitHub Desktop.
Spacy for whispertimestamped
import os
import argparse
import logging
import json
from more_itertools import chunked
from collections.abc import Iterator
import spacy
from spacy.language import Language
from spacy.tokens import Doc, Span, Token
from spacy.matcher import Matcher
from whisper.utils import format_timestamp
# Define custom token extensions
Token.set_extension("can_fragment_after", default=False, force=True)
Token.set_extension("fragment_reason", default="", force=True)
def get_time_span(span: Span, timing: dict):
start_token = span[0]
end_token = span[-1]
while start_token.is_punct or not timing.get(start_token.idx, None):
start_token = start_token.nbor(-1)
while end_token.is_punct or not timing.get(end_token.idx, None):
end_token = end_token.nbor(-1)
end_index = end_token.idx
start_index = start_token.idx
start, _ = timing[start_index]
_, end = timing.get(end_index, (None, None))
if not end:
logging.debug("Timing alignment error: %s %d", span.text, end_token.idx)
return (start, end)
Span.set_extension("get_time_span", method=get_time_span, force=True)
@Language.factory("fragmenter")
class FragmenterComponent:
def __init__(self, nlp: Language, name: str, verbal_pauses: list):
self.nlp = nlp
self.name = name
self.pauses = set(verbal_pauses)
def __call__(self, doc: Doc) -> Doc:
return self.fragmenter(doc)
def fragmenter(self, doc: Doc) -> Doc:
matcher = Matcher(self.nlp.vocab)
# Define patterns
punct_pattern = [{'IS_PUNCT': True, 'ORTH': {"IN": [",", ":", ";"]}}]
conj_pattern = [{"POS": {"IN": ["CCONJ", "SCONJ"]}}]
clause_pattern = [{"DEP": {"IN": ["advcl", "relcl", "acl", "acl:relcl"]}}]
# Add patterns to matcher
matcher.add("punct", [punct_pattern])
matcher.add("conj", [conj_pattern])
matcher.add("clause", [clause_pattern])
# Find matches
matches = matcher(doc)
for match_id, start, end in matches:
rule = doc.vocab.strings[match_id]
matched_span = doc[start:end]
if rule == "punct":
matched_span[0]._.can_fragment_after = True
matched_span[0]._.fragment_reason = "punctuation"
elif rule == "conj":
if start > 0:
doc[start-1]._.can_fragment_after = True
doc[start-1]._.fragment_reason = "conjunction"
elif rule == "clause":
if start > 0:
doc[start-1]._.can_fragment_after = True
doc[start-1]._.fragment_reason = "clause"
# Handle verbal pauses
for token in doc:
if token.i in self.pauses:
token._.can_fragment_after = True
token._.fragment_reason = "verbal pause"
return doc
def load_whisper_json(file: str) -> tuple[str, dict]:
doc_timing = {}
doc_text = ""
with open(file) as js:
jsdata = json.load(js)
for s in jsdata['segments']:
for word_timed in s['words']:
if word_timed['text'] == '[*]':
continue # Skip non-speech segments
word = word_timed['text']
if len(doc_text) == 0:
word = word.lstrip()
start_index = 0
doc_text += word + ' ' # Add space between words
start_index = len(doc_text) - len(word) - 1 # Account for added space
doc_timing[start_index] = (word_timed['start'], word_timed['end'])
return doc_text.strip(), doc_timing
def scan_for_pauses(doc_text: str, timing: dict) -> list[int]:
pauses = []
sorted_timings = sorted(timing.items())
for i in range(len(sorted_timings) - 1):
(k1, (_, end)), (k2, (start, _)) = sorted_timings[i], sorted_timings[i+1]
gap = start - end
if gap > 0.5:
pauses.append(k1)
return pauses
def divide_span(span: Span, args) -> Iterator[Span]:
max_width = args.width
if span.end_char - span.start_char <= max_width:
yield span
return
for i, token in enumerate(span):
if token._.can_fragment_after:
first_part = span[:i+1]
if first_part.end_char - first_part.start_char <= max_width:
yield first_part
yield from divide_span(span[i+1:], args)
return
# If no suitable breakpoint found, break at max_width
yield span[:max_width]
yield from divide_span(span[max_width:], args)
def iterate_document(doc: Doc, timing: dict, args):
max_lines = args.lines
for sentence in doc.sents:
for chunk in chunked(divide_span(sentence, args), max_lines):
subtitle = '\n'.join(line.text for line in chunk)
sub_start, _ = chunk[0]._.get_time_span(timing)
_, sub_end = chunk[-1]._.get_time_span(timing)
yield sub_start, sub_end, subtitle
def write_srt(doc, timing, args):
comma: str = ','
for i, (start, end, text) in enumerate(iterate_document(doc, timing, args), start=1):
ts1 = format_timestamp(start, always_include_hours=True, decimal_marker=comma)
ts2 = format_timestamp(end, always_include_hours=True, decimal_marker=comma)
print(f"{i}\n{ts1} --> {ts2}\n{text}\n")
def configure_spaCy(model: str, entities: str, pauses: list = []):
nlp = spacy.load(model)
if model.startswith('xx'):
raise NotImplementedError("spaCy multilanguage models are not currently supported")
nlp.add_pipe("fragmenter", config={"verbal_pauses": pauses}, last=True)
if entities:
ruler = nlp.add_pipe("entity_ruler", config={"overwrite_ents": True})
ruler.from_disk(entities)
return nlp
def main():
parser = argparse.ArgumentParser(
prog='subwisp',
description='Convert a whisper .json transcript into .srt subtitles with sentences, grammatically separated where possible.')
parser.add_argument('input_file')
parser.add_argument('-m', '--model', help='specify spaCy model', default="en_core_web_trf")
parser.add_argument('-e', '--entities', help='optional custom entities for spaCy (.jsonl format)', default="")
parser.add_argument('-w', '--width', help='maximum line width', default=42, type=int)
parser.add_argument('-l', '--lines', help='maximum lines per subtitle', default=2, type=int, choices=range(1,4))
parser.add_argument('-d', '--debug', help='print debug information',
action="store_const", dest="loglevel", const=logging.DEBUG, default=logging.WARNING)
parser.add_argument('--verbose', help='be verbose',
action="store_const", dest="loglevel", const=logging.INFO)
args = parser.parse_args()
logging.basicConfig(level=args.loglevel)
if not os.path.isfile(args.input_file):
logging.error("File not found: %s", args.input_file)
exit(-1)
if not args.model:
logging.error("No spacy model specified")
exit(-1)
if len(args.entities) > 0 and not os.path.isfile(args.entities):
logging.error("Entities file not found: %s", args.entities)
exit(-1)
wtext, word_timing = load_whisper_json(args.input_file)
verbal_pauses = scan_for_pauses(wtext, word_timing)
nlp = configure_spaCy(args.model, args.entities, verbal_pauses)
doc = nlp(wtext)
write_srt(doc, word_timing, args)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment