Created
September 16, 2016 13:37
-
-
Save lebedov/cb26d28ef9bbb142c393cbf00c1c2fe9 to your computer and use it in GitHub Desktop.
Class for passing data to Torch's nn.StochasticGradient
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- Use the class rock to create a Dataset class | |
-- that can be used by nn.StochasticGradient in Torch | |
local class = require 'class' | |
local Dataset = class('Dataset') | |
function Dataset:__init(inputs, labels) | |
self.inputs = inputs | |
self.labels = labels | |
end | |
function Dataset:size() | |
return self.inputs:size()[1] | |
end | |
-- Try looking keys up in original __index before | |
-- treating the key as an index into self.data | |
local t = Dataset.__index | |
Dataset.__index = function (self, k) | |
if t[k] == nil then | |
return {self.inputs[k], self.labels[k]} | |
else | |
return t[k] | |
end | |
end | |
return {Dataset = Dataset} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment