Created
September 12, 2017 16:58
-
-
Save ikhlestov/a7c829c46463a84df7fa7ff6c287db13 to your computer and use it in GitHub Desktop.
pytorch: self defined layers
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class MyFunction(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, input): | |
ctx.save_for_backward(input) | |
output = torch.sign(input) | |
return output | |
@staticmethod | |
def backward(ctx, grad_output): | |
# saved tensors - tuple of tensors, so we need get first | |
input, = ctx.saved_variables | |
grad_output[input.ge(1)] = 0 | |
grad_output[input.le(-1)] = 0 | |
return grad_output | |
# usage | |
x = torch.randn(10, 20) | |
y = MyFunction.apply(x) | |
# or | |
my_func = MyFunction.apply | |
y = my_func(x) | |
# and if we want to use inside nn.Module | |
class MyFunctionModule(torch.nn.Module): | |
def forward(self, x): | |
return MyFunction.apply(x) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment