-
-
Save rosinality/0cdd8d6adb8463961f50bd1845faddf8 to your computer and use it in GitHub Desktop.
| import torch | |
| from torch import nn | |
| from torch.autograd import Variable | |
| class AdaptiveSoftmax(nn.Module): | |
| def __init__(self, input_size, cutoff): | |
| super().__init__() | |
| self.input_size = input_size | |
| self.cutoff = cutoff | |
| self.output_size = cutoff[0] + len(cutoff) - 1 | |
| self.head = nn.Linear(input_size, self.output_size) | |
| self.tail = nn.ModuleList() | |
| for i in range(len(cutoff) - 1): | |
| seq = nn.Sequential( | |
| nn.Linear(input_size, input_size // 4 ** i, False), | |
| nn.Linear(input_size // 4 ** i, cutoff[i + 1] - cutoff[i], False) | |
| ) | |
| self.tail.append(seq) | |
| def reset(self, init=0.1): | |
| self.head.weight.data.uniform_(-init, init) | |
| for tail in self.tail: | |
| tail[0].weight.data.uniform_(-init, init) | |
| tail[1].weight.data.uniform_(-init, init) | |
| def set_target(self, target): | |
| self.id = [] | |
| for i in range(len(self.cutoff) - 1): | |
| mask = target.ge(self.cutoff[i]).mul(target.lt(self.cutoff[i + 1])) | |
| if mask.sum() > 0: | |
| self.id.append(Variable(mask.float().nonzero().squeeze(1))) | |
| else: | |
| self.id.append(None) | |
| def forward(self, input): | |
| output = [self.head(input)] | |
| for i in range(len(self.id)): | |
| if self.id[i] is not None: | |
| output.append(self.tail[i](input.index_select(0, self.id[i]))) | |
| else: | |
| output.append(None) | |
| return output | |
| def log_prob(self, input): | |
| lsm = nn.LogSoftmax().cuda() | |
| head_out = self.head(input) | |
| batch_size = head_out.size(0) | |
| prob = torch.zeros(batch_size, self.cutoff[-1]).cuda() | |
| lsm_head = lsm(head_out) | |
| prob.narrow(1, 0, self.output_size).add_(lsm_head.narrow(1, 0, self.output_size).data) | |
| for i in range(len(self.tail)): | |
| pos = self.cutoff[i] | |
| i_size = self.cutoff[i + 1] - pos | |
| buffer = lsm_head.narrow(1, self.cutoff[0] + i, 1) | |
| buffer = buffer.expand(batch_size, i_size) | |
| lsm_tail = lsm(self.tail[i](input)) | |
| prob.narrow(1, pos, i_size).copy_(buffer.data).add_(lsm_tail.data) | |
| return prob | |
| class AdaptiveLoss(nn.Module): | |
| def __init__(self, cutoff): | |
| super().__init__() | |
| self.cutoff = cutoff | |
| self.criterions = nn.ModuleList() | |
| for i in self.cutoff: | |
| self.criterions.append(nn.CrossEntropyLoss(size_average=False)) | |
| def remap_target(self, target): | |
| new_target = [target.clone()] | |
| for i in range(len(self.cutoff) - 1): | |
| mask = target.ge(self.cutoff[i]).mul(target.lt(self.cutoff[i + 1])) | |
| new_target[0][mask] = self.cutoff[0] + i | |
| if mask.sum() > 0: | |
| new_target.append(target[mask].add(-self.cutoff[i])) | |
| else: | |
| new_target.append(None) | |
| return new_target | |
| def forward(self, input, target): | |
| batch_size = input[0].size(0) | |
| target = self.remap_target(target.data) | |
| output = 0.0 | |
| for i in range(len(input)): | |
| if input[i] is not None: | |
| assert(target[i].min() >= 0 and target[i].max() <= input[i].size(1)) | |
| criterion = self.criterions[i] | |
| output += criterion(input[i], Variable(target[i])) | |
| output /= batch_size | |
| return output |
Line 108: Should be if input[i] is not None:
Sorry for late reply...I found that the comments on gist gives no notifications.
jerrybai1995: That's it. index_select works by selecting N-th elements on input sequence specified by id tensor.
temporaer: Yes, I corrected it.
I got error as follows:
File "text8.py", line 121, in
train()
File "text8.py", line 78, in train
loss = criterion(output, Y_var)
File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 357, in call
result = self.forward(*input, **kwargs)
File "/AdaptiveSoftmaxPyTorch/adasoft.py", line 114, in forward
output += self.criterion(input[i], Variable(target[i]))
File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 357, in call
result = self.forward(*input, **kwargs)
File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 71, in forward
raise NotImplementedError
NotImplementedError
Any idea?
Line 48: Shouldn't it be
output.append(self.tail[i](input.index_select(1, self.id[i])))instead of dim 0? I'm assuminginputhas dimensionN x input_size.