Created
May 18, 2015 14:54
-
-
Save guy4261/7d6af7292114c6f44dc2 to your computer and use it in GitHub Desktop.
Implementing an iterator that goes through the pixels in the torch mnist dataset
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
train = torch.load('mnist.t7/train_32x32.t7', 'ascii') | |
train.data = train.data:type(torch.getdefaulttensortype()) | |
function pixels(dataset) | |
local dimensions = (#dataset.data):totable() | |
local data = dataset.data | |
local d1 = 1 | |
local d2 = 1 | |
local d3 = 1 | |
local d4 = 1 | |
return function() | |
while (d1 <= dimensions[1]) do | |
while (d2 <= dimensions[2]) do | |
while (d3 <= dimensions[3]) do | |
while (d4 <= dimensions[4]) do | |
cur = data[d1][d2][d3][d4] | |
d4 = d4 + 1 | |
return cur | |
end | |
d3 = d3 + 1 | |
d4 = 1 | |
end | |
d2 = d2 + 1 | |
d3 = 1 | |
end | |
d1 = d1 + 1 | |
d2 = 1 | |
end | |
end | |
end | |
sum = 0 | |
for pixel in pixels(train) do | |
sum = sum + pixel | |
end | |
print(sum) | |
print(train.data:sum()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment