Last active
September 9, 2016 13:36
-
-
Save yenchenlin/bfb8fecc63fe9a8489dce29b43cd7388 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'torch' | |
require 'nn' | |
require 'optim' | |
-- to specify these at runtime, you can do, e.g.: | |
-- $ lr=0.001 th main.lua | |
opt = { | |
dataset = 'video2', -- indicates what dataset load to use (in data.lua) | |
nThreads = 32, -- how many threads to pre-fetch data | |
batchSize = 64, -- self-explanatory | |
loadSize = 128, -- when loading images, resize first to this size | |
fineSize = 64, -- crop this size from the loaded image | |
frameSize = 32, | |
lr = 0.0002, -- learning rate | |
lr_decay = 1000, -- how often to decay learning rate (in epoch's) | |
lambda = 0.1, | |
beta1 = 0.5, -- momentum term for adam | |
meanIter = 0, -- how many iterations to retrieve for mean estimation | |
saveIter = 1000, -- write check point on this interval | |
niter = 100, -- number of iterations through dataset | |
ntrain = math.huge, -- how big one epoch should be | |
gpu = 1, -- which GPU to use; consider using CUDA_VISIBLE_DEVICES instead | |
cudnn = 1, -- whether to use cudnn or not | |
finetune = '', -- if set, will load this network instead of starting from scratch | |
name = 'beach100', -- the name of the experiment | |
randomize = 1, -- whether to shuffle the data file or not | |
cropping = 'random', -- options for data augmentation | |
display_port = 8001, -- port to push graphs | |
display_id = 1, -- window ID when pushing graphs | |
mean = {0,0,0}, | |
data_root = '/data/vision/torralba/crossmodal/flickr_videos/', | |
data_list = '/data/vision/torralba/crossmodal/flickr_videos/scene_extract/lists-full/_b_beach.txt.train', | |
} | |
-- one-line argument parser. parses enviroment variables to override the defaults | |
for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end | |
print(opt) | |
torch.manualSeed(0) | |
torch.setnumthreads(1) | |
torch.setdefaulttensortype('torch.FloatTensor') | |
-- if using GPU, select indicated one | |
if opt.gpu > 0 then | |
require 'cunn' | |
cutorch.setDevice(opt.gpu) | |
end | |
-- create data loader | |
local DataLoader = paths.dofile('data/data.lua') | |
local data = DataLoader.new(opt.nThreads, opt.dataset, opt) | |
print("Dataset: " .. opt.dataset, " Size: ", data:size()) | |
-- define the model | |
local net | |
local netD | |
local mask_net | |
local motion_net | |
local static_net | |
local penalty_net | |
if opt.finetune == '' then -- build network from scratch | |
net = nn.Sequential() | |
static_net = nn.Sequential() | |
static_net:add(nn.View(-1, 100, 1, 1)) | |
static_net:add(nn.SpatialFullConvolution(100, 512, 4,4)) | |
static_net:add(nn.SpatialBatchNormalization(512)):add(nn.ReLU(true)) | |
static_net:add(nn.SpatialFullConvolution(512, 256, 4,4, 2,2, 1,1)) | |
static_net:add(nn.SpatialBatchNormalization(256)):add(nn.ReLU(true)) | |
static_net:add(nn.SpatialFullConvolution(256, 128, 4,4, 2,2, 1,1)) | |
static_net:add(nn.SpatialBatchNormalization(128)):add(nn.ReLU(true)) | |
static_net:add(nn.SpatialFullConvolution(128, 64, 4,4, 2,2, 1,1)) | |
static_net:add(nn.SpatialBatchNormalization(64)):add(nn.ReLU(true)) | |
static_net:add(nn.SpatialFullConvolution(64, 3, 4,4, 2,2, 1,1)) | |
static_net:add(nn.Tanh()) | |
local net_video = nn.Sequential() | |
net_video:add(nn.View(-1, 100, 1, 1, 1)) | |
net_video:add(nn.VolumetricFullConvolution(100, 512, 2,4,4)) | |
net_video:add(nn.VolumetricBatchNormalization(512)):add(nn.ReLU(true)) | |
net_video:add(nn.VolumetricFullConvolution(512, 256, 4,4,4, 2,2,2, 1,1,1)) | |
net_video:add(nn.VolumetricBatchNormalization(256)):add(nn.ReLU(true)) | |
net_video:add(nn.VolumetricFullConvolution(256, 128, 4,4,4, 2,2,2, 1,1,1)) | |
net_video:add(nn.VolumetricBatchNormalization(128)):add(nn.ReLU(true)) | |
net_video:add(nn.VolumetricFullConvolution(128, 64, 4,4,4, 2,2,2, 1,1,1)) | |
net_video:add(nn.VolumetricBatchNormalization(64)):add(nn.ReLU(true)) | |
local mask_out = nn.VolumetricFullConvolution(64,1, 4,4,4, 2,2,2, 1,1,1) | |
penalty_net = nn.L1Penalty(opt.lambda, true) | |
mask_net = nn.Sequential():add(mask_out):add(nn.Sigmoid()):add(penalty_net) | |
gen_net = nn.Sequential():add(nn.VolumetricFullConvolution(64,3, 4,4,4, 2,2,2, 1,1,1)):add(nn.Tanh()) | |
net_video:add(nn.ConcatTable():add(gen_net):add(mask_net)) | |
-- [1] is generated video, [2] is mask, and [3] is static | |
net:add(nn.ConcatTable():add(net_video):add(static_net)):add(nn.FlattenTable()) | |
-- video .* mask (with repmat on mask) | |
motion_net = nn.Sequential():add(nn.ConcatTable():add(nn.SelectTable(1)) | |
:add(nn.Sequential():add(nn.SelectTable(2)) | |
:add(nn.Squeeze()) | |
:add(nn.Replicate(3, 2)))) -- for color chan | |
:add(nn.CMulTable()) | |
-- static .* (1-mask) (then repmatted) | |
local sta_part = nn.Sequential():add(nn.ConcatTable():add(nn.Sequential():add(nn.SelectTable(3)) | |
:add(nn.Replicate(opt.frameSize, 3))) -- for time | |
:add(nn.Sequential():add(nn.SelectTable(2)) | |
:add(nn.Squeeze()) | |
:add(nn.MulConstant(-1)) | |
:add(nn.AddConstant(1)) | |
:add(nn.Replicate(3, 2)))) -- for color chan | |
:add(nn.CMulTable()) | |
net:add(nn.ConcatTable():add(motion_net):add(sta_part)):add(nn.CAddTable()) | |
netD = nn.Sequential() | |
netD:add(nn.VolumetricConvolution(3,64, 4,4,4, 2,2,2, 1,1,1)) | |
netD:add(nn.LeakyReLU(0.2, true)) | |
netD:add(nn.VolumetricConvolution(64,128, 4,4,4, 2,2,2, 1,1,1)) | |
netD:add(nn.VolumetricBatchNormalization(128,1e-3)):add(nn.LeakyReLU(0.2, true)) | |
netD:add(nn.VolumetricConvolution(128,256, 4,4,4, 2,2,2, 1,1,1)) | |
netD:add(nn.VolumetricBatchNormalization(256,1e-3)):add(nn.LeakyReLU(0.2, true)) | |
netD:add(nn.VolumetricConvolution(256,512, 4,4,4, 2,2,2, 1,1,1)) | |
netD:add(nn.VolumetricBatchNormalization(512,1e-3)):add(nn.LeakyReLU(0.2, true)) | |
netD:add(nn.VolumetricConvolution(512,2, 2,4,4, 1,1,1, 0,0,0)) | |
netD:add(nn.View(2):setNumInputDims(4)) | |
-- initialize the model | |
local function weights_init(m) | |
local name = torch.type(m) | |
if name:find('Convolution') then | |
m.weight:normal(0.0, 0.01) | |
m.bias:fill(0) | |
elseif name:find('BatchNormalization') then | |
if m.weight then m.weight:normal(1.0, 0.02) end | |
if m.bias then m.bias:fill(0) end | |
end | |
end | |
net:apply(weights_init) -- loop over all layers, applying weights_init | |
netD:apply(weights_init) | |
mask_out.weight:normal(0, 0.01) | |
mask_out.bias:fill(0) | |
else -- load in existing network | |
print('loading ' .. opt.finetune) | |
net = torch.load(opt.finetune) | |
end | |
print('Generator:') | |
print(net) | |
print('Discriminator:') | |
print(netD) | |
-- define the loss | |
local criterion = nn.CrossEntropyCriterion() | |
local real_label = 1 | |
local fake_label = 2 | |
-- create the data placeholders | |
local noise = torch.Tensor(opt.batchSize, 100) | |
local target = torch.Tensor(opt.batchSize, 3, opt.frameSize, opt.fineSize, opt.fineSize) | |
local label = torch.Tensor(opt.batchSize) | |
local err, errD | |
-- timers to roughly profile performance | |
local tm = torch.Timer() | |
local data_tm = torch.Timer() | |
-- ship everything to GPU if needed | |
if opt.gpu > 0 then | |
noise = noise:cuda() | |
target = target:cuda() | |
label = label:cuda() | |
net:cuda() | |
netD:cuda() | |
criterion:cuda() | |
end | |
-- conver to cudnn if needed | |
-- if this errors on you, you can disable, but will be slightly slower | |
if opt.gpu > 0 and opt.cudnn > 0 then | |
require 'cudnn' | |
net = cudnn.convert(net, cudnn) | |
netD = cudnn.convert(netD, cudnn) | |
end | |
-- get a vector of parameters | |
local parameters, gradParameters = net:getParameters() | |
local parametersD, gradParametersD = netD:getParameters() | |
-- show graphics | |
disp = require 'display' | |
disp.url = 'http://localhost:' .. opt.display_port .. '/events' | |
-- optimization closure | |
-- the optimizer will call this function to get the gradients | |
local data_im,data_label | |
local fDx = function(x) | |
gradParametersD:zero() | |
-- fetch data | |
data_tm:reset(); data_tm:resume() | |
data_im = data:getBatch() | |
data_tm:stop() | |
-- ship to GPU | |
noise:normal() | |
target:copy(data_im) | |
label:fill(real_label) | |
-- forward/backwards real examples | |
local output = netD:forward(target) | |
errD = criterion:forward(output, label) | |
local df_do = criterion:backward(output, label) | |
netD:backward(target, df_do) | |
-- generate fake examples | |
local fake = net:forward(noise) | |
target:copy(fake) | |
label:fill(fake_label) | |
-- forward/backwards fake examples | |
local output = netD:forward(target) | |
errD = errD + criterion:forward(output, label) | |
local df_do = criterion:backward(output, label) | |
netD:backward(target, df_do) | |
errD = errD / 2 | |
return errD, gradParametersD | |
end | |
local fx = function(x) | |
gradParameters:zero() | |
label:fill(real_label) | |
local output = netD.output | |
err = criterion:forward(output, label) | |
local df_do = criterion:backward(output, label) | |
local df_dg = netD:updateGradInput(target, df_do) | |
net:backward(noise, df_dg) | |
return err, gradParameters | |
end | |
local counter = 0 | |
local history = {} | |
-- parameters for the optimization | |
-- very important: you must only create this table once! | |
-- the optimizer will add fields to this table (such as momentum) | |
local optimState = { | |
learningRate = opt.lr, | |
beta1 = opt.beta1, | |
} | |
local optimStateD = { | |
learningRate = opt.lr, | |
beta1 = opt.beta1, | |
} | |
-- train main loop | |
for epoch = 1,opt.niter do -- for each epoch | |
for i = 1, math.min(data:size(), opt.ntrain), opt.batchSize do -- for each mini-batch | |
collectgarbage() -- necessary sometimes | |
tm:reset() | |
-- do one iteration | |
optim.adam(fDx, parametersD, optimStateD) | |
optim.adam(fx, parameters, optimState) | |
if counter % 10 == 0 then | |
table.insert(history, {counter, err, errD}) | |
disp.plot(history, {win=opt.display_id+1, title=opt.name, labels = {"iteration", "err", "errD"}}) | |
end | |
if counter % 100 == 0 then | |
local vis = net.output:float() | |
local vis_tab = {} | |
for i=1,opt.frameSize do table.insert(vis_tab, vis[{ {}, {}, i, {}, {} }]) end | |
disp.image(torch.cat(vis_tab, 3), {win=opt.display_id, title=(opt.name .. ' gen')}) | |
local vis = motion_net.output:float() | |
local vis_tab = {} | |
for i=1,opt.frameSize do table.insert(vis_tab, vis[{ {}, {}, i, {}, {} }]) end | |
disp.image(torch.cat(vis_tab, 3), {win=opt.display_id+3, title=(opt.name .. ' motion')}) | |
local vis = static_net.output:float() | |
disp.image(vis, {win=opt.display_id+4, title=(opt.name .. ' static')}) | |
local vis = mask_net.output:float():squeeze() | |
local vis_lo = vis:min() | |
local vis_hi = vis:max() | |
local vis_tab = {} | |
for i=1,opt.frameSize do table.insert(vis_tab, vis[{ {}, i, {}, {} }]) end | |
disp.image(torch.cat(vis_tab, 2), {win=opt.display_id+2, title=(opt.name .. ' mask ' .. string.format('%.2f %.2f', vis_lo, vis_hi))}) | |
end | |
counter = counter + 1 | |
print(('%s: Epoch: [%d][%8d / %8d]\t Time: %.3f DataTime: %.3f ' | |
.. ' Err: %.4f ErrD: %.4f L2: %.4f'):format( | |
opt.name, epoch, ((i-1) / opt.batchSize), | |
math.floor(math.min(data:size(), opt.ntrain) / opt.batchSize), | |
tm:time().real, data_tm:time().real, | |
err and err or -1, errD and errD or -1, penalty_net.loss)) | |
-- save checkpoint | |
-- :clearState() compacts the model so it takes less space on disk | |
if counter % opt.saveIter == 0 then | |
print('Saving ' .. opt.name .. '/iter' .. counter .. '_net.t7') | |
paths.mkdir('checkpoints') | |
paths.mkdir('checkpoints/' .. opt.name) | |
torch.save('checkpoints/' .. opt.name .. '/iter' .. counter .. '_net.t7', net:clearState()) | |
torch.save('checkpoints/' .. opt.name .. '/iter' .. counter .. '_netD.t7', netD:clearState()) | |
torch.save('checkpoints/' .. opt.name .. '/iter' .. counter .. '_history.t7', history) | |
end | |
end | |
-- decay the learning rate, if requested | |
if opt.lr_decay > 0 and epoch % opt.lr_decay == 0 then | |
opt.lr = opt.lr / 10 | |
print('Decreasing learning rate to ' .. opt.lr) | |
-- create new optimState to reset momentum | |
optimState = { | |
learningRate = opt.lr, | |
beta1 = opt.beta1, | |
} | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment