-
-
Save rosinality/0cdd8d6adb8463961f50bd1845faddf8 to your computer and use it in GitHub Desktop.
import torch | |
from torch import nn | |
from torch.autograd import Variable | |
class AdaptiveSoftmax(nn.Module): | |
def __init__(self, input_size, cutoff): | |
super().__init__() | |
self.input_size = input_size | |
self.cutoff = cutoff | |
self.output_size = cutoff[0] + len(cutoff) - 1 | |
self.head = nn.Linear(input_size, self.output_size) | |
self.tail = nn.ModuleList() | |
for i in range(len(cutoff) - 1): | |
seq = nn.Sequential( | |
nn.Linear(input_size, input_size // 4 ** i, False), | |
nn.Linear(input_size // 4 ** i, cutoff[i + 1] - cutoff[i], False) | |
) | |
self.tail.append(seq) | |
def reset(self, init=0.1): | |
self.head.weight.data.uniform_(-init, init) | |
for tail in self.tail: | |
tail[0].weight.data.uniform_(-init, init) | |
tail[1].weight.data.uniform_(-init, init) | |
def set_target(self, target): | |
self.id = [] | |
for i in range(len(self.cutoff) - 1): | |
mask = target.ge(self.cutoff[i]).mul(target.lt(self.cutoff[i + 1])) | |
if mask.sum() > 0: | |
self.id.append(Variable(mask.float().nonzero().squeeze(1))) | |
else: | |
self.id.append(None) | |
def forward(self, input): | |
output = [self.head(input)] | |
for i in range(len(self.id)): | |
if self.id[i] is not None: | |
output.append(self.tail[i](input.index_select(0, self.id[i]))) | |
else: | |
output.append(None) | |
return output | |
def log_prob(self, input): | |
lsm = nn.LogSoftmax().cuda() | |
head_out = self.head(input) | |
batch_size = head_out.size(0) | |
prob = torch.zeros(batch_size, self.cutoff[-1]).cuda() | |
lsm_head = lsm(head_out) | |
prob.narrow(1, 0, self.output_size).add_(lsm_head.narrow(1, 0, self.output_size).data) | |
for i in range(len(self.tail)): | |
pos = self.cutoff[i] | |
i_size = self.cutoff[i + 1] - pos | |
buffer = lsm_head.narrow(1, self.cutoff[0] + i, 1) | |
buffer = buffer.expand(batch_size, i_size) | |
lsm_tail = lsm(self.tail[i](input)) | |
prob.narrow(1, pos, i_size).copy_(buffer.data).add_(lsm_tail.data) | |
return prob | |
class AdaptiveLoss(nn.Module): | |
def __init__(self, cutoff): | |
super().__init__() | |
self.cutoff = cutoff | |
self.criterions = nn.ModuleList() | |
for i in self.cutoff: | |
self.criterions.append(nn.CrossEntropyLoss(size_average=False)) | |
def remap_target(self, target): | |
new_target = [target.clone()] | |
for i in range(len(self.cutoff) - 1): | |
mask = target.ge(self.cutoff[i]).mul(target.lt(self.cutoff[i + 1])) | |
new_target[0][mask] = self.cutoff[0] + i | |
if mask.sum() > 0: | |
new_target.append(target[mask].add(-self.cutoff[i])) | |
else: | |
new_target.append(None) | |
return new_target | |
def forward(self, input, target): | |
batch_size = input[0].size(0) | |
target = self.remap_target(target.data) | |
output = 0.0 | |
for i in range(len(input)): | |
if input[i] is not None: | |
assert(target[i].min() >= 0 and target[i].max() <= input[i].size(1)) | |
criterion = self.criterions[i] | |
output += criterion(input[i], Variable(target[i])) | |
output /= batch_size | |
return output |
Line 108: Should be if input[i] is not None:
Sorry for late reply...I found that the comments on gist gives no notifications.
jerrybai1995: That's it. index_select works by selecting N-th elements on input sequence specified by id tensor.
temporaer: Yes, I corrected it.
I got error as follows:
File "text8.py", line 121, in
train()
File "text8.py", line 78, in train
loss = criterion(output, Y_var)
File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 357, in call
result = self.forward(*input, **kwargs)
File "/AdaptiveSoftmaxPyTorch/adasoft.py", line 114, in forward
output += self.criterion(input[i], Variable(target[i]))
File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 357, in call
result = self.forward(*input, **kwargs)
File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 71, in forward
raise NotImplementedError
NotImplementedError
Any idea?
Line 48: Shouldn't it be
output.append(self.tail[i](input.index_select(1, self.id[i])))
instead of dim 0? I'm assuminginput
has dimensionN x input_size
.