Last active
August 5, 2016 12:09
-
-
Save karandwivedi42/4d217ff054daf09c93b83096093ac8a1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'image' | |
-- Image Transformations | |
local M = {} | |
M.ColorNormalize = function() | |
return function(img) | |
img = img:clone() | |
local x = meanstd | |
for i=1,3 do | |
img[i]:add(0.5) | |
img[i]:div(10) | |
end | |
return img | |
end | |
end | |
M.Scale = function(size) | |
return function(input) | |
local w, h = input:size(3), input:size(2) | |
if (w <= h and w == size) or (h <= w and h == size) then | |
return input | |
end | |
if w < h then | |
return image.scale(input, size, h/w * size) | |
else | |
return image.scale(input, w/h * size, size) | |
end | |
end | |
end | |
M.HorizontalFlip = function(prob) | |
return function(input) | |
if torch.uniform() < prob then | |
input = image.hflip(input) | |
end | |
return input | |
end | |
end | |
M.RandomCrop = function(size) | |
return function(input) | |
local w, h = input:size(3), input:size(2) | |
if w == size and h == size then | |
return input | |
end | |
local x1, y1 = torch.random(0, w - size), torch.random(0, h - size) | |
local out = image.crop(input, x1, y1, x1 + size, y1 + size) | |
assert(out:size(2) == size and out:size(3) == size, 'wrong crop size') | |
return out | |
end | |
end | |
------------------------------------- | |
image.save('temp.jpg',torch.randn(3,256,256)) | |
local tnt = require 'torchnet' | |
local batchSize = 256 | |
local function getIterator() | |
return tnt.ParallelDatasetIterator{ | |
nthread = 4, | |
init = function() | |
require 'torchnet' | |
require 'image' | |
imtransform = M | |
end, | |
closure = function() | |
local list = tnt.ListDataset{ | |
list = torch.range(1,batchSize*500):long(), | |
load = function(x) | |
return { | |
input = image.load('temp.jpg'):float(), | |
target = torch.LongTensor{x}, | |
} | |
end, | |
}:transform{ | |
input = tnt.transform.compose{ | |
imtransform.Scale(256), | |
imtransform.RandomCrop(224), | |
imtransform.ColorNormalize(), | |
imtransform.HorizontalFlip(0.5), | |
} | |
}:batch(batchSize,'skip-last') | |
return list | |
end, | |
} | |
end | |
local iter = getIterator() | |
timer = torch.Timer() | |
timer:reset() | |
for x in iter() do | |
print(timer:time().real) | |
end | |
print(timer:time().real) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment