Skip to content

Instantly share code, notes, and snippets.

@ceshine
Last active September 30, 2017 03:46
Show Gist options
  • Save ceshine/e125fcc1280ff4e9d58828145f834f86 to your computer and use it in GitHub Desktop.
Save ceshine/e125fcc1280ff4e9d58828145f834f86 to your computer and use it in GitHub Desktop.
Key Code Blocks of Pytorch RNN Dropout Implementation
# https://github.com/salesforce/awd-lstm-lm/blob/dfd3cb0235d2caf2847a4d53e1cbd495b781b5d2/weight_drop.py#L5
class WeightDrop(torch.nn.Module):
def __init__(self, module, weights, dropout=0, variational=False):
# ...
self._setup()
# ...
def _setup(self):
# Terrible temporary solution to an issue regarding compacting weights re: CUDNN RNN
if issubclass(type(self.module), torch.nn.RNNBase):
self.module.flatten_parameters = self.widget_demagnetizer_y2k_edition
for name_w in self.weights:
print('Applying weight drop of {} to {}'.format(self.dropout, name_w))
w = getattr(self.module, name_w)
del self.module._parameters[name_w]
self.module.register_parameter(name_w + '_raw', Parameter(w.data))
def _setweights(self):
for name_w in self.weights:
raw_w = getattr(self.module, name_w + '_raw')
w = None
if self.variational:
mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1))
if raw_w.is_cuda: mask = mask.cuda()
mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True)
w = mask.expand_as(raw_w) * raw_w
else:
w = torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training)
setattr(self.module, name_w, w)
def forward(self, *args):
self._setweights()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment