Created
April 22, 2016 19:48
-
-
Save braingineer/c69482eb1bfa4ac3bf9a7bc9b6b35cdf to your computer and use it in GitHub Desktop.
serving data
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function, division | |
import yaml | |
import time | |
import sys | |
import numpy as np | |
import itertools | |
from keras.utils.np_utils import to_categorical | |
from baal.utils import loggers | |
try: | |
input = raw_input | |
except: | |
pass | |
class DataServer(object): | |
""" | |
Required config: | |
disable_logger: {True/False} | |
saving_prefix | |
batch_size | |
Is also good to "stage" it by adding to it the following: | |
train_data | |
dev_data | |
train_vocab | |
""" | |
def __init__(self, config): | |
self.__dict__.update(config) | |
log_name = self.saving_prefix | |
self.logger = loggers.duallog(log_name, "info", "logs/", disable=self.disable_logger) | |
def print_everything(self): | |
for key, value in self.__dict__.items(): | |
if isinstance(value, (bool, str, int, float)): | |
print("{}: {}".format(key, value)) | |
@classmethod | |
def from_file(cls, config_file): | |
with open(config_file) as fp: | |
config = yaml.load(fp) | |
return cls(config) | |
@property | |
def num_train_batches(self): | |
return len(self.train_data)//self.batch_size | |
@property | |
def num_dev_batches(self): | |
return len(self.dev_data)//self.batch_size | |
@property | |
def num_train_samples(self): | |
if self.subepochs > 0: | |
return self.num_train_batches // self.subepochs * self.batch_size | |
return self.num_train_batches * self.batch_size | |
@property | |
def num_dev_samples(self): | |
return self.num_dev_batches * self.batch_size | |
def stage(self, *args, **kwargs): | |
""" would be great to load things here """ | |
pass | |
def serve_single(self, data, unseen_data=None): | |
"""serve a single sample from a dataset; | |
yield X and Y of size appropriate to sample_size | |
### an example implementation | |
for data_i in np.random.choice(len(data), len(data), replace=False): | |
in_X = np.zeros(self.max_sequence_len) | |
out_Y = np.zeros(self.max_sequence_len, dtype=np.int32) | |
bigram_data = zip(data[data_i][0:-1], data[data_i][1:]) | |
for datum_j,(datum_in, datum_out) in enumerate(bigram_data): | |
in_X[datum_j] = datum_in | |
out_Y[datum_j] = datum_out | |
yield in_X, out_Y | |
""" | |
raise NotImplementedError | |
def serve_batch(self, data, unseen_data=None): | |
"""serve a batch of samples from a dataset; | |
yield X and Y of sizes appropriate to (batch,) + sample_size | |
### an example implementation | |
### yields (batch,sequence) and (batch, sequence, vocab) | |
dataiter = self.serve_sentence(data) | |
V = self.vocab_size | |
S = self.max_sequence_len | |
B = self.batch_size | |
while dataiter: | |
in_X = np.zeros((B, S), dtype=np.int32) | |
out_Y = np.zeros((B, S, V), dtype=np.int32) | |
next_batch = list(itertools.islice(dataiter, 0, self.batch_size)) | |
if len(next_batch) < self.batch_size: | |
raise StopIteration | |
for d_i, (d_X, d_Y) in enumerate(next_batch): | |
in_X[d_i] = d_X | |
out_Y[d_i] = to_categorical(d_Y, V) | |
yield in_X, out_Y | |
""" | |
raise NotImplementedError | |
def _data_gen(self, data, forever=True, unseen_data=None): | |
working = True | |
while working: | |
for batch in self.serve_batch(data, unseen_data): | |
yield batch | |
working = working and forever | |
def dev_gen(self, forever=True): | |
return self._data_gen(self.dev_data, forever, self.dev_unseen) | |
def train_gen(self, forever=True): | |
return self._data_gen(self.train_data, forever, self.train_unseen) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import baal | |
members = ["treecut", "hlfdebug", "nlg", "csp", "trees"] | |
levels = level = {"debug": logging.DEBUG, "warning":logging.WARNING, | |
"info": logging.INFO, "error":logging.ERROR, | |
"critical":logging.CRITICAL} | |
def shell_log(loggername="", level="debug"): | |
logger = logging.getLogger(loggername) | |
if not logger.handlers: | |
ch = logging.StreamHandler() | |
ch.setLevel(levels[level]) | |
logger.addHandler(ch) | |
logger.setLevel(levels[level]) | |
return logger | |
def file_log(loggername, filename, level="debug"): | |
logger = logging.getLogger(loggername) | |
fh = logging.FileHandler(filename) | |
fh.setLevel(levels[level]) | |
logger.addHandler(fh) | |
logger.setLevel(levels[level]) | |
def duallog(loggername, shell_level="info", file_loc="logs/", disable=False): | |
logger = logging.getLogger(loggername) | |
logger.setLevel(logging.DEBUG) | |
if not logger.handlers and not disable: | |
print(file_loc) | |
if baal.utils.general.ensure_dir(file_loc+("/" if file_loc[-1] != "/" else "")): | |
print("Created directory for logger at {}".format(file_loc)) | |
fh = logging.FileHandler("{}/{}.debug.log".format(file_loc, loggername)) | |
fh.setLevel(logging.DEBUG) | |
sh = logging.StreamHandler() | |
sh.setLevel(levels[shell_level]) | |
logger.addHandler(fh) | |
logger.addHandler(sh) | |
return logger | |
def turn_on(name, level="debug", shell=True, filename=None): | |
if shell: | |
shell_log(name, level) | |
elif filename: | |
file_log(name, filename, level) | |
def get(name, level='debug', turn_on=True): | |
if turn_on: | |
return shell_log(name, level) | |
else: | |
logger = logging.getLogger(name) | |
logger.addHandler(NullHandler()) | |
return logger | |
def set_level(name, level): | |
logging.getLogger(name).setLevel(levels[level]) | |
class NullHandler(logging.Handler): | |
def emit(self, record): | |
pass |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
### file stuff | |
word_vocab: data/wsj_chiang.vocab | |
train_filepath: $gist$/data/chiang_wsj_train_supertags.pkl | |
dev_filepath: $gist$/data/chiang_wsj_dev_supertags.pkl | |
test_filepath: $gist$/data/chiang_wsj_test_supertags.pkl | |
checkpoint_weights: ki_weights.h5 | |
### parameters | |
num_epochs: 500 | |
batch_size: 32 | |
attention_size: 50 | |
pos_embedding_size: 35 | |
type_embedding_size: 15 | |
word_embedding_size: 50 | |
LR: 0.001 | |
max_grad_norm: 10. | |
grad_clip_threshold: 5.0 | |
unroll_lstms: True | |
data_frequency_cutoff: 2 | |
subepochs: 5 | |
#dropout parameters | |
p_emb_dropout: 0.5 | |
p_W_dropout: 0.35 | |
p_U_dropout: 0.35 | |
p_dense_dropout: 0.0 | |
p_summary_dropout: 0.5 | |
p_individual_summary_dropout: 0 | |
weight_decay: 1e-6 | |
### logging and saving | |
# 5p0 has some great great great stuff! | |
saving_prefix: fergus_5p1p2 # upped sizes; not compatible with 5p0 | |
disable_logger: False | |
from_checkpoint: False | |
### parameters set at runtime | |
max_num_supertags: 0 | |
max_spine_length: 0 | |
max_context_size: 0 | |
max_num_children: 0 | |
word_vocab_size: 0 | |
################ | |
### bit info is for multiple embeddings that get concatted | |
# [(pos_vocab_size, pos_embedding_size) , (type_vocab_size, type_embedding_size)] | |
bit_info: ~ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from numpy.random import randint | |
import json | |
class Vocabulary(object): | |
""" | |
Taken from Tim Vieira & his Github: https://github.com/timvieira/arsenal | |
Bijective mapping from strings to integers. | |
Used for turning a bunch of string observations into integers (presumably for indexing and other fun things) | |
>>> a = Vocabulary() | |
>>> [a[x] for x in 'abcd'] | |
[0, 1, 2, 3] | |
>>> map(a.lookup, range(4)) | |
['a', 'b', 'c', 'd'] | |
>>> a.stop_growth() # Thus, a growth stop means it will be nice about its closed-ness | |
>>> a['e'] | |
>>> a.freeze() #freezing means harsher punishments for trying to go outside its domain | |
>>> a.add('z') | |
Traceback (most recent call last): | |
... | |
ValueError: Vocabulary is frozen. Key "z" not found. | |
>>> print a.plaintext() | |
a | |
b | |
c | |
d | |
""" | |
def __init__(self, random_int=None, use_mask=False): | |
self._mapping = {} # str -> int | |
self._flip = {} # int -> str; timv: consider using array or list | |
self._i = 0 | |
self._frozen = False | |
self._growing = True | |
self._random_int = random_int # if non-zero, will randomly assign | |
# integers (between 0 and randon_int) as | |
# index (possibly with collisions) | |
self.unk_symbol = "<UNK>" | |
self.mask_symbol = "<MASK>" | |
self.emit_unks = False | |
self.use_mask = use_mask | |
if self.use_mask: | |
self.add(self.mask_symbol) | |
def __repr__(self): | |
return 'Vocabulary(size=%s,frozen=%s)' % (len(self), self._frozen) | |
def freeze(self, emit_unks=False): | |
self.emit_unks = emit_unks | |
if emit_unks and "<UNK>" not in self: | |
self.add("<UNK>") | |
self._frozen = True | |
def stop_growth(self): | |
self._growing = False | |
@classmethod | |
def from_iterable(cls, s): | |
"Assumes keys are strings." | |
inst = cls() | |
for x in s: | |
inst.add(x) | |
return inst | |
def keyset(self): | |
keys = set(self._mapping.keys()) | |
if self.mask_symbol in keys: | |
keys.remove(self.mask_symbol) | |
return keys | |
def keys(self): | |
return self._mapping.iterkeys() | |
def items(self): | |
return self._mapping.iteritems() | |
def filter_generator(self, seq, emit_none=False): | |
""" | |
Apply Vocabulary to sequence while filtering. By default, `None` is not | |
emitted, so please note that the output sequence may have fewer items. | |
""" | |
if emit_none: | |
for s in seq: | |
yield self[s] | |
else: | |
for s in seq: | |
x = self[s] | |
if x is not None: | |
yield x | |
def filter(self, seq, *args, **kwargs): | |
return list(self.filter_generator(seq, *args, **kwargs)) | |
def add_many(self, x): | |
return [self.add(k) for k in x] | |
def lookup(self, i): | |
if i is None: | |
return None | |
#assert isinstance(i, int) | |
return self._flip[i] | |
def lookup_many(self, x): | |
for k in x: | |
yield self.lookup(k) | |
def __contains__(self, k): | |
#assert isinstance(k, basestring) | |
return k in self._mapping | |
def __getitem__(self, k): | |
try: | |
return self._mapping[k] | |
except KeyError: | |
#if not isinstance(k, basestring): | |
# raise ValueError("Invalid key (%s): only strings allowed." % (k,)) | |
if self._frozen: | |
if self.emit_unks: | |
return self._mapping[self.unk_symbol] | |
else: | |
raise ValueError('Vocabulary is frozen. Key "%s" not found.' % (k,)) | |
if not self._growing: | |
if self.emit_unks: | |
return self._mapping[self.unk_symbol] | |
else: | |
return None | |
if self._random_int: | |
x = self._mapping[k] = randint(0, self._random_int) | |
else: | |
x = self._mapping[k] = self._i | |
self._i += 1 | |
self._flip[x] = k | |
return x | |
add = __getitem__ | |
def __setitem__(self, k, v): | |
assert k not in self._mapping | |
if self._frozen: raise ValueError("Vocabulary is frozen. Key '%s' cannot be changed") | |
assert isinstance(v, int) | |
self._mapping[k] = v | |
self._flip[v] = k | |
def __iter__(self): | |
for i in xrange(len(self)): | |
yield self._flip[i] | |
def enum(self): | |
for i in xrange(len(self)): | |
yield (i, self._flip[i]) | |
def __len__(self): | |
return len(self._mapping) | |
def plaintext(self): | |
"assumes keys are strings" | |
return '\n'.join(self) | |
def exact_save(self): | |
return self._mapping | |
@classmethod | |
def exact_load(cls, exact_dict): | |
if "mapping" in exact_dict and isinstance(exact_dict['mapping'], dict): | |
exact_dict, config = exact_dict['mapping'], exact_dict['config'] | |
else: | |
config = {} | |
new_vocab = cls() | |
for k,v in exact_dict.items(): | |
assert isinstance(v,int) | |
new_vocab._mapping[k] = v | |
new_vocab._flip[v] = k | |
new_vocab._i = len(new_vocab) + 1 | |
new_vocab.__dict__.update(config) | |
return new_vocab | |
@classmethod | |
def load(cls, filename): | |
if not os.path.exists(filename): | |
return cls() | |
with open(filename) as fp: | |
return cls.exact_load(json.load(fp)) | |
def _config(self): | |
config = {"emit_unk": self.emit_unk, | |
"use_mask": self.use_mask, | |
"_frozen": self._frozen, | |
"_growing": self._growing} | |
return config | |
def save(self, filename, exactly=False): | |
with file(filename, 'wb') as fp: | |
json.dump({"mapping":self._mapping, | |
"config":self._config()}, fp) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment