Created
March 24, 2009 13:37
-
-
Save bayerj/84085 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Arithmetic(GrammarDataSet): | |
open_par = '(' | |
close_par = ')' | |
def __init__(self, maxnest=4, maxlength=50, size=250, | |
vars=None, operators=None): | |
"""Initialize an Arithmetic objective with a nesting depth of at most | |
maxnest, words up to a size of length and size items in the training and | |
test datasets. | |
The used variables and operators can be specified.""" | |
self.maxnest = maxnest | |
self.maxlength = maxlength | |
self.size = size | |
if vars is None: | |
self.vars = set(('x', 'y')) | |
if operators is None: | |
self.operators = set(('*', '+', '-')) | |
# This field is needed to construct arrays from symbol sets. Since the | |
# symbol set will never have both start and stop (they are not possible | |
# at the same point in the grammar) we will drop the start symbol. | |
self._sorted_symbols = sorted(self.symbols) | |
self._sorted_symbols.remove(self.start) | |
super(Arithmetic, self).__init__() | |
@property | |
def symbols(self): | |
return self.vars | self.operators | set([self.open_par, | |
self.close_par, | |
self.start, self.stop]) | |
def make_datasets(self): | |
words = set() | |
while len(words) < 2 * self.size: | |
new_word = self.dice_word(self.maxlength, self.maxnest) | |
words.add(new_word) | |
words = list(words) | |
random.shuffle(words) | |
test, train = words[:self.size], words[self.size:] | |
# Subtract one, since start symbol is not used in output and stop symbol | |
# is not used in input. | |
dim = len(self.symbols) - 1 | |
trainds = SequentialDataSet(dim, dim) | |
testds = SequentialDataSet(dim, dim) | |
for word in test: | |
self.add_word_to_ds(word, testds) | |
for word in train: | |
self.add_word_to_ds(word, trainds) | |
return trainds, testds | |
def add_word_to_ds(self, word, ds): | |
ds.newSequence() | |
for prefix in self.prefixes(word): | |
next = self.lookaheadset(prefix) | |
ds.addSample(self.arr_by_symbols(set(prefix[-1:])), | |
self.arr_by_symbols(next)) | |
def arr_by_symbols(self, this_symbols): | |
if self.start in this_symbols: | |
this_symbols.remove(self.start) | |
this_symbols.add(self.stop) | |
return [1. if self._sorted_symbols[i] in this_symbols else -1. | |
for i in range(len(self._sorted_symbols))] | |
def dice_word(self, maxlength, maxnest): | |
while True: | |
word = [] | |
while True: | |
# Set of possible following symbols. | |
nexts = self.lookaheadset(word) | |
# Chose one of them if there is room for more symbols left. | |
if nexts and len(word) <= maxlength: | |
next = self.onesample(nexts) | |
# Otherwise stop generating new symbols. | |
else: | |
break | |
word.append(next) | |
word = self.cut(word, maxlength) | |
if not word: | |
# Try again if the empty word was generated. | |
continue | |
break | |
return word | |
def cut(self, word, maxlength): | |
word = "".join(word) | |
word = word.replace(" ", "") | |
"""Cut off a suffix from the word so that the word is still valid but | |
has a length of less than maxlength.""" | |
if len(word) <= maxlength: | |
return word | |
# Subtract 1 in order to make up for the stop symbol. | |
for prefix in self.prefixes(word, maxlength - 1, reverse=True): | |
if self.stop in self.lookaheadset(prefix): | |
return prefix + ">" | |
return None | |
def prefixes(self, word, maxlength=None, reverse=False): | |
word = "".join(word) | |
word = word.replace(" ", "") | |
if maxlength is None: | |
maxlength = len(word) | |
indices = range(1, maxlength + 1) | |
if reverse: | |
indices.reverse() | |
for i in indices: | |
yield word[0:i] | |
def onesample(self, iterable): | |
bucket = list(iterable) | |
if not bucket: | |
raise ValueError("Empty iterable given.") | |
return random.sample(bucket, 1)[0] | |
def lookaheadset(self, word): | |
"""Return the valid next symbols. | |
If the resulting set is empty, it already is a valid word.""" | |
# The first symbol is needed. | |
if not word: | |
return self.start | |
last = word[-1] | |
nesting = word.count('(') - word.count(')') | |
next = set() | |
if last in (self.start, self.open_par) or last in self.operators: | |
next |= self.vars | |
if nesting < self.maxnest: | |
next.add(self.open_par) | |
elif last in self.vars or last == self.close_par: | |
next |= self.operators | |
if nesting > 0: | |
next.add(self.close_par) | |
else: | |
next.add(self.stop) | |
elif last == self.stop: | |
pass | |
return next |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment