Skip to content

Instantly share code, notes, and snippets.

@bayerj
Created March 24, 2009 13:37
Show Gist options
  • Save bayerj/84085 to your computer and use it in GitHub Desktop.
Save bayerj/84085 to your computer and use it in GitHub Desktop.
class Arithmetic(GrammarDataSet):
open_par = '('
close_par = ')'
def __init__(self, maxnest=4, maxlength=50, size=250,
vars=None, operators=None):
"""Initialize an Arithmetic objective with a nesting depth of at most
maxnest, words up to a size of length and size items in the training and
test datasets.
The used variables and operators can be specified."""
self.maxnest = maxnest
self.maxlength = maxlength
self.size = size
if vars is None:
self.vars = set(('x', 'y'))
if operators is None:
self.operators = set(('*', '+', '-'))
# This field is needed to construct arrays from symbol sets. Since the
# symbol set will never have both start and stop (they are not possible
# at the same point in the grammar) we will drop the start symbol.
self._sorted_symbols = sorted(self.symbols)
self._sorted_symbols.remove(self.start)
super(Arithmetic, self).__init__()
@property
def symbols(self):
return self.vars | self.operators | set([self.open_par,
self.close_par,
self.start, self.stop])
def make_datasets(self):
words = set()
while len(words) < 2 * self.size:
new_word = self.dice_word(self.maxlength, self.maxnest)
words.add(new_word)
words = list(words)
random.shuffle(words)
test, train = words[:self.size], words[self.size:]
# Subtract one, since start symbol is not used in output and stop symbol
# is not used in input.
dim = len(self.symbols) - 1
trainds = SequentialDataSet(dim, dim)
testds = SequentialDataSet(dim, dim)
for word in test:
self.add_word_to_ds(word, testds)
for word in train:
self.add_word_to_ds(word, trainds)
return trainds, testds
def add_word_to_ds(self, word, ds):
ds.newSequence()
for prefix in self.prefixes(word):
next = self.lookaheadset(prefix)
ds.addSample(self.arr_by_symbols(set(prefix[-1:])),
self.arr_by_symbols(next))
def arr_by_symbols(self, this_symbols):
if self.start in this_symbols:
this_symbols.remove(self.start)
this_symbols.add(self.stop)
return [1. if self._sorted_symbols[i] in this_symbols else -1.
for i in range(len(self._sorted_symbols))]
def dice_word(self, maxlength, maxnest):
while True:
word = []
while True:
# Set of possible following symbols.
nexts = self.lookaheadset(word)
# Chose one of them if there is room for more symbols left.
if nexts and len(word) <= maxlength:
next = self.onesample(nexts)
# Otherwise stop generating new symbols.
else:
break
word.append(next)
word = self.cut(word, maxlength)
if not word:
# Try again if the empty word was generated.
continue
break
return word
def cut(self, word, maxlength):
word = "".join(word)
word = word.replace(" ", "")
"""Cut off a suffix from the word so that the word is still valid but
has a length of less than maxlength."""
if len(word) <= maxlength:
return word
# Subtract 1 in order to make up for the stop symbol.
for prefix in self.prefixes(word, maxlength - 1, reverse=True):
if self.stop in self.lookaheadset(prefix):
return prefix + ">"
return None
def prefixes(self, word, maxlength=None, reverse=False):
word = "".join(word)
word = word.replace(" ", "")
if maxlength is None:
maxlength = len(word)
indices = range(1, maxlength + 1)
if reverse:
indices.reverse()
for i in indices:
yield word[0:i]
def onesample(self, iterable):
bucket = list(iterable)
if not bucket:
raise ValueError("Empty iterable given.")
return random.sample(bucket, 1)[0]
def lookaheadset(self, word):
"""Return the valid next symbols.
If the resulting set is empty, it already is a valid word."""
# The first symbol is needed.
if not word:
return self.start
last = word[-1]
nesting = word.count('(') - word.count(')')
next = set()
if last in (self.start, self.open_par) or last in self.operators:
next |= self.vars
if nesting < self.maxnest:
next.add(self.open_par)
elif last in self.vars or last == self.close_par:
next |= self.operators
if nesting > 0:
next.add(self.close_par)
else:
next.add(self.stop)
elif last == self.stop:
pass
return next
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment