Skip to content

Instantly share code, notes, and snippets.

@kzinmr
Created January 15, 2021 12:13
Show Gist options
  • Save kzinmr/3840c747bdf3990d488066e48bf20a23 to your computer and use it in GitHub Desktop.
Save kzinmr/3840c747bdf3990d488066e48bf20a23 to your computer and use it in GitHub Desktop.
import torch
class WordDropout(torch.nn.Module):
"""
Implementation of word dropout. Randomly drops out entire words (or characters) in embedding space.
"""
def __init__(self, dropout_rate=0.05, inplace=False):
super(WordDropout, self).__init__()
self.dropout_rate = dropout_rate
self.inplace = inplace
def forward(self, x):
if not self.training or not self.dropout_rate:
return x
m = x.data.new(x.size(0), x.size(1), 1).bernoulli_(1 - self.dropout_rate)
mask = torch.autograd.Variable(m, requires_grad=False)
return mask * x
def extra_repr(self):
inplace_str = ", inplace" if self.inplace else ""
return "p={}{}".format(self.dropout_rate, inplace_str)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment