Created
August 11, 2012 00:14
-
-
Save jaberg/3319213 to your computer and use it in GitHub Desktop.
LuaTorchTheanoPython
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import theano | |
import numpy as np | |
class TorchThunk(object): | |
def __init__(self, op, node, storage, compute, no_recycling): | |
self.op = op | |
self.node = node | |
self.storage = storage | |
self.compute = compute | |
self.no_recycling = no_recycling | |
# lupa allocates lua stack thing | |
# | |
def __call__(self): | |
x = self.storage[self.node.inputs[0]][0] | |
y = self.storage[self.node.outputs[0]][0] | |
# fill output buffers from lua | |
f(x.ctypes.data, x.size, | |
y.ctypes.data, y.size) | |
class TorchLinear(theano.Op): | |
def __init__(self, n_inputs, n_hid): | |
self.n_inputs = n_inputs | |
self.n_hid = n_hid | |
def lua_alloc_outputs(self, inputs, outputs): | |
X, w, b = inputs | |
out, = outputs | |
if out.shape != (X.shape[0], w.shape[1]): | |
out = np.zeros((X.shape[0], w.shape[1])) | |
return out, | |
def lua_code(self, node, name, inputs, outputs): | |
return """ | |
asdf = function(w_info, b_info, out_info, in_info) | |
nn.Linear.__init = function() end | |
module = nn.Linear() | |
module.weight = torch_from_numpy(w_info) | |
module.bias = torch_from_numpy(b_info) | |
module.output = torch_from_numpy(out_info) | |
module:forward(torch_from_numpy(in_info)) | |
end | |
""" | |
def make_thunk(self, node, storage_map, compute_map, no_recycling=[]): | |
return TorchThunk(self, node, storage_map, compute_map, no_recycling) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment