Created
April 5, 2024 23:44
-
-
Save pzelasko/168dbc53bf6b80f97bc20abb828b5e51 to your computer and use it in GitHub Desktop.
Analyze where the most errors are found in ASR transcripts using a NeMo manifest with `text` and `pred_text` keys.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Make sure to first run: | |
$ pip install click pandas lhotse kaldialign | |
""" | |
import click | |
import pandas as pd | |
from lhotse.serialization import load_jsonl | |
from kaldialign import align, bootstrap_wer_ci | |
EPS = '*' | |
@click.command() | |
@click.argument("manifest", type=click.Path(exists=True)) | |
@click.option("-c", "--cer", is_flag=True, help="If true, we won't split the text by whitespace and treat each character as a separate symbol..") | |
@click.option("-n", "--num-splits", default=3, type=int, help="Number of position splits for detailed analysis. You may increase this number if your data consists mostly of long utterances.") | |
def analyse_results(manifest: str, cer: bool, num_splits: int) -> None: | |
items = [(item['text'], item['pred_text']) for item in load_jsonl(manifest)] | |
stats = [] | |
tot_sym = 0 | |
for uttidx, (ref, hyp) in enumerate(items): | |
if not cer: | |
ref, hyp = ref.split(), hyp.split() | |
tot_sym += len(ref) | |
ali = align(ref, hyp, EPS) | |
tot = len(ali) | |
for pos, (r, h) in enumerate(ali): | |
if r == h: | |
continue | |
stat = {"ref": r, "hyp": h, "pos": pos, "tot": tot, "relpos": pos / tot, "uttidx": uttidx} | |
if r == EPS: | |
stat["kind"] = "ins" | |
elif h == EPS: | |
stat["kind"] = "del" | |
else: | |
stat["kind"] = "sub" | |
stats.append(stat) | |
refs, hyps = zip(*items) | |
ans = bootstrap_wer_ci([r.split() for r in refs], [h.split() for h in hyps]) | |
click.echo(f"Boostrap WER={ans['wer']:.2%}+/-{ans['ci95']:.2%} [[email protected]={ans['ci95min']:.2%} - [email protected]={ans['ci95max']:.2%}]") | |
df = pd.DataFrame(stats) | |
tot_err = len(df) | |
KINDS = "del ins sub".split() | |
tot_kind = {kind: len(df.query(f'kind == "{kind}"')) for kind in KINDS} | |
msg = "\t* " | |
for kind, val in tot_kind.items(): | |
msg += f"{kind}={val/tot_sym:.2%} " | |
click.echo(msg) | |
click.echo("Error location analysis [relative to utterance length]:") | |
pieces = range(num_splits) | |
for b, e in zip(range(num_splits), range(1, num_splits + 1)): | |
b = b / num_splits | |
if e == num_splits: | |
e = 1.0001 # last loop iter, include last symbol pos | |
else: | |
e = e / num_splits | |
subdf = df[(b <= df.relpos) & (df.relpos < e)] | |
num_err = len(subdf) | |
click.echo(f"[{b:.2f} - {e:.2f}]") | |
click.echo(f"\t* {num_err / tot_err:.1%} of all errors.") | |
for kind in KINDS: | |
num_kind = len(subdf.query(f"kind == '{kind}'")) | |
click.echo(f"\t* {num_kind / num_err:.1%} are of type '{kind}' (this constitutes {num_kind / tot_kind[kind]:.1%} of all '{kind}').") | |
if __name__ == "__main__": | |
analyse_results() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment