Last active
May 5, 2022 01:52
-
-
Save supposedly/2405d5cc2fb700c643b8b6d76b023e93 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from decimal import Decimal # for accurate math | |
ZERO = Decimal('0') | |
# I don't know what to actually call this but it's for different | |
# node types | |
# eg do `P = Unit(); NP = Unit(); VP = Unit(); etc` | |
# The class overrides the > + operators to make it easy to write | |
# relationships, like P > NP or VP > V+NP | |
class Unit: | |
def __init__(self, nick=None): | |
self.nick = id(self) if nick is None else nick | |
def __repr__(self): | |
return str(self.nick) | |
def __gt__(self, other): | |
return (self, other) | |
def __add__(self, other): | |
return (self, other) | |
# Dictionary except it returns 0 for keys it doesn't have | |
# instead of throwing an error | |
class DictThatIsntWhiny(dict): | |
def __getitem__(self, key): | |
return super().get(key, ZERO) | |
# Shorthand for defining a bunch of units at once | |
# define('P', 'N', 'V') | |
# is the same as | |
# P = Unit('P'); N = Unit('N'); V = Unit('V') | |
def define(*units, namespace=None): | |
if namespace is None: | |
namespace = globals() | |
for name in units: | |
namespace[name] = Unit(name) | |
return namespace | |
# Construct the actual table out of words (list of words in a | |
# sentence) and P (the grammar + probabilities) | |
# also idk what 'back' is actually for but it was in the pseudocode | |
def parse(words, P): | |
P = DictThatIsntWhiny({k: Decimal(v) for k, v in P.items()}) | |
units = set(A for A, _ in P) | |
table = DictThatIsntWhiny({}) | |
back = {} | |
for j, word in enumerate(words): | |
for A in units: | |
if (A > word) in P: | |
table[j - 1, j, A] = P[A > word] | |
for i in range(j - 2, -2, -1): | |
for k in range(i + 1, j): | |
for A, (B, C) in ( | |
(A, BC) for A, BC in P | |
if isinstance(BC, tuple) | |
and len(BC) == 2 | |
and table[i, k, BC[0]] > 0 | |
and table[k, j, BC[1]] > 0 | |
): | |
if table[i, j, A] < P[A > B+C] * table[i, k, B] * table[k, j, C]: | |
table[i, j, A] = P[A > B+C] * table[i, k, B] * table[k, j, C] | |
back[i, j, A] = (k, B, C) | |
return table, back | |
# Prettify the result of parse() to make it printable | |
def format(words, P): | |
if isinstance(words, str): | |
words = words.split() | |
num_words = len(words) | |
table, _ = parse(words, P) | |
formatted_table = [words] + [[''] * num_words for _ in range(num_words)] | |
max_lengths = list(map(len, words)) # for padding | |
for (y, x, unit), probability in table.items(): | |
string = f'{unit}: {probability!s}' | |
max_lengths[x] = max(max_lengths[x], len(string)) | |
formatted_table[y + 2][x] = f'{unit}: {probability!s}' | |
for row in formatted_table: | |
for i, string in enumerate(row): | |
row[i] = string.ljust(max_lengths[i]) | |
# the 3 is the length of the ' | ' string | |
return f'\n{"-" * (sum(max_lengths) + num_words * 3 - 3)}\n'.join([' | '.join(row) for row in formatted_table]) | |
define( | |
'Det', | |
'P', | |
'Aux', | |
'V', | |
'VP', | |
'N', | |
'NP', | |
'S', | |
'PP' | |
) | |
print( | |
format('Kim adores snow in Oslo', { | |
Det>'the': '.6', | |
Det>'these': '.2', | |
Det>'this': '.2', | |
P>'before': '.4', | |
P>'on': '.4', | |
P>'in': '.2', | |
Aux>'is': '.55', | |
Aux>'can': '.25', | |
Aux>'does': '.2', | |
V>'snores': '.3', | |
V>'adores': '.7', | |
NP>'snow': '.1', | |
NP>'Kim': '.1', | |
NP>'Oslo': '.1', | |
NP>'Waikiki': '.1', | |
PP>P+S: '.1', | |
PP>P+NP: '.9', | |
NP>NP+PP: '.2', | |
NP>Det+N: '.4', | |
VP>VP+PP: '.3', | |
VP>V+NP: '.45', | |
VP>V+S: '.25', | |
S>Aux+S: '.2', | |
S>NP+VP: '.8' | |
}) | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment