Skip to content

Instantly share code, notes, and snippets.

@supposedly
Last active May 5, 2022 01:52
Show Gist options
  • Save supposedly/2405d5cc2fb700c643b8b6d76b023e93 to your computer and use it in GitHub Desktop.
Save supposedly/2405d5cc2fb700c643b8b6d76b023e93 to your computer and use it in GitHub Desktop.
from decimal import Decimal # for accurate math
ZERO = Decimal('0')
# I don't know what to actually call this but it's for different
# node types
# eg do `P = Unit(); NP = Unit(); VP = Unit(); etc`
# The class overrides the > + operators to make it easy to write
# relationships, like P > NP or VP > V+NP
class Unit:
def __init__(self, nick=None):
self.nick = id(self) if nick is None else nick
def __repr__(self):
return str(self.nick)
def __gt__(self, other):
return (self, other)
def __add__(self, other):
return (self, other)
# Dictionary except it returns 0 for keys it doesn't have
# instead of throwing an error
class DictThatIsntWhiny(dict):
def __getitem__(self, key):
return super().get(key, ZERO)
# Shorthand for defining a bunch of units at once
# define('P', 'N', 'V')
# is the same as
# P = Unit('P'); N = Unit('N'); V = Unit('V')
def define(*units, namespace=None):
if namespace is None:
namespace = globals()
for name in units:
namespace[name] = Unit(name)
return namespace
# Construct the actual table out of words (list of words in a
# sentence) and P (the grammar + probabilities)
# also idk what 'back' is actually for but it was in the pseudocode
def parse(words, P):
P = DictThatIsntWhiny({k: Decimal(v) for k, v in P.items()})
units = set(A for A, _ in P)
table = DictThatIsntWhiny({})
back = {}
for j, word in enumerate(words):
for A in units:
if (A > word) in P:
table[j - 1, j, A] = P[A > word]
for i in range(j - 2, -2, -1):
for k in range(i + 1, j):
for A, (B, C) in (
(A, BC) for A, BC in P
if isinstance(BC, tuple)
and len(BC) == 2
and table[i, k, BC[0]] > 0
and table[k, j, BC[1]] > 0
):
if table[i, j, A] < P[A > B+C] * table[i, k, B] * table[k, j, C]:
table[i, j, A] = P[A > B+C] * table[i, k, B] * table[k, j, C]
back[i, j, A] = (k, B, C)
return table, back
# Prettify the result of parse() to make it printable
def format(words, P):
if isinstance(words, str):
words = words.split()
num_words = len(words)
table, _ = parse(words, P)
formatted_table = [words] + [[''] * num_words for _ in range(num_words)]
max_lengths = list(map(len, words)) # for padding
for (y, x, unit), probability in table.items():
string = f'{unit}: {probability!s}'
max_lengths[x] = max(max_lengths[x], len(string))
formatted_table[y + 2][x] = f'{unit}: {probability!s}'
for row in formatted_table:
for i, string in enumerate(row):
row[i] = string.ljust(max_lengths[i])
# the 3 is the length of the ' | ' string
return f'\n{"-" * (sum(max_lengths) + num_words * 3 - 3)}\n'.join([' | '.join(row) for row in formatted_table])
define(
'Det',
'P',
'Aux',
'V',
'VP',
'N',
'NP',
'S',
'PP'
)
print(
format('Kim adores snow in Oslo', {
Det>'the': '.6',
Det>'these': '.2',
Det>'this': '.2',
P>'before': '.4',
P>'on': '.4',
P>'in': '.2',
Aux>'is': '.55',
Aux>'can': '.25',
Aux>'does': '.2',
V>'snores': '.3',
V>'adores': '.7',
NP>'snow': '.1',
NP>'Kim': '.1',
NP>'Oslo': '.1',
NP>'Waikiki': '.1',
PP>P+S: '.1',
PP>P+NP: '.9',
NP>NP+PP: '.2',
NP>Det+N: '.4',
VP>VP+PP: '.3',
VP>V+NP: '.45',
VP>V+S: '.25',
S>Aux+S: '.2',
S>NP+VP: '.8'
})
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment