Created
June 11, 2019 03:45
-
-
Save ashunigion/63e00810d137b80ffbf6b9adbc56d63d to your computer and use it in GitHub Desktop.
The skip gram model with negative sampling
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SkipGramNeg(nn.Module): | |
def __init__(self, n_vocab, n_embed, noise_dist=None): | |
super().__init__() | |
self.n_vocab = n_vocab | |
self.n_embed = n_embed | |
self.noise_dist = noise_dist | |
# define embedding layers for input and output words | |
self.in_embed = nn.Embedding(n_vocab, n_embed) | |
self.out_embed = nn.Embedding(n_vocab,n_embed) | |
# Initialize both embedding tables with uniform distribution | |
self.in_embed.weight.data.uniform_(-1,1) | |
self.out_embed.weight.data.uniform_(-1,1) | |
def forward_input(self, input_words): | |
# return input vector embeddings | |
input_vector = self.in_embed(input_words) | |
return input_vector | |
def forward_output(self, output_words): | |
# return output vector embeddings | |
output_vector = self.out_embed(output_words) | |
return output_vector | |
def forward_noise(self, batch_size, n_samples): | |
""" Generate noise vectors with shape (batch_size, n_samples, n_embed)""" | |
if self.noise_dist is None: | |
# Sample words uniformly | |
noise_dist = torch.ones(self.n_vocab) | |
else: | |
noise_dist = self.noise_dist | |
# Sample words from our noise distribution | |
noise_words = torch.multinomial(noise_dist, | |
batch_size * n_samples, | |
replacement=True) | |
device = "cuda" if model.out_embed.weight.is_cuda else "cpu" | |
noise_words = noise_words.to(device) | |
## TODO: get the noise embeddings | |
# reshape the embeddings so that they have dims (batch_size, n_samples, n_embed) | |
noise_words = self.out_embed(noise_words).view(batch_size, n_samples, self.n_embed) | |
return noise_words |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment