Last active
May 3, 2016 15:23
-
-
Save toshi-k/07ac1993b6f8b7fd47e33a9fd25b4bff to your computer and use it in GitHub Desktop.
Random Feature Extractor (Torch 7)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
local RandomFeatureExtractor, Parent = torch.class('nn.RandomFeatureExtractor', 'nn.Module') | |
function RandomFeatureExtractor:__init(inputSize, outputSize, kmin, kmax) | |
Parent.__init(self) | |
self.mask = torch.Tensor(outputSize, inputSize):zero() | |
for i = 1,outputSize do | |
local num_samp = math.random(kmin, kmax, 1) | |
local index_samp = torch.randperm(inputSize) | |
for j = 1,num_samp do | |
self.mask[{{i},{index_samp[j]}}] = 1 | |
end | |
end | |
self.inputSize = inputSize | |
self.outputSize = outputSize | |
self.kmin = kmin | |
self.kmax = kmax | |
self.output = torch.Tensor() | |
self.gradInput = torch.Tensor() | |
end | |
function RandomFeatureExtractor:updateOutput(input) | |
if input:dim() == 1 then | |
self.output:resize(self.outputSize) | |
self.output:mv(self.mask, input) | |
else | |
self.batchSize = input:size(1) | |
self.output:resize(self.batchSize, self.outputSize) | |
self.output:mm(input, self.mask:t()) | |
end | |
return self.output | |
end | |
function RandomFeatureExtractor:updateGradInput(input, gradOutput) | |
if input:dim() == 1 then | |
self.gradInput:resizeAs(input) | |
self.gradInput:mv(self.mask:t(), gradOutput) | |
else | |
self.batchSize = input:size(1) | |
self.gradInput:resize(self.batchSize, self.inputSize) | |
self.gradInput:mm(gradOutput, self.mask) | |
end | |
return self.gradInput | |
end | |
function RandomFeatureExtractor:__tostring__() | |
return torch.type(self) .. | |
string.format('(%d -> %d, kmin: %d, kmax: %d)', self.inputSize, self.outputSize, self.kmin, self.kmax) | |
end | |
--[[ | |
<<References>> | |
[1] 12th solution for the Otto Group Product Classification Challenge on Kaggle. | |
tks0123456789 | |
https://github.com/tks0123456789/kaggle-Otto | |
--]] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment