Created
December 15, 2017 21:52
-
-
Save soumith/8c73eb07b81c01298203da5537f8f2a3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.autograd import Variable as Var | |
class TreeDecoder(nn.Module): | |
NODE_DICT, NODE_LIST, NODE_STR = 0, 1, 2 | |
def __init__(self, input_size, max_key, max_ident, max_depth, max_length): | |
super(TreeDecoder, self).__init__() | |
self.input_size = input_size | |
self.max_key = max_key | |
self.max_ident = max_ident | |
self.max_depth = max_depth | |
self.max_length = max_length | |
# used to predict (values of list) / (key / val of dict) | |
self.rnn = nn.LSTMCell(input_size + max_key, input_size) | |
# predicts NODE_DICT / NODE_LIST / NODE_STR | |
self.pred_node_type = nn.Linear(input_size, 3) | |
# predicts the string identifier | |
self.pred_ident = nn.Linear(input_size, max_ident) | |
# predicts the key of dict from hidden of rnn | |
self.pred_key = nn.Linear(input_size, max_key) | |
# predicts whether to stop generating or not | |
self.pred_stop = nn.Linear(input_size, 2) | |
# linear transform that projects inp that has one hot back to input_size | |
self.transform = nn.Linear(input_size + max_ident + 3, input_size) | |
self.ident_dict = dict() | |
self.ident_list = [] | |
self.key_dict = dict() | |
self.key_list = [] | |
def key(self, s): | |
if s in self.key_dict: | |
return self.key_dict[s] | |
i = len(self.key_dict) | |
assert i < self.max_key | |
self.key_dict[s] = i | |
self.key_list.append(s) | |
return i | |
def ident(self, s): | |
if s in self.ident_dict: | |
return self.ident_dict[s] | |
i = len(self.ident_dict) | |
assert i < self.max_ident | |
self.ident_dict[s] = i | |
self.ident_list.append(s) | |
return i | |
def unkey(self, i): | |
return self.key_list[i] if i < len(self.key_list) else '<UNK>' | |
def unident(self, i): | |
return self.ident_list[i] if i < len(self.key_list) else '<UNK>' | |
def onehot(self, oh_type, maxval=None): | |
if maxval is None: | |
maxval = self.max_ident + 3 | |
o = self._type.new(1, maxval).zero_() | |
o[0, oh_type] = 1 | |
return Var(o) | |
def forward(self, encoder_out, truth=None): | |
self.loss = 0 | |
self._type = self.pred_key.weight.data | |
self.encoder_out = encoder_out | |
self.depth = self.max_depth | |
# sum over sequence, make it a bag of words | |
bag = encoder_out.sum(dim=1) | |
return self.generate_tree(bag, truth) | |
def generate_tree(self, inp, truth=None): | |
assert truth is not None | |
# if max_depth is reached, then shortcut | |
if self.depth < 0: | |
return None | |
self.depth -= 1 | |
# first, predict node type: | |
node_type_logits = F.log_softmax(self.pred_node_type(inp), dim=1) | |
_, node_type_pred = node_type_logits.max(1) | |
if truth is not None: | |
node_type_truth = self.NODE_DICT if isinstance(truth, dict) else \ | |
self.NODE_LIST if isinstance(truth, list) else \ | |
self.NODE_STR if isinstance(truth, str) or isinstance(truth, int) \ | |
else None | |
try: | |
assert node_type_truth is not None | |
except Exception: | |
print("Node type unknown: " + str(type(truth))) | |
# loss between predicted node and groundtruth | |
self.loss += F.nll_loss(node_type_logits, Var(self._type.new([node_type_truth]).long())) | |
# stay on gold standard path | |
node_type_pred = node_type_truth | |
else: | |
node_type_pred = node_type_pred.data[0] | |
# generate the corresponding type | |
if node_type_pred == self.NODE_STR: | |
next_inp = torch.cat([inp, self.onehot(self.NODE_STR)], 1) | |
out = self.generate_string(self.transform(next_inp), truth) | |
elif node_type_pred == self.NODE_LIST: | |
next_inp = torch.cat([inp, self.onehot(self.NODE_LIST)], 1) | |
out = self.generate_list(self.transform(next_inp), truth) | |
elif node_type_pred == self.NODE_DICT: | |
next_inp = torch.cat([inp, self.onehot(self.NODE_DICT)], 1) | |
out = self.generate_dict(self.transform(next_inp), truth) | |
self.depth += 1 | |
return out | |
def predict_stop(self, inp, truth=None): | |
stop_logits = F.log_softmax(self.pred_stop(inp), dim=1) | |
_, stop_pred = stop_logits.max(1) | |
if truth is not None: | |
truth_v = Var(self._type.new([1 if truth else 0]).long()) | |
self.loss += F.nll_loss(stop_logits, truth_v) | |
stop_pred = 1 if truth else 0 | |
else: | |
stop_pred = stop_pred.data[0] | |
return stop_pred == 1 | |
def generate_string(self, inp, truth=None): | |
ident_logits = F.log_softmax(self.pred_ident(inp), dim=1) | |
_, pred = ident_logits.max(1) | |
if truth is not None: | |
assert isinstance(truth, str) or isinstance(truth, int) | |
truth = Var(self._type.new([self.ident(truth)]).long()) | |
self.loss += F.nll_loss(ident_logits, truth) | |
pred = truth | |
else: | |
pred = self.unident(pred.data[0]) | |
return pred | |
def generate_list(self, inp, truth=None): | |
# list is predicted by an LSTM/GRU, so initialize hidden states | |
hx = Var(self._type.new(1, self.input_size).zero_()) | |
cx = Var(self._type.new(1, self.input_size).zero_()) | |
res = [] | |
for _ in range(self.max_length): | |
inp_rnn = torch.cat([inp, Var(self._type.new(1, self.max_key).zero_())], 1) | |
hx, cx = self.rnn(inp_rnn, (hx, cx)) | |
stop = self.predict_stop(hx, None if truth is None else len(res) == len(truth)) | |
if stop: break | |
truth_item = None if truth is None else (truth[len(res)] or 'None') | |
next_inp = torch.cat([hx, Var(self._type.new(1, self.max_ident + 3).zero_())], 1) | |
tree = self.generate_tree(self.transform(next_inp), truth_item) | |
res.append(tree) | |
return res | |
def generate_dict(self, inp, truth=None): | |
# dict is predicted by an LSTM/GRU, so initialize hidden states | |
hx = Var(self._type.new(1, self.input_size).zero_()) | |
cx = Var(self._type.new(1, self.input_size).zero_()) | |
true_keys = None if truth is None else sorted(truth.keys()) | |
prev_key = 0 | |
res = {} | |
for ii in range(self.max_length): | |
next_inp = torch.cat([inp, self.onehot(prev_key, self.max_key)], 1) | |
hx, cx = self.rnn(next_inp, (hx, cx)) | |
stop_truth = None if truth is None else (ii == len(true_keys)) | |
stop = self.predict_stop(hx, stop_truth) | |
if stop: break | |
key_logits = F.log_softmax(self.pred_key(hx), dim=1) | |
_, key = key_logits.max(1) | |
if truth is not None: | |
key_ii = self.key(true_keys[ii]) | |
self.loss += F.nll_loss(key_logits, Var(self._type.new([key_ii]).long())) | |
key = key_ii | |
else: | |
key = key.data[0] | |
key_str = self.unkey(key) | |
inp_new = self.transform(torch.cat([hx, self.onehot(key + 3)], 1)) | |
truth_new = None if truth is None else (truth[key_str] or 'None') | |
tree = self.generate_tree(inp_new, truth_new) | |
res[key_str] = tree | |
prev_key = key | |
return res | |
# code below here is used for prototyping | |
import pg_query | |
from pg_query import parse_sql | |
from pg_query.printer import RawStream | |
import pdb | |
import zss | |
def tree_to_zss(t): | |
if isinstance(t, int): | |
t = str(t) | |
if isinstance(t, str): | |
return zss.Node(t) | |
elif isinstance(t, list): | |
node = zss.Node('**LIST**') | |
for c in t: | |
node.addkid(tree_to_zss(c)) | |
return node | |
elif isinstance(t, dict): | |
node = zss.Node('**DICT**') | |
for k in sorted(t.keys()): | |
child = zss.Node(k, [tree_to_zss(t[k])]) | |
node.addkid(child) | |
return node | |
assert False | |
def tree_edit_distance(t1, t2): | |
return zss.simple_distance(tree_to_zss(t1), tree_to_zss(t2)) | |
def tree_depth(t): | |
if isinstance(t, dict): | |
return max((tree_depth(v) for v in t.values()))+1 | |
elif isinstance(t, list): | |
return max((tree_depth(v) for v in t))+1 | |
elif isinstance(t, str) or isinstance(t, int): | |
return 1 | |
def tree_length(t): | |
if isinstance(t, dict): | |
l = max((tree_length(v) for v in t.values())) | |
return max(l, len(t)) | |
elif isinstance(t, list): | |
l = max((tree_length(v) for v in t)) | |
return max(l, len(t)) | |
elif isinstance(t, str) or isinstance(t, int): | |
return 1 | |
def tree_str_sorted(t): | |
if isinstance(t, dict): | |
return '{' + ', '.join([str(k) + ': ' + tree_str_sorted(t[k]) for k in sorted(t.keys())]) + '}' | |
elif isinstance(t, list): | |
return '[' + ', '.join([tree_str_sorted(c) for c in t]) + ']' | |
return str([tree_str_sorted(c) for c in t]) | |
elif isinstance(t, str): | |
return t | |
elif isinstance(t, int): | |
return str(t) | |
# test strings | |
test_nl = "what is the capital of the state that borders the state that borders texas" | |
test_sql_string = "SELECT state.capital FROM state WHERE state.state_name IN (SELECT border_info.border FROM border_info WHERE border_info.state_name IN (SELECT border_info.border FROM border_info WHERE border_info.state_name = 'texas'))" | |
test_tree = parse_sql(test_sql_string)[0] | |
def print_traverse(t, indent=0): | |
if isinstance(t, dict): | |
print(' ' * indent, '{') | |
for k, v in t.items(): | |
print(' ' * indent, k) | |
print_traverse(v, indent+2) | |
print(' ' * indent, '}') | |
elif isinstance(t, list): | |
print(' ' * indent, '[') | |
for t0 in t: | |
print_traverse(t0, indent+2) | |
print(' ' * indent, ']') | |
elif isinstance(t, str) or isinstance(t, int): | |
print(' ' * indent, t) | |
else: | |
raise Exception('cannot handle ' + str(t)) | |
if 'stored_params' not in globals(): | |
globals()['stored_params'] = None | |
def test(retrain=False): | |
global stored_params | |
torch.manual_seed(90211) | |
seq_length = 5 | |
input_size = 512 | |
encoder_out = Var(torch.randn(1, seq_length, input_size)) | |
groundtruth = test_tree | |
decoder = TreeDecoder(input_size=input_size, max_key=50, max_ident=50, | |
max_depth=50, max_length=10) | |
opt = torch.optim.Adam(decoder.parameters(), lr=0.001) | |
n_epoch = 50 if retrain else 1 | |
for epoch in range(n_epoch): | |
opt.zero_grad() | |
decoder(encoder_out, groundtruth) | |
decoder.loss.backward() | |
if (epoch + 1) % 10 == 0: | |
print(epoch, decoder.loss.data[0]) | |
opt.step() | |
if (epoch + 1) % 100 == 0: | |
pred = decoder(encoder_out, None) | |
print('edit distance', tree_edit_distance(groundtruth, pred)) | |
if (epoch+1) % 100 == 0: | |
print(RawStream()(pg_query.Node(groundtruth))) | |
print(RawStream()(pg_query.Node(pred))) | |
print(tree_str_sorted(groundtruth)) | |
print(tree_str_sorted(pred)) | |
if __name__ == "__main__": | |
test(True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment