Last active
April 30, 2019 03:45
-
-
Save marcoleewow/6ac845fb000e15bd0dc5bdfb9b2f3ad0 to your computer and use it in GitHub Desktop.
Error rate calculation for handwriting recognition. pip install editdistance==0.4
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import editdistance | |
def cer(y_true: str, y_pred: str) -> float: | |
assert isinstance(y_true, str) and isinstance(y_pred, str) | |
n_err = editdistance.eval(y_true.rstrip(), y_pred.rstrip()) | |
return n_err / len(y_true) | |
def wer(transcript: str, reference: str) -> float: | |
assert isinstance(transcript, str) and isinstance(reference, str) | |
transcript = transcript.split() | |
reference = reference.split() | |
n_err = editdistance.eval(transcript, reference) | |
return n_err / len(reference) | |
if __name__ == "__main__": | |
# exact case | |
cer_score = cer("hello world", "hello world") | |
wer_score = wer("hello world", "hello world") | |
assert cer_score == 0. | |
assert wer_score == 0. | |
# one deletion case | |
cer_score = cer("hello world", "hello world!") | |
wer_score = wer("hello world", "hello world!") | |
assert cer_score == 1. / 11. | |
assert wer_score == 1. / 2. | |
# one substitution case | |
cer_score = cer("hello world", "hallo world") | |
wer_score = wer("hello world", "hallo world") | |
assert cer_score == 1. / 11. | |
assert wer_score == 1. / 2. | |
# one insertion case | |
cer_score = cer("hello world", "ello world") | |
wer_score = wer("hello world", "ello world") | |
assert cer_score == 1. / 11. | |
assert wer_score == 1. / 2. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment