Created
February 18, 2020 05:25
-
-
Save caspark/8826ae47aef433d2dbbfae21604c5f03 to your computer and use it in GitHub Desktop.
2020-02-17-kaldi-breaking-grammar
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging, os | |
import dragonfly | |
if False: | |
logging.basicConfig(level=10) | |
logging.getLogger('grammar.decode').setLevel(20) | |
logging.getLogger('compound').setLevel(20) | |
# logging.getLogger('kaldi').setLevel(30) | |
logging.getLogger('engine').setLevel(10) | |
logging.getLogger('kaldi').setLevel(10) | |
else: | |
logging.basicConfig(level=20) | |
from dragonfly.log import setup_log | |
setup_log() | |
class KaldiBreakerRule(dragonfly.CompoundRule): | |
spec = "[<alpha>] [<beta>]" | |
extras = [ | |
dragonfly.Repetition(name="alpha", min=1, max=3, | |
child=dragonfly.Alternative(name="alpha_alternative", children=[ | |
dragonfly.Literal("escape"), | |
dragonfly.Literal("escape"), | |
]) | |
), | |
dragonfly.Repetition(min=1, max=3, name="beta", | |
child=dragonfly.Literal("escape"), | |
), | |
] | |
def _process_recognition(self, node, extras): | |
print(f"Breaker recognized! node={node} and extras={extras}") | |
engine = dragonfly.get_engine("kaldi", | |
model_dir='models/daanzu_20200201_1ep-biglm', | |
) | |
engine.connect() | |
grammar = dragonfly.Grammar(name="mygrammar") | |
grammar.add_rule(KaldiBreakerRule()) | |
grammar.load() | |
# import utils_dragonfly | |
# print(f"Grammar loaded: {utils_dragonfly.get_grammar_complexity_tree(grammar)}") | |
print("Preparing for recognition...") | |
engine.prepare_for_recognition() | |
print("Listening...") | |
engine.do_recognition() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
class ComplexityNode(object): | |
def __init__(self, item): | |
self.item = item | |
self.children = [] | |
self.total_descendents = 1 | |
def build_complexity_tree(thing): | |
node = ComplexityNode(thing) | |
if isinstance(thing, Rule): | |
children = [thing.element] | |
element = thing.element | |
elif isinstance(thing, RuleRef): | |
children = [thing.rule.element] | |
else: | |
# thing is probably an Element | |
children = thing.children | |
for child in children: | |
child_node = build_complexity_tree(child) | |
node.children.append(child_node) | |
node.total_descendents += child_node.total_descendents | |
if isinstance(thing, Alternative): | |
node.children = sorted(node.children, reverse=False, | |
key=lambda node: str(node.item)) | |
node.children = sorted(node.children, reverse=True, | |
key=lambda node: node.total_descendents) | |
return node | |
def get_rule_complexity_tree(rule, depth_threshold=10, complexity_threshold=10): | |
def render_complexity_tree(node, current_depth): | |
pluralized_children = "children" if len( | |
node.children) != 1 else "child" | |
node_name = "%-75s %d" % (" " * current_depth + "- " + repr(node.item), node.total_descendents) | |
# if current_depth >= depth_threshold: | |
# return "" | |
# elif node.total_descendents <= complexity_threshold: | |
# return "%s (+ %3d uncomplex direct %s)" % (node_name, len(node.children), pluralized_children) | |
# if (isinstance(node.item, Integer) | |
# or isinstance(node.item, Compound) and node.total_descendents <= 2): | |
# children_repr = " (+ %3d trivial direct %s)" % ( | |
# len(node.children), pluralized_children) | |
# elif current_depth + 1 == depth_threshold and node.total_descendents > 1: | |
# children_repr = " (+ %3d truncated direct %s)" % ( | |
# len(node.children), pluralized_children) | |
# else: | |
if True: | |
children_repr = "" | |
for child in node.children: | |
child_repr = render_complexity_tree(child, current_depth + 1) | |
if len(child_repr) > 0: | |
children_repr += "\n" + child_repr | |
return node_name + children_repr | |
try: | |
tree = build_complexity_tree(rule) | |
return render_complexity_tree(tree, 0) | |
except Exception: | |
logging.exception("failed to build complexity tree") | |
return "" | |
def get_grammar_complexity_score(grammar): | |
try: | |
return sum([build_complexity_tree(r).total_descendents for r in grammar.rules if r.exported]) | |
except Exception: | |
logging.exception("failed to build grammar complexity score") | |
return 0 | |
def get_grammar_complexity_tree(grammar, threshold=5): | |
rules_all = grammar.rules | |
rules_top = [r for r in grammar.rules if r.exported] | |
rules_imp = [r for r in grammar.rules if r.imported] | |
text = ("%s: %d rules (%d exported, %d imported):" % ( | |
grammar, len(rules_all), len(rules_top), len(rules_imp), | |
)) | |
for rule in rules_top: | |
text += "\n%s" % get_rule_complexity_tree(rule, threshold) | |
return text |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment