Skip to content

Instantly share code, notes, and snippets.

@OhadRubin
Last active July 4, 2023 07:02
Show Gist options
  • Save OhadRubin/442098aaf073e2176d09db7bf80a19e1 to your computer and use it in GitHub Desktop.
Save OhadRubin/442098aaf073e2176d09db7bf80a19e1 to your computer and use it in GitHub Desktop.
Synchronous Probabilistic Context-Free Grammars sampling
import nltk
import random
class ProductionRule:
def __init__(self, production):
self.production = production
def get_lhs(self):
return self.production.lhs()
def get_rhs(self):
return self.production.rhs()
def get_prob(self):
return self.production.prob()
class Grammar:
def __init__(self):
self.raw_rules = []
def add_rule(self, rule):
self.raw_rules.append(rule)
def finalize(self):
pcfg_str = []
lhs_dict = {}
for rule in self.raw_rules:
lhs = nltk.grammar.Nonterminal(rule.split("->", 1)[0].strip())
if lhs in lhs_dict:
lhs_dict[lhs] += 1
else:
lhs_dict[lhs] = 1
for rule in self.raw_rules:
lhs = nltk.grammar.Nonterminal(rule.split("->", 1)[0].strip())
probability = 1.0 / lhs_dict[lhs]
pcfg_str.append(f"{rule} [{probability}]")
self.pcfg = nltk.grammar.PCFG.fromstring("\n".join(pcfg_str))
self.rules = [ProductionRule(rule) for rule in self.pcfg.productions()]
def rules_with_lhs(self, lhs):
return [rule for rule in self.rules if rule.get_lhs() == lhs]
class SynchronousGrammar:
def __init__(self):
self.source_grammar = Grammar()
self.target_grammar = Grammar()
self.rule_mapping = {}
def add_rule(self, source_rule, target_rule):
self.source_grammar.add_rule(source_rule)
self.target_grammar.add_rule(target_rule)
def finalize(self):
self.source_grammar.finalize()
self.target_grammar.finalize()
self.rule_mapping = {source_rule: target_rule for source_rule, target_rule in zip(self.source_grammar.rules, self.target_grammar.rules)}
def sample_rule(self, grammar, symbol):
"""Sample a production for a given symbol."""
rules = grammar.rules_with_lhs(symbol)
weights = [rule.get_prob() for rule in rules]
return random.choices(rules, weights=weights)[0]
def generate_sentence(self, symbol=nltk.grammar.Nonterminal('S')):
"""Generate a sentence and its translation from the synchronous PCFG."""
source_sentence = []
target_sentence = []
source_rule = self.sample_rule(self.source_grammar, symbol)
target_rule = self.rule_mapping[source_rule]
for source_sym, target_sym in zip(source_rule.get_rhs(), target_rule.get_rhs()):
if isinstance(source_sym, nltk.grammar.Nonterminal):
sub_source_sentence, sub_target_sentence = self.generate_sentence(source_sym)
source_sentence.extend(sub_source_sentence)
target_sentence.extend(sub_target_sentence)
else:
source_sentence.append(source_sym)
target_sentence.append(target_sym)
return source_sentence, target_sentence
# Example usage:
synchronous_grammar = SynchronousGrammar()
synchronous_grammar.add_rule("S -> NP VP", "S -> NP VP")
synchronous_grammar.add_rule("NP -> 'I'", "NP -> 'Ich'")
synchronous_grammar.add_rule("NP -> 'You'", "NP -> 'Du'")
synchronous_grammar.add_rule("VP -> V NP", "VP -> V NP")
synchronous_grammar.add_rule("VP -> VP NP", "VP -> VP NP")
synchronous_grammar.add_rule("V -> 'see'", "V -> 'sehe'")
synchronous_grammar.finalize()
source_sentence, target_sentence = synchronous_grammar.generate_sentence()
print(' '.join(source_sentence))
print(' '.join(target_sentence))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment