Skip to content

Instantly share code, notes, and snippets.

@soumith
Created December 15, 2017 21:52
Show Gist options
  • Save soumith/8c73eb07b81c01298203da5537f8f2a3 to your computer and use it in GitHub Desktop.
Save soumith/8c73eb07b81c01298203da5537f8f2a3 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable as Var
class TreeDecoder(nn.Module):
NODE_DICT, NODE_LIST, NODE_STR = 0, 1, 2
def __init__(self, input_size, max_key, max_ident, max_depth, max_length):
super(TreeDecoder, self).__init__()
self.input_size = input_size
self.max_key = max_key
self.max_ident = max_ident
self.max_depth = max_depth
self.max_length = max_length
# used to predict (values of list) / (key / val of dict)
self.rnn = nn.LSTMCell(input_size + max_key, input_size)
# predicts NODE_DICT / NODE_LIST / NODE_STR
self.pred_node_type = nn.Linear(input_size, 3)
# predicts the string identifier
self.pred_ident = nn.Linear(input_size, max_ident)
# predicts the key of dict from hidden of rnn
self.pred_key = nn.Linear(input_size, max_key)
# predicts whether to stop generating or not
self.pred_stop = nn.Linear(input_size, 2)
# linear transform that projects inp that has one hot back to input_size
self.transform = nn.Linear(input_size + max_ident + 3, input_size)
self.ident_dict = dict()
self.ident_list = []
self.key_dict = dict()
self.key_list = []
def key(self, s):
if s in self.key_dict:
return self.key_dict[s]
i = len(self.key_dict)
assert i < self.max_key
self.key_dict[s] = i
self.key_list.append(s)
return i
def ident(self, s):
if s in self.ident_dict:
return self.ident_dict[s]
i = len(self.ident_dict)
assert i < self.max_ident
self.ident_dict[s] = i
self.ident_list.append(s)
return i
def unkey(self, i):
return self.key_list[i] if i < len(self.key_list) else '<UNK>'
def unident(self, i):
return self.ident_list[i] if i < len(self.key_list) else '<UNK>'
def onehot(self, oh_type, maxval=None):
if maxval is None:
maxval = self.max_ident + 3
o = self._type.new(1, maxval).zero_()
o[0, oh_type] = 1
return Var(o)
def forward(self, encoder_out, truth=None):
self.loss = 0
self._type = self.pred_key.weight.data
self.encoder_out = encoder_out
self.depth = self.max_depth
# sum over sequence, make it a bag of words
bag = encoder_out.sum(dim=1)
return self.generate_tree(bag, truth)
def generate_tree(self, inp, truth=None):
assert truth is not None
# if max_depth is reached, then shortcut
if self.depth < 0:
return None
self.depth -= 1
# first, predict node type:
node_type_logits = F.log_softmax(self.pred_node_type(inp), dim=1)
_, node_type_pred = node_type_logits.max(1)
if truth is not None:
node_type_truth = self.NODE_DICT if isinstance(truth, dict) else \
self.NODE_LIST if isinstance(truth, list) else \
self.NODE_STR if isinstance(truth, str) or isinstance(truth, int) \
else None
try:
assert node_type_truth is not None
except Exception:
print("Node type unknown: " + str(type(truth)))
# loss between predicted node and groundtruth
self.loss += F.nll_loss(node_type_logits, Var(self._type.new([node_type_truth]).long()))
# stay on gold standard path
node_type_pred = node_type_truth
else:
node_type_pred = node_type_pred.data[0]
# generate the corresponding type
if node_type_pred == self.NODE_STR:
next_inp = torch.cat([inp, self.onehot(self.NODE_STR)], 1)
out = self.generate_string(self.transform(next_inp), truth)
elif node_type_pred == self.NODE_LIST:
next_inp = torch.cat([inp, self.onehot(self.NODE_LIST)], 1)
out = self.generate_list(self.transform(next_inp), truth)
elif node_type_pred == self.NODE_DICT:
next_inp = torch.cat([inp, self.onehot(self.NODE_DICT)], 1)
out = self.generate_dict(self.transform(next_inp), truth)
self.depth += 1
return out
def predict_stop(self, inp, truth=None):
stop_logits = F.log_softmax(self.pred_stop(inp), dim=1)
_, stop_pred = stop_logits.max(1)
if truth is not None:
truth_v = Var(self._type.new([1 if truth else 0]).long())
self.loss += F.nll_loss(stop_logits, truth_v)
stop_pred = 1 if truth else 0
else:
stop_pred = stop_pred.data[0]
return stop_pred == 1
def generate_string(self, inp, truth=None):
ident_logits = F.log_softmax(self.pred_ident(inp), dim=1)
_, pred = ident_logits.max(1)
if truth is not None:
assert isinstance(truth, str) or isinstance(truth, int)
truth = Var(self._type.new([self.ident(truth)]).long())
self.loss += F.nll_loss(ident_logits, truth)
pred = truth
else:
pred = self.unident(pred.data[0])
return pred
def generate_list(self, inp, truth=None):
# list is predicted by an LSTM/GRU, so initialize hidden states
hx = Var(self._type.new(1, self.input_size).zero_())
cx = Var(self._type.new(1, self.input_size).zero_())
res = []
for _ in range(self.max_length):
inp_rnn = torch.cat([inp, Var(self._type.new(1, self.max_key).zero_())], 1)
hx, cx = self.rnn(inp_rnn, (hx, cx))
stop = self.predict_stop(hx, None if truth is None else len(res) == len(truth))
if stop: break
truth_item = None if truth is None else (truth[len(res)] or 'None')
next_inp = torch.cat([hx, Var(self._type.new(1, self.max_ident + 3).zero_())], 1)
tree = self.generate_tree(self.transform(next_inp), truth_item)
res.append(tree)
return res
def generate_dict(self, inp, truth=None):
# dict is predicted by an LSTM/GRU, so initialize hidden states
hx = Var(self._type.new(1, self.input_size).zero_())
cx = Var(self._type.new(1, self.input_size).zero_())
true_keys = None if truth is None else sorted(truth.keys())
prev_key = 0
res = {}
for ii in range(self.max_length):
next_inp = torch.cat([inp, self.onehot(prev_key, self.max_key)], 1)
hx, cx = self.rnn(next_inp, (hx, cx))
stop_truth = None if truth is None else (ii == len(true_keys))
stop = self.predict_stop(hx, stop_truth)
if stop: break
key_logits = F.log_softmax(self.pred_key(hx), dim=1)
_, key = key_logits.max(1)
if truth is not None:
key_ii = self.key(true_keys[ii])
self.loss += F.nll_loss(key_logits, Var(self._type.new([key_ii]).long()))
key = key_ii
else:
key = key.data[0]
key_str = self.unkey(key)
inp_new = self.transform(torch.cat([hx, self.onehot(key + 3)], 1))
truth_new = None if truth is None else (truth[key_str] or 'None')
tree = self.generate_tree(inp_new, truth_new)
res[key_str] = tree
prev_key = key
return res
# code below here is used for prototyping
import pg_query
from pg_query import parse_sql
from pg_query.printer import RawStream
import pdb
import zss
def tree_to_zss(t):
if isinstance(t, int):
t = str(t)
if isinstance(t, str):
return zss.Node(t)
elif isinstance(t, list):
node = zss.Node('**LIST**')
for c in t:
node.addkid(tree_to_zss(c))
return node
elif isinstance(t, dict):
node = zss.Node('**DICT**')
for k in sorted(t.keys()):
child = zss.Node(k, [tree_to_zss(t[k])])
node.addkid(child)
return node
assert False
def tree_edit_distance(t1, t2):
return zss.simple_distance(tree_to_zss(t1), tree_to_zss(t2))
def tree_depth(t):
if isinstance(t, dict):
return max((tree_depth(v) for v in t.values()))+1
elif isinstance(t, list):
return max((tree_depth(v) for v in t))+1
elif isinstance(t, str) or isinstance(t, int):
return 1
def tree_length(t):
if isinstance(t, dict):
l = max((tree_length(v) for v in t.values()))
return max(l, len(t))
elif isinstance(t, list):
l = max((tree_length(v) for v in t))
return max(l, len(t))
elif isinstance(t, str) or isinstance(t, int):
return 1
def tree_str_sorted(t):
if isinstance(t, dict):
return '{' + ', '.join([str(k) + ': ' + tree_str_sorted(t[k]) for k in sorted(t.keys())]) + '}'
elif isinstance(t, list):
return '[' + ', '.join([tree_str_sorted(c) for c in t]) + ']'
return str([tree_str_sorted(c) for c in t])
elif isinstance(t, str):
return t
elif isinstance(t, int):
return str(t)
# test strings
test_nl = "what is the capital of the state that borders the state that borders texas"
test_sql_string = "SELECT state.capital FROM state WHERE state.state_name IN (SELECT border_info.border FROM border_info WHERE border_info.state_name IN (SELECT border_info.border FROM border_info WHERE border_info.state_name = 'texas'))"
test_tree = parse_sql(test_sql_string)[0]
def print_traverse(t, indent=0):
if isinstance(t, dict):
print(' ' * indent, '{')
for k, v in t.items():
print(' ' * indent, k)
print_traverse(v, indent+2)
print(' ' * indent, '}')
elif isinstance(t, list):
print(' ' * indent, '[')
for t0 in t:
print_traverse(t0, indent+2)
print(' ' * indent, ']')
elif isinstance(t, str) or isinstance(t, int):
print(' ' * indent, t)
else:
raise Exception('cannot handle ' + str(t))
if 'stored_params' not in globals():
globals()['stored_params'] = None
def test(retrain=False):
global stored_params
torch.manual_seed(90211)
seq_length = 5
input_size = 512
encoder_out = Var(torch.randn(1, seq_length, input_size))
groundtruth = test_tree
decoder = TreeDecoder(input_size=input_size, max_key=50, max_ident=50,
max_depth=50, max_length=10)
opt = torch.optim.Adam(decoder.parameters(), lr=0.001)
n_epoch = 50 if retrain else 1
for epoch in range(n_epoch):
opt.zero_grad()
decoder(encoder_out, groundtruth)
decoder.loss.backward()
if (epoch + 1) % 10 == 0:
print(epoch, decoder.loss.data[0])
opt.step()
if (epoch + 1) % 100 == 0:
pred = decoder(encoder_out, None)
print('edit distance', tree_edit_distance(groundtruth, pred))
if (epoch+1) % 100 == 0:
print(RawStream()(pg_query.Node(groundtruth)))
print(RawStream()(pg_query.Node(pred)))
print(tree_str_sorted(groundtruth))
print(tree_str_sorted(pred))
if __name__ == "__main__":
test(True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment