Created
August 22, 2017 15:03
-
-
Save dpressel/ab64e1a1f0aeed97a8bbbf713a1704b0 to your computer and use it in GitHub Desktop.
RNN model with sum of char embeddings.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from baseline.tf.tfy import * | |
import json | |
import os | |
from google.protobuf import text_format | |
from tensorflow.python.platform import gfile | |
from baseline.model import Tagger | |
from tensorflow.contrib.layers import fully_connected, xavier_initializer | |
class RNNWordBoWCharTaggerModel(Tagger): | |
def save_values(self, basename): | |
self.saver.save(self.sess, basename) | |
def save_md(self, basename): | |
path = basename.split('/') | |
base = path[-1] | |
outdir = '/'.join(path[:-1]) | |
tf.train.write_graph(self.sess.graph_def, outdir, base + '.graph', as_text=False) | |
with open(basename + '.saver', 'w') as f: | |
f.write(str(self.saver.as_saver_def())) | |
with open(basename + '.labels', 'w') as f: | |
json.dump(self.labels, f) | |
if len(self.word_vocab) > 0: | |
with open(basename + '-word.vocab', 'w') as f: | |
json.dump(self.word_vocab, f) | |
with open(basename + '-char.vocab', 'w') as f: | |
json.dump(self.char_vocab, f) | |
def make_feed_dict(self, x, xch, lengths, y=None, do_dropout=False): | |
pkeep = 1.0-self.pdrop_value if do_dropout else 1.0 | |
feed_dict = {self.x: x, self.xch: xch, self.lengths: lengths, self.pkeep: pkeep} | |
if y is not None: | |
feed_dict[self.y] = y | |
return feed_dict | |
def save(self, basename): | |
self.save_md(basename) | |
self.save_values(basename) | |
@staticmethod | |
def load(basename, **kwargs): | |
model = RNNWordBoWCharTaggerModel() | |
model.sess = kwargs.get('sess', tf.Session()) | |
checkpoint_name = kwargs.get('checkpoint_name', basename) | |
checkpoint_name = checkpoint_name or basename | |
with open(basename + '.saver') as fsv: | |
saver_def = tf.train.SaverDef() | |
text_format.Merge(fsv.read(), saver_def) | |
print('Loaded saver def') | |
with gfile.FastGFile(basename + '.graph', 'rb') as f: | |
gd = tf.GraphDef() | |
gd.ParseFromString(f.read()) | |
model.sess.graph.as_default() | |
tf.import_graph_def(gd, name='') | |
print('Imported graph def') | |
model.sess.run(saver_def.restore_op_name, {saver_def.filename_tensor_name: checkpoint_name}) | |
model.x = tf.get_default_graph().get_tensor_by_name('x:0') | |
model.xch = tf.get_default_graph().get_tensor_by_name('xch:0') | |
model.y = tf.get_default_graph().get_tensor_by_name('y:0') | |
model.pkeep = tf.get_default_graph().get_tensor_by_name('pkeep:0') | |
model.best = tf.get_default_graph().get_tensor_by_name('output/ArgMax:0') # X | |
model.probs = tf.get_default_graph().get_tensor_by_name('output/transpose:0') # X | |
try: | |
model.A = tf.get_default_graph().get_tensor_by_name('Loss/transitions:0') | |
print('Found transition matrix in graph, setting crf=True') | |
model.crf = True | |
except: | |
print('Failed to get transition matrix, setting crf=False') | |
model.A = None | |
model.crf = False | |
with open(basename + '.labels', 'r') as f: | |
model.labels = json.load(f) | |
model.word_vocab = {} | |
if os.path.exists(basename + '-word.vocab'): | |
with open(basename + '-word.vocab', 'r') as f: | |
model.word_vocab = json.load(f) | |
with open(basename + '-char.vocab', 'r') as f: | |
model.char_vocab = json.load(f) | |
model.saver = tf.train.Saver(saver_def=saver_def) | |
return model | |
def __init__(self): | |
pass | |
def save_using(self, saver): | |
self.saver = saver | |
def _compute_word_level_loss(self, mask): | |
nc = len(self.labels) | |
# Cross entropy loss | |
cross_entropy = tf.one_hot(self.y, nc, axis=-1) * tf.log(tf.nn.softmax(self.probs)) | |
cross_entropy = -tf.reduce_sum(cross_entropy, reduction_indices=2) | |
cross_entropy *= mask | |
cross_entropy = tf.reduce_sum(cross_entropy, reduction_indices=1) | |
all_loss = tf.reduce_mean(cross_entropy, name="loss") | |
return all_loss | |
def _compute_sentence_level_loss(self): | |
ll, self.A = tf.contrib.crf.crf_log_likelihood(self.probs, self.y, self.lengths) | |
return tf.reduce_mean(-ll) | |
def create_loss(self): | |
with tf.variable_scope("Loss"): | |
gold = tf.cast(self.y, tf.float32) | |
mask = tf.sign(gold) | |
if self.crf is True: | |
print('crf=True, creating SLL') | |
all_loss = self._compute_sentence_level_loss() | |
else: | |
print('crf=False, creating WLL') | |
all_loss = self._compute_word_level_loss(mask) | |
return all_loss | |
def __init__(self): | |
pass | |
def get_vocab(self, vocab_type='word'): | |
return self.word_vocab if vocab_type == 'word' else self.char_vocab | |
def get_labels(self): | |
return self.labels | |
def predict(self, x, xch, lengths): | |
feed_dict = self.make_feed_dict(x, xch, lengths) | |
# We can probably conditionally add the loss here | |
preds = [] | |
if self.crf is True: | |
probv, tranv = self.sess.run([self.probs, self.A], feed_dict=feed_dict) | |
for pij, sl in zip(probv, lengths): | |
unary = pij[:sl] | |
viterbi, _ = tf.contrib.crf.viterbi_decode(unary, tranv) | |
preds.append(viterbi) | |
else: | |
# Get batch (B, T) | |
bestv = self.sess.run(self.best, feed_dict=feed_dict) | |
# Each sentence, probv | |
for pij, sl in zip(bestv, lengths): | |
unary = pij[:sl] | |
preds.append(unary) | |
return preds | |
@staticmethod | |
def create(labels, word_vec, char_vec, **kwargs): | |
model = RNNWordBoWCharTaggerModel() | |
model.sess = kwargs.get('sess', tf.Session()) | |
mxlen = kwargs.get('maxs', 100) | |
maxw = kwargs.get('maxw', 100) | |
wsz = kwargs.get('wsz', 30) | |
hsz = int(kwargs['hsz']) | |
pdrop = kwargs.get('dropout', 0.5) | |
rnntype = kwargs.get('rnntype', 'blstm') | |
nlayers = kwargs.get('layers', 1) | |
model.labels = labels | |
model.crf = bool(kwargs.get('crf', False)) | |
char_dsz = char_vec.dsz | |
nc = len(labels) | |
model.x = tf.placeholder(tf.int32, [None, mxlen], name="x") | |
model.xch = tf.placeholder(tf.int32, [None, mxlen, maxw], name="xch") | |
model.y = tf.placeholder(tf.int32, [None, mxlen], name="y") | |
model.lengths = tf.placeholder(tf.int32, [None], name="lengths") | |
model.pkeep = tf.placeholder(tf.float32, name="pkeep") | |
model.pdrop_value = pdrop | |
model.word_vocab = {} | |
if word_vec is not None: | |
model.word_vocab = word_vec.vocab | |
model.char_vocab = char_vec.vocab | |
seed = np.random.randint(10e8) | |
if word_vec is not None: | |
with tf.name_scope("WordLUT"): | |
Ww = tf.Variable(tf.constant(word_vec.weights, dtype=tf.float32), name="W") | |
we0 = tf.scatter_update(Ww, tf.constant(0, dtype=tf.int32, shape=[1]), tf.zeros(shape=[1, word_vec.dsz])) | |
with tf.control_dependencies([we0]): | |
wembed = tf.nn.embedding_lookup(Ww, model.x, name="embeddings") | |
Wch = tf.Variable(tf.constant(char_vec.weights, dtype=tf.float32), name="Wch") | |
ce0 = tf.scatter_update(Wch, tf.constant(0, dtype=tf.int32, shape=[1]), tf.zeros(shape=[1, char_dsz])) | |
with tf.variable_scope("Chars2Word"): | |
with tf.control_dependencies([ce0]): | |
rnnchar_bt_x_w = tf.reshape(model.xch, [-1, maxw]) | |
cembed = tf.nn.embedding_lookup(Wch, rnnchar_bt_x_w, name="embeddings") | |
word_char = tf.reshape(tf.reduce_sum(cembed, 1), [-1, mxlen, char_dsz]) | |
joint = word_char if word_vec is None else tf.concat(values=[wembed, word_char], axis=2) | |
embedseq = tf.nn.dropout(joint, model.pkeep) | |
with tf.variable_scope("Words2Hidden"): | |
if rnntype == 'blstm': | |
rnnfwd = stacked_lstm(hsz, model.pkeep, nlayers) | |
rnnbwd = stacked_lstm(hsz, model.pkeep, nlayers) | |
rnnout, _ = tf.nn.bidirectional_dynamic_rnn(rnnfwd, rnnbwd, embedseq, sequence_length=model.lengths, dtype=tf.float32) | |
# The output of the BRNN function needs to be joined on the H axis | |
rnnout = tf.concat(axis=2, values=rnnout) | |
else: | |
rnnfwd = stacked_lstm(hsz, model.pkeep, nlayers) | |
rnnout, _ = tf.nn.dynamic_rnn(rnnfwd, embedseq, sequence_length=model.lengths, dtype=tf.float32) | |
with tf.variable_scope("output"): | |
# Converts seq to tensor, back to (B,T,W) | |
hout = hsz | |
if rnntype == 'blstm': | |
hout *= 2 | |
# Flatten from [B x T x H] - > [BT x H] | |
rnnout_bt_x_h = tf.reshape(rnnout, [-1, hout]) | |
init = xavier_initializer(True, seed) | |
with tf.contrib.slim.arg_scope([fully_connected], weights_initializer=init): | |
hidden = tf.nn.dropout(fully_connected(rnnout_bt_x_h, hsz, activation_fn=tf.nn.tanh), model.pkeep) | |
preds = fully_connected(hidden, nc, activation_fn=None, weights_initializer=init) | |
model.probs = tf.reshape(preds, [-1, mxlen, nc]) | |
model.best = tf.argmax(model.probs, 2) | |
return model | |
def create_model(labels, word_embeddings, char_embeddings, **kwargs): | |
tagger = RNNWordBoWCharTaggerModel.create(labels, word_embeddings, char_embeddings, **kwargs) | |
return tagger | |
def load_model(modelname, **kwargs): | |
return RNNWordBoWCharTaggerModel.load(modelname, **kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment