Created
August 8, 2017 04:43
-
-
Save ilblackdragon/779938a95b8d90f30cd94ff38eb1e538 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import collections | |
import torch | |
from torch.autograd import Variable | |
import torch.nn as nn | |
from torch import optim | |
import torch.nn.functional as F | |
import utils | |
class Fold(object): | |
class Node(object): | |
def __init__(self, op, step, index, *args): | |
self.op = op | |
self.step = step | |
self.index = index | |
self.args = args | |
self.split_idx = -1 | |
def split(self, num): | |
"""Split resulting node, if function returns multiple values.""" | |
nodes = [] | |
for idx in range(num): | |
nodes.append(Fold.Node( | |
self.op, self.step, self.index, *self.args)) | |
nodes[-1].split_idx = idx | |
return nodes | |
def __repr__(self): | |
return "[%d:%d]%s" % ( | |
self.step, self.index, self.op) | |
def __init__(self): | |
self.steps = collections.defaultdict( | |
lambda: collections.defaultdict(list)) | |
self.cached_nodes = collections.defaultdict(dict) | |
self.total_nodes = 0 | |
def add(self, op, *args): | |
"""Add op to the fold.""" | |
self.total_nodes += 1 | |
if args not in self.cached_nodes[op]: | |
step = max([0] + [arg.step + 1 for arg in args | |
if isinstance(arg, Fold.Node)]) | |
node = Fold.Node(op, step, len(self.steps[step][op]), *args) | |
self.steps[step][op].append(args) | |
self.cached_nodes[op][args] = node | |
return self.cached_nodes[op][args] | |
def _batch_args(self, arg_lists, values): | |
res = [] | |
for arg in arg_lists: | |
r = [] | |
if isinstance(arg[0], Fold.Node): | |
for x in arg: | |
if x.split_idx >= 0: | |
r.append(values[x.step][x.op][x.split_idx][x.index]) | |
else: | |
r.append(values[x.step][x.op][x.index]) | |
res.append(torch.cat(r, 0)) | |
else: | |
res.append(Variable(torch.LongTensor(arg))) | |
return res | |
def apply(self, nn, nodes): | |
"""Apply current fold to given neural module.""" | |
values = {} | |
for step in sorted(self.steps.keys()): | |
values[step] = {} | |
for op in self.steps[step]: | |
func = getattr(nn, op) | |
batched_args = self._batch_args( | |
zip(*self.steps[step][op]), values) | |
res = func(*batched_args) | |
if isinstance(res, (tuple, list)): | |
values[step][op] = [] | |
for x in res: | |
values[step][op].append( | |
torch.chunk(x, batched_args[0].size()[0]) | |
) | |
else: | |
values[step][op] = torch.chunk( | |
res, batched_args[0].size()[0]) | |
return self._batch_args(nodes, values) | |
if __name__ == "__main__": | |
timer = utils.Timer() | |
timer.start() | |
f = Fold() | |
v1, _ = f.add('value', 1).split(2) | |
v2, _ = f.add('value', 2).split(2) | |
r = v1 | |
for i in range(1000): | |
r = f.add('attr', v1, v2) | |
r = f.add('attr', r, v2) | |
timer.tag('fold') | |
class TestEncoder(nn.Module): | |
def __init__(self): | |
super(TestEncoder, self).__init__() | |
self.embed = nn.Embedding(10, 10) | |
self.out = nn.Linear(20, 10) | |
def value(self, idx): | |
return self.embed(idx), self.embed(idx) | |
def attr(self, left, right): | |
return self.out(torch.cat([left, right], 1)) | |
te = TestEncoder() | |
timer.tag('encoder: created') | |
enc = f.apply(te, [[r]]) | |
timer.tag('encoder: apply') | |
print(enc[0].size()) | |
print(timer.report()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment