Last active
July 4, 2023 07:02
-
-
Save OhadRubin/442098aaf073e2176d09db7bf80a19e1 to your computer and use it in GitHub Desktop.
Synchronous Probabilistic Context-Free Grammars sampling
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import nltk | |
import random | |
class ProductionRule: | |
def __init__(self, production): | |
self.production = production | |
def get_lhs(self): | |
return self.production.lhs() | |
def get_rhs(self): | |
return self.production.rhs() | |
def get_prob(self): | |
return self.production.prob() | |
class Grammar: | |
def __init__(self): | |
self.raw_rules = [] | |
def add_rule(self, rule): | |
self.raw_rules.append(rule) | |
def finalize(self): | |
pcfg_str = [] | |
lhs_dict = {} | |
for rule in self.raw_rules: | |
lhs = nltk.grammar.Nonterminal(rule.split("->", 1)[0].strip()) | |
if lhs in lhs_dict: | |
lhs_dict[lhs] += 1 | |
else: | |
lhs_dict[lhs] = 1 | |
for rule in self.raw_rules: | |
lhs = nltk.grammar.Nonterminal(rule.split("->", 1)[0].strip()) | |
probability = 1.0 / lhs_dict[lhs] | |
pcfg_str.append(f"{rule} [{probability}]") | |
self.pcfg = nltk.grammar.PCFG.fromstring("\n".join(pcfg_str)) | |
self.rules = [ProductionRule(rule) for rule in self.pcfg.productions()] | |
def rules_with_lhs(self, lhs): | |
return [rule for rule in self.rules if rule.get_lhs() == lhs] | |
class SynchronousGrammar: | |
def __init__(self): | |
self.source_grammar = Grammar() | |
self.target_grammar = Grammar() | |
self.rule_mapping = {} | |
def add_rule(self, source_rule, target_rule): | |
self.source_grammar.add_rule(source_rule) | |
self.target_grammar.add_rule(target_rule) | |
def finalize(self): | |
self.source_grammar.finalize() | |
self.target_grammar.finalize() | |
self.rule_mapping = {source_rule: target_rule for source_rule, target_rule in zip(self.source_grammar.rules, self.target_grammar.rules)} | |
def sample_rule(self, grammar, symbol): | |
"""Sample a production for a given symbol.""" | |
rules = grammar.rules_with_lhs(symbol) | |
weights = [rule.get_prob() for rule in rules] | |
return random.choices(rules, weights=weights)[0] | |
def generate_sentence(self, symbol=nltk.grammar.Nonterminal('S')): | |
"""Generate a sentence and its translation from the synchronous PCFG.""" | |
source_sentence = [] | |
target_sentence = [] | |
source_rule = self.sample_rule(self.source_grammar, symbol) | |
target_rule = self.rule_mapping[source_rule] | |
for source_sym, target_sym in zip(source_rule.get_rhs(), target_rule.get_rhs()): | |
if isinstance(source_sym, nltk.grammar.Nonterminal): | |
sub_source_sentence, sub_target_sentence = self.generate_sentence(source_sym) | |
source_sentence.extend(sub_source_sentence) | |
target_sentence.extend(sub_target_sentence) | |
else: | |
source_sentence.append(source_sym) | |
target_sentence.append(target_sym) | |
return source_sentence, target_sentence | |
# Example usage: | |
synchronous_grammar = SynchronousGrammar() | |
synchronous_grammar.add_rule("S -> NP VP", "S -> NP VP") | |
synchronous_grammar.add_rule("NP -> 'I'", "NP -> 'Ich'") | |
synchronous_grammar.add_rule("NP -> 'You'", "NP -> 'Du'") | |
synchronous_grammar.add_rule("VP -> V NP", "VP -> V NP") | |
synchronous_grammar.add_rule("VP -> VP NP", "VP -> VP NP") | |
synchronous_grammar.add_rule("V -> 'see'", "V -> 'sehe'") | |
synchronous_grammar.finalize() | |
source_sentence, target_sentence = synchronous_grammar.generate_sentence() | |
print(' '.join(source_sentence)) | |
print(' '.join(target_sentence)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment