Created
February 20, 2018 20:41
-
-
Save arunmallya/668e0f31aedb3563c3fa020b4116e8a8 to your computer and use it in GitHub Desktop.
Autograd snippet for Binarizer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
DEFAULT_THRESHOLD = 5e-3 | |
class Binarizer(torch.autograd.Function): | |
"""Binarizes {0, 1} a real valued tensor.""" | |
def __init__(self, threshold=DEFAULT_THRESHOLD): | |
super(Binarizer, self).__init__() | |
self.threshold = threshold | |
def forward(self, inputs): | |
outputs = inputs.clone() | |
outputs[inputs.le(self.threshold)] = 0 | |
outputs[inputs.gt(self.threshold)] = 1 | |
return outputs | |
def backward(self, gradOutput): | |
return gradOutput |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment