Created
October 13, 2015 20:34
-
-
Save jcjohnson/cfdf1e52dcc2c6b1050a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'nn' | |
require 'TimesTwo' | |
local times_two = nn.TimesTwo() | |
local input = torch.randn(4, 5) | |
local output = times_two:forward(input) | |
print('here is input:') | |
print(input) | |
print('here is output:') | |
print(output) | |
local gradOutput = torch.randn(4, 5) | |
local gradInput = times_two:backward(input, gradOutput) | |
print('here is gradOutput:') | |
print(gradOutput) | |
print('here is gradInput') | |
print(gradInput) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'nn' | |
local times_two, parent = torch.class('nn.TimesTwo', 'nn.Module') | |
function times_two:__init() | |
parent.__init(self) | |
end | |
function times_two:updateOutput(input) | |
self.output:mul(input, 2) | |
return self.output | |
end | |
function times_two:updateGradInput(input, gradOutput) | |
self.gradInput:mul(gradOutput, 2) | |
return self.gradInput | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment