Last active
January 14, 2020 07:41
-
-
Save mjjimenez/900e7356766e988179bc47fe48128f9d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
torch::jit::script::Module forwardModel = torch::jit::load(forwardFilePath.UTF8String); | |
torch::jit::script::Module backwardModel = torch::jit::load(backwardFilePath.UTF8String); | |
auto w = torch::IValue(torch::full({10}, 1.0f)); | |
torch::IValue loss; | |
//Create train x | |
std::vector<torch::jit::IValue> train_x; | |
for (int i = 0; i < 500; i++) { | |
train_x.push_back(torch::rand({10})); | |
} | |
//Create train y | |
std::vector<torch::jit::IValue> train_y; | |
for (int i = 0; i < 500; i++) { | |
train_y.push_back(torch::rand({1})); | |
} | |
for (int i = 0; i < 500; i++) { | |
for(unsigned i = 0; i < train_x.size(); ++i) { | |
auto x = train_x[i]; | |
auto y = train_y[i]; | |
torch::autograd::AutoGradMode guard(false); | |
at::AutoNonVariableTypeMode non_var_type_mode(true); | |
std::vector<torch::jit::IValue> forwardArgs; | |
forwardArgs.push_back(x); | |
forwardArgs.push_back(y); | |
forwardArgs.push_back(w); | |
loss = forwardModel.forward(forwardArgs); | |
auto tensor = loss.toTensor(); | |
auto tensor_a = tensor.accessor<float, 1>(); | |
auto loss_value = tensor_a[0]; | |
std::cout << loss_value << std::endl; | |
torch::IValue loss_param = torch::IValue(torch::tensor(loss_value)); | |
w = backwardModel.forward({w, loss_param, x}); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ForwardModule(torch.nn.Module): | |
def forward(self,x,y,W): | |
return torch.sum(W*x,axis=0)-y | |
class BackwardModule(torch.nn.Module): | |
def forward(self,W,loss,x): | |
g = 2*loss*x | |
return W - 0.00001*g | |
forward_model = ForwardModule() | |
backward_model = BackwardModule() | |
loss = torch.rand(1) | |
W = torch.rand(10) | |
x = torch.rand(10) | |
y = torch.rand(1) | |
forward_model = torch.jit.trace(forward_model,[x,y,W]) | |
forward_model.save("forward_model.pt") | |
backward_model = torch.jit.trace(backward_model, [W,loss,x]) | |
backward_model.save("backward_model.pt") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment