Skip to content

Instantly share code, notes, and snippets.

@braingineer
Created April 22, 2016 19:48
Show Gist options
  • Save braingineer/c69482eb1bfa4ac3bf9a7bc9b6b35cdf to your computer and use it in GitHub Desktop.
Save braingineer/c69482eb1bfa4ac3bf9a7bc9b6b35cdf to your computer and use it in GitHub Desktop.
serving data
from __future__ import print_function, division
import yaml
import time
import sys
import numpy as np
import itertools
from keras.utils.np_utils import to_categorical
from baal.utils import loggers
try:
input = raw_input
except:
pass
class DataServer(object):
"""
Required config:
disable_logger: {True/False}
saving_prefix
batch_size
Is also good to "stage" it by adding to it the following:
train_data
dev_data
train_vocab
"""
def __init__(self, config):
self.__dict__.update(config)
log_name = self.saving_prefix
self.logger = loggers.duallog(log_name, "info", "logs/", disable=self.disable_logger)
def print_everything(self):
for key, value in self.__dict__.items():
if isinstance(value, (bool, str, int, float)):
print("{}: {}".format(key, value))
@classmethod
def from_file(cls, config_file):
with open(config_file) as fp:
config = yaml.load(fp)
return cls(config)
@property
def num_train_batches(self):
return len(self.train_data)//self.batch_size
@property
def num_dev_batches(self):
return len(self.dev_data)//self.batch_size
@property
def num_train_samples(self):
if self.subepochs > 0:
return self.num_train_batches // self.subepochs * self.batch_size
return self.num_train_batches * self.batch_size
@property
def num_dev_samples(self):
return self.num_dev_batches * self.batch_size
def stage(self, *args, **kwargs):
""" would be great to load things here """
pass
def serve_single(self, data, unseen_data=None):
"""serve a single sample from a dataset;
yield X and Y of size appropriate to sample_size
### an example implementation
for data_i in np.random.choice(len(data), len(data), replace=False):
in_X = np.zeros(self.max_sequence_len)
out_Y = np.zeros(self.max_sequence_len, dtype=np.int32)
bigram_data = zip(data[data_i][0:-1], data[data_i][1:])
for datum_j,(datum_in, datum_out) in enumerate(bigram_data):
in_X[datum_j] = datum_in
out_Y[datum_j] = datum_out
yield in_X, out_Y
"""
raise NotImplementedError
def serve_batch(self, data, unseen_data=None):
"""serve a batch of samples from a dataset;
yield X and Y of sizes appropriate to (batch,) + sample_size
### an example implementation
### yields (batch,sequence) and (batch, sequence, vocab)
dataiter = self.serve_sentence(data)
V = self.vocab_size
S = self.max_sequence_len
B = self.batch_size
while dataiter:
in_X = np.zeros((B, S), dtype=np.int32)
out_Y = np.zeros((B, S, V), dtype=np.int32)
next_batch = list(itertools.islice(dataiter, 0, self.batch_size))
if len(next_batch) < self.batch_size:
raise StopIteration
for d_i, (d_X, d_Y) in enumerate(next_batch):
in_X[d_i] = d_X
out_Y[d_i] = to_categorical(d_Y, V)
yield in_X, out_Y
"""
raise NotImplementedError
def _data_gen(self, data, forever=True, unseen_data=None):
working = True
while working:
for batch in self.serve_batch(data, unseen_data):
yield batch
working = working and forever
def dev_gen(self, forever=True):
return self._data_gen(self.dev_data, forever, self.dev_unseen)
def train_gen(self, forever=True):
return self._data_gen(self.train_data, forever, self.train_unseen)
import logging
import baal
members = ["treecut", "hlfdebug", "nlg", "csp", "trees"]
levels = level = {"debug": logging.DEBUG, "warning":logging.WARNING,
"info": logging.INFO, "error":logging.ERROR,
"critical":logging.CRITICAL}
def shell_log(loggername="", level="debug"):
logger = logging.getLogger(loggername)
if not logger.handlers:
ch = logging.StreamHandler()
ch.setLevel(levels[level])
logger.addHandler(ch)
logger.setLevel(levels[level])
return logger
def file_log(loggername, filename, level="debug"):
logger = logging.getLogger(loggername)
fh = logging.FileHandler(filename)
fh.setLevel(levels[level])
logger.addHandler(fh)
logger.setLevel(levels[level])
def duallog(loggername, shell_level="info", file_loc="logs/", disable=False):
logger = logging.getLogger(loggername)
logger.setLevel(logging.DEBUG)
if not logger.handlers and not disable:
print(file_loc)
if baal.utils.general.ensure_dir(file_loc+("/" if file_loc[-1] != "/" else "")):
print("Created directory for logger at {}".format(file_loc))
fh = logging.FileHandler("{}/{}.debug.log".format(file_loc, loggername))
fh.setLevel(logging.DEBUG)
sh = logging.StreamHandler()
sh.setLevel(levels[shell_level])
logger.addHandler(fh)
logger.addHandler(sh)
return logger
def turn_on(name, level="debug", shell=True, filename=None):
if shell:
shell_log(name, level)
elif filename:
file_log(name, filename, level)
def get(name, level='debug', turn_on=True):
if turn_on:
return shell_log(name, level)
else:
logger = logging.getLogger(name)
logger.addHandler(NullHandler())
return logger
def set_level(name, level):
logging.getLogger(name).setLevel(levels[level])
class NullHandler(logging.Handler):
def emit(self, record):
pass
### file stuff
word_vocab: data/wsj_chiang.vocab
train_filepath: $gist$/data/chiang_wsj_train_supertags.pkl
dev_filepath: $gist$/data/chiang_wsj_dev_supertags.pkl
test_filepath: $gist$/data/chiang_wsj_test_supertags.pkl
checkpoint_weights: ki_weights.h5
### parameters
num_epochs: 500
batch_size: 32
attention_size: 50
pos_embedding_size: 35
type_embedding_size: 15
word_embedding_size: 50
LR: 0.001
max_grad_norm: 10.
grad_clip_threshold: 5.0
unroll_lstms: True
data_frequency_cutoff: 2
subepochs: 5
#dropout parameters
p_emb_dropout: 0.5
p_W_dropout: 0.35
p_U_dropout: 0.35
p_dense_dropout: 0.0
p_summary_dropout: 0.5
p_individual_summary_dropout: 0
weight_decay: 1e-6
### logging and saving
# 5p0 has some great great great stuff!
saving_prefix: fergus_5p1p2 # upped sizes; not compatible with 5p0
disable_logger: False
from_checkpoint: False
### parameters set at runtime
max_num_supertags: 0
max_spine_length: 0
max_context_size: 0
max_num_children: 0
word_vocab_size: 0
################
### bit info is for multiple embeddings that get concatted
# [(pos_vocab_size, pos_embedding_size) , (type_vocab_size, type_embedding_size)]
bit_info: ~
import os
from numpy.random import randint
import json
class Vocabulary(object):
"""
Taken from Tim Vieira & his Github: https://github.com/timvieira/arsenal
Bijective mapping from strings to integers.
Used for turning a bunch of string observations into integers (presumably for indexing and other fun things)
>>> a = Vocabulary()
>>> [a[x] for x in 'abcd']
[0, 1, 2, 3]
>>> map(a.lookup, range(4))
['a', 'b', 'c', 'd']
>>> a.stop_growth() # Thus, a growth stop means it will be nice about its closed-ness
>>> a['e']
>>> a.freeze() #freezing means harsher punishments for trying to go outside its domain
>>> a.add('z')
Traceback (most recent call last):
...
ValueError: Vocabulary is frozen. Key "z" not found.
>>> print a.plaintext()
a
b
c
d
"""
def __init__(self, random_int=None, use_mask=False):
self._mapping = {} # str -> int
self._flip = {} # int -> str; timv: consider using array or list
self._i = 0
self._frozen = False
self._growing = True
self._random_int = random_int # if non-zero, will randomly assign
# integers (between 0 and randon_int) as
# index (possibly with collisions)
self.unk_symbol = "<UNK>"
self.mask_symbol = "<MASK>"
self.emit_unks = False
self.use_mask = use_mask
if self.use_mask:
self.add(self.mask_symbol)
def __repr__(self):
return 'Vocabulary(size=%s,frozen=%s)' % (len(self), self._frozen)
def freeze(self, emit_unks=False):
self.emit_unks = emit_unks
if emit_unks and "<UNK>" not in self:
self.add("<UNK>")
self._frozen = True
def stop_growth(self):
self._growing = False
@classmethod
def from_iterable(cls, s):
"Assumes keys are strings."
inst = cls()
for x in s:
inst.add(x)
return inst
def keyset(self):
keys = set(self._mapping.keys())
if self.mask_symbol in keys:
keys.remove(self.mask_symbol)
return keys
def keys(self):
return self._mapping.iterkeys()
def items(self):
return self._mapping.iteritems()
def filter_generator(self, seq, emit_none=False):
"""
Apply Vocabulary to sequence while filtering. By default, `None` is not
emitted, so please note that the output sequence may have fewer items.
"""
if emit_none:
for s in seq:
yield self[s]
else:
for s in seq:
x = self[s]
if x is not None:
yield x
def filter(self, seq, *args, **kwargs):
return list(self.filter_generator(seq, *args, **kwargs))
def add_many(self, x):
return [self.add(k) for k in x]
def lookup(self, i):
if i is None:
return None
#assert isinstance(i, int)
return self._flip[i]
def lookup_many(self, x):
for k in x:
yield self.lookup(k)
def __contains__(self, k):
#assert isinstance(k, basestring)
return k in self._mapping
def __getitem__(self, k):
try:
return self._mapping[k]
except KeyError:
#if not isinstance(k, basestring):
# raise ValueError("Invalid key (%s): only strings allowed." % (k,))
if self._frozen:
if self.emit_unks:
return self._mapping[self.unk_symbol]
else:
raise ValueError('Vocabulary is frozen. Key "%s" not found.' % (k,))
if not self._growing:
if self.emit_unks:
return self._mapping[self.unk_symbol]
else:
return None
if self._random_int:
x = self._mapping[k] = randint(0, self._random_int)
else:
x = self._mapping[k] = self._i
self._i += 1
self._flip[x] = k
return x
add = __getitem__
def __setitem__(self, k, v):
assert k not in self._mapping
if self._frozen: raise ValueError("Vocabulary is frozen. Key '%s' cannot be changed")
assert isinstance(v, int)
self._mapping[k] = v
self._flip[v] = k
def __iter__(self):
for i in xrange(len(self)):
yield self._flip[i]
def enum(self):
for i in xrange(len(self)):
yield (i, self._flip[i])
def __len__(self):
return len(self._mapping)
def plaintext(self):
"assumes keys are strings"
return '\n'.join(self)
def exact_save(self):
return self._mapping
@classmethod
def exact_load(cls, exact_dict):
if "mapping" in exact_dict and isinstance(exact_dict['mapping'], dict):
exact_dict, config = exact_dict['mapping'], exact_dict['config']
else:
config = {}
new_vocab = cls()
for k,v in exact_dict.items():
assert isinstance(v,int)
new_vocab._mapping[k] = v
new_vocab._flip[v] = k
new_vocab._i = len(new_vocab) + 1
new_vocab.__dict__.update(config)
return new_vocab
@classmethod
def load(cls, filename):
if not os.path.exists(filename):
return cls()
with open(filename) as fp:
return cls.exact_load(json.load(fp))
def _config(self):
config = {"emit_unk": self.emit_unk,
"use_mask": self.use_mask,
"_frozen": self._frozen,
"_growing": self._growing}
return config
def save(self, filename, exactly=False):
with file(filename, 'wb') as fp:
json.dump({"mapping":self._mapping,
"config":self._config()}, fp)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment