Skip to content

Instantly share code, notes, and snippets.

@bartvm
Created July 28, 2016 17:25
Show Gist options
  • Save bartvm/67072495c48eb29553bf7d31a6ebedd3 to your computer and use it in GitHub Desktop.
Save bartvm/67072495c48eb29553bf7d31a6ebedd3 to your computer and use it in GitHub Desktop.
local nn = require 'nn'
Flatten, parent = torch.class('nn.Flatten', 'nn.Container')
function Flatten:__init(flattenDims)
-- This container wraps a module and flattens the first N dimensions before
-- unflattening them again (in this case time and batches)
self.flattenDims = flattenDims
parent.__init(self)
end
function Flatten:updateOutput(input)
self.inputShape = input:size()
self.shape = {1}
for i = 1, self.flattenDims do
self.shape[1] = self.shape[1] * input:size(i)
end
for j = self.flattenDims + 1, #input:size() do
self.shape[j - self.flattenDims + 1] = input:size(j)
end
self.flattened = input:view(table.unpack(self.shape))
local flattenedOutput = self.modules[1]:updateOutput(self.flattened)
self.outputShape = flattenedOutput:size()
local newShape = {}
for i = 1, self.flattenDims do
newShape[i] = input:size(i)
end
for j = self.flattenDims + 1, #flattenedOutput:size() + self.flattenDims - 1 do
newShape[j] = flattenedOutput:size(j - self.flattenDims + 1)
end
self.output = flattenedOutput:view(table.unpack(newShape))
return self.output
end
function Flatten:backward(input, gradOutput)
local flattenedGradInput = self.modules[1]:backward(
input:view(table.unpack(self.shape)), gradOutput:view(self.outputShape))
self.gradInput = flattenedGradInput:view(self.inputShape)
return self.gradInput
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment