local RandomFeatureExtractor, Parent = torch.class('nn.RandomFeatureExtractor', 'nn.Module')

function RandomFeatureExtractor:__init(inputSize, outputSize, kmin, kmax)
	Parent.__init(self)

	self.mask = torch.Tensor(outputSize, inputSize):zero()

	for i = 1,outputSize do
		local num_samp = math.random(kmin, kmax, 1)
		local index_samp = torch.randperm(inputSize)
		for j = 1,num_samp do
			self.mask[{{i},{index_samp[j]}}] = 1
		end
	end

	self.inputSize = inputSize
	self.outputSize = outputSize
	self.kmin = kmin
	self.kmax = kmax
	
	self.output = torch.Tensor()
	self.gradInput = torch.Tensor()
end

function RandomFeatureExtractor:updateOutput(input)

	if input:dim() == 1 then
		self.output:resize(self.outputSize)
		self.output:mv(self.mask, input)
	else
		self.batchSize = input:size(1)
		self.output:resize(self.batchSize, self.outputSize)
		self.output:mm(input, self.mask:t())
	end

	return self.output
end

function RandomFeatureExtractor:updateGradInput(input, gradOutput)

	if input:dim() == 1 then
		self.gradInput:resizeAs(input)
		self.gradInput:mv(self.mask:t(), gradOutput)
	else
		self.batchSize = input:size(1)
		self.gradInput:resize(self.batchSize, self.inputSize)
		self.gradInput:mm(gradOutput, self.mask)
	end

	return self.gradInput
end

function RandomFeatureExtractor:__tostring__()
	return torch.type(self) ..
	string.format('(%d -> %d, kmin: %d, kmax: %d)', self.inputSize, self.outputSize, self.kmin, self.kmax)
end

--[[
 <<References>>
  [1] 12th solution for the Otto Group Product Classification Challenge on Kaggle.
 	tks0123456789
 	https://github.com/tks0123456789/kaggle-Otto
--]]