Last active
February 16, 2019 07:51
-
-
Save jorendorff/55fd6cb694f69b802912340b5a6bba21 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
class IndexedList: | |
""" Like a list, but append-only, values must be hashable, and duplicates are ignored. """ | |
def __init__(self, values=None): | |
self._list = [] | |
self._index = {} | |
if values is not None: | |
for v in values: | |
self.append(v) | |
def append(self, value): | |
if value not in self._index: | |
self._index[value] = len(self._list) | |
self._list.append(value) | |
def __contains__(self, value): | |
return value in self._index | |
def __len__(self): | |
return len(self._list) | |
def __getitem__(self, index): | |
return self._list[index] | |
def __iter__(self): | |
return iter(self._list) | |
def index(self, value): | |
return self._index[value] | |
def topo_sort(iterable, predecessors): | |
"""Topological sort the values in `iterable`, discarding duplicates. | |
predecessors(value) returns the value's predecessors; it should return a | |
list of values that are in `iterable`. | |
If possible, return a list that is a permutation of `list(set(iterable))`, | |
such that for all X, Y in iterable, if X in predecessors(Y), then X | |
appears before Y in the output. | |
If no such permutation exists, raise a ValueError. | |
""" | |
seen = set() | |
out = [] | |
def add(value): | |
if value in seen: | |
if value not in out: | |
raise ValueError(f"cycle detected involving {value!r}") | |
else: | |
seen.add(value) | |
for p in predecessors(value): | |
add(p) | |
out.append(value) | |
for value in iterable: | |
add(value) | |
return out | |
grammar = { | |
'add': [ | |
['mul'], | |
['add', '+', 'mul'], | |
['add', '-', 'mul'], | |
], | |
'mul': [ | |
['pre'], | |
['mul', '*', 'pre'], | |
['mul', '/', 'pre'], | |
], | |
'pre': [ | |
['prim'], | |
['-', 'pre'], | |
['prim', '^', 'pre'], | |
], | |
'prim': [ | |
['N'], | |
['v'], | |
['(', 'add', ')'], | |
], | |
} | |
def is_nt(symbol): | |
return symbol in grammar | |
class Grammar: | |
def __init__(self, indexified_grammar, suffixes, counts_by_nt, counts_by_seq): | |
self.indexified_grammar = indexified_grammar | |
self.suffixes = suffixes | |
self.counts_by_nt = counts_by_nt | |
self.counts_by_seq = counts_by_seq | |
def is_nt(self, symbol): | |
return symbol in self.indexified_grammar | |
def sentence(self, nt, length, fuel): | |
if not (0 <= fuel < self.counts_by_nt[nt][length]): | |
raise IndexError("sentence index out of range") | |
# Select a production. | |
for index in self.indexified_grammar[nt]: | |
d = self.counts_by_seq[index][length] | |
if fuel < d: | |
return self.sequence(index, length, fuel) | |
fuel -= d | |
def sequence(self, index, length, fuel): | |
head, tail = self.suffixes[index] | |
if tail is None: | |
if self.is_nt(head): | |
return self.sentence(head, length, fuel) | |
else: | |
assert fuel == 0 | |
return [head] | |
else: | |
if self.is_nt(head): | |
for k in range(1, length): | |
d_head = self.counts_by_nt[head][k] | |
d_tail = self.counts_by_seq[tail][length - k] | |
d = d_head * d_tail | |
if fuel < d: | |
return (self.sentence(head, k, fuel // d_tail) + | |
self.sequence(tail, length - k, fuel % d_tail)) | |
fuel -= d | |
else: | |
return [head] + self.sequence(tail, length - 1, fuel) | |
def compile_grammar(grammar, length): | |
nts = list(grammar.keys()) | |
def is_nt(symbol): | |
return symbol in grammar | |
# First, just compute the set of all suffixes, saving the index of each | |
# whole production. | |
indexified_grammar = {nt: [] for nt in nts} | |
suffixes = IndexedList() | |
for nt in nts: | |
for prod in grammar[nt]: | |
acc = None | |
for i in reversed(range(len(prod))): | |
pair = (prod[i], acc) | |
if pair not in suffixes: | |
suffixes.append(pair) | |
acc = suffixes.index(pair) | |
indexified_grammar[nt].append(acc) | |
# Now list all the data dependencies. Note that a sequence `( term )` does | |
# not have to be counted after its suffix `term )`! This is because we | |
# count sequences `term )` of length N-1 in a previous round before trying | |
# to count sequences `( term )` of length N. And no symbol ever matches the | |
# empty string. | |
dependencies = [[]] * len(suffixes) | |
for prod_index_list in indexified_grammar.values(): | |
for index in prod_index_list: | |
head, tail = suffixes[index] | |
if is_nt(head) and tail is None: | |
# A full production consisting of only a nonterminal can be | |
# counted only after all that nonterminal's productions have | |
# been counted. | |
dependencies[index] = indexified_grammar[head] | |
# Use the dependencies to sort the suffixes. | |
sorted_indexes = topo_sort(range(len(suffixes)), dependencies.__getitem__) | |
worklist = [] | |
for index in sorted_indexes: | |
head, tail = suffixes[index] | |
nts_satisfied = [nt for nt in nts | |
if index in indexified_grammar[nt]] | |
worklist.append((index, head, tail, nts_satisfied)) | |
# Ready to calculate. | |
counts_by_seq = [[0] for _ in suffixes] | |
counts_by_nt = {nt: [0] * (length + 1) for nt in nts} | |
for current_length in range(1, length + 1): | |
for index, head, tail, nts_satisfied in worklist: | |
if tail is None: | |
if is_nt(head): | |
n = counts_by_nt[head][current_length] | |
else: | |
n = 1 if current_length == 1 else 0 | |
else: | |
if is_nt(head): | |
n = sum(counts_by_nt[head][k] * counts_by_seq[tail][current_length - k] | |
for k in range(1, current_length)) | |
else: | |
n = counts_by_seq[tail][current_length - 1] | |
counts_by_seq[index].append(n) | |
for nt in nts_satisfied: | |
counts_by_nt[nt][current_length] += n | |
return Grammar(indexified_grammar, suffixes, counts_by_nt, counts_by_seq) | |
def main(): | |
maxlen = 20 | |
g = compile_grammar(grammar, maxlen) | |
n = g.counts_by_nt['add'][maxlen] | |
print(f"There are {n} strings " + | |
f"of length {maxlen} that match the 'add' nonterminal. Here are a few uniformly selected ones:") | |
for _ in range(15): | |
i = random.randrange(n) | |
print(f"{i}: {' '.join(g.sentence('add', maxlen, i))}") | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment