Created
July 19, 2018 13:03
-
-
Save szagoruyko/5f958d986a2973340f392a53b8e5575e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'optim' | |
require 'image' | |
require 'xlua' | |
require 'cudnn' | |
local json = require 'cjson' | |
local tnt = require 'torchnet' | |
local utils = require 'models.utils' | |
-- local tds = require 'tds' | |
local iterm = require 'iterm' | |
local engines = require 'engines' | |
local grad = require 'autograd' | |
require 'FPR95Meter' | |
local opt = { | |
save_folder = 'logs', | |
batchSize = 256, | |
learningRate = 0.1, | |
weightDecay = 0.0005, | |
momentum = 0.9, | |
-- t0 = 1e+4, | |
-- eta0 = 0.1, | |
max_epoch = 1200, | |
model = '2ch', | |
optimMethod = 'asgd', | |
backend = 'cudnn', | |
train_set = 'notredame', | |
test_set = 'liberty', | |
train_matches = 'm50_500000_500000_0.txt', | |
test_matches = 'm50_100000_100000_0.txt', | |
nDonkeys = 6, | |
manualSeed = 555, | |
grad_clamp = 2, | |
data_type = 'torch.CudaTensor', | |
testOnly = false, | |
nGPU = 1, | |
alpha = 1, | |
dampening = 0, | |
} | |
opt = xlua.envparams(opt) | |
print(opt) | |
local function cast(x) return x:type(opt.data_type) end | |
local function log(t) print('json_stats: '..json.encode(tablex.merge(t,opt,true))) end | |
-- Data loading -- | |
local function loadProvider() | |
print'Loading train and test data' | |
local p = { | |
train = torch.load(opt.train_set), | |
test = torch.load(opt.test_set), | |
} | |
p.train.info:add(1) | |
p.test.info:add(1) | |
-- assign matches and nonmatches | |
-- add 1 because original indexes are 0-based | |
for i,v in ipairs{'matches', 'nonmatches'} do | |
p.train[v] = p.train[1][opt.train_matches][v] + 1 | |
p.test[v] = p.test[1][opt.test_matches][v] + 1 | |
end | |
return p | |
end | |
local provider = loadProvider() | |
local mode = 'train' | |
local function getIterator(mode) | |
return tnt.ParallelDatasetIterator{ | |
nthread = opt.nDonkeys, | |
init = function() | |
require 'torchnet' | |
require 'image' | |
tds = require 'tds' | |
end, | |
closure = function() | |
local dataset = provider[mode] | |
if mode == 'train' then | |
local function perClassTable(info) | |
local t = tds.Vec() | |
for i=1,info:max() do t:insert(tds.Vec()) end | |
for j=1,info:numel() do t[info[j]]:insert(j) end | |
return t | |
end | |
local idxs = perClassTable(dataset.info) | |
print('Found dataset with '..#idxs..' unique points') | |
local list_dataset = tnt.ListDataset{ | |
list = dataset.info, | |
load = function(idx) | |
local mean = dataset.patches_mean[idx][1] | |
return (dataset.patches[idx]:float() - mean) / 256 | |
end, | |
} | |
local triplet_dataset = tnt.ListDataset{ | |
list = torch.range(1,#idxs):long(), | |
load = function(idx) | |
local other_idx = idx | |
while other_idx == idx do | |
other_idx = torch.random(#idxs) | |
end | |
local instances_a = idxs[idx] | |
local instances_b = idxs[other_idx] | |
return { | |
list_dataset.load(instances_a[torch.random(#instances_a)]), | |
list_dataset.load(instances_b[torch.random(#instances_b)]), | |
list_dataset.load(instances_a[torch.random(#instances_a)]), | |
} | |
end | |
} | |
return triplet_dataset | |
:transform(function(x) return {tnt.utils.table.mergetensor(x)} end) | |
:batch(opt.batchSize / 2, 'skip-last') | |
:transform(function(x) return x[1] end) | |
else | |
local function getListDataset(pair_type) | |
local list_dataset = tnt.ListDataset{ | |
list = dataset[pair_type], | |
load = function(idx) | |
local im = torch.FloatTensor(2,64,64) | |
for i=1,2 do | |
local mean = dataset.patches_mean[idx[i]][1] | |
im[i]:copy(dataset.patches[idx[i]]):add(-mean):div(256) | |
end | |
return { | |
input = im, | |
target = torch.LongTensor{pair_type == 'matches' and 1 or -1}, | |
} | |
end, | |
} | |
return list_dataset:batch(opt.batchSize / 2, 'include-last') | |
end | |
local concat = tnt.ConcatDataset{ | |
datasets = { | |
getListDataset'matches', | |
getListDataset'nonmatches', | |
} | |
} | |
local n = concat:size() | |
local multi_idx = torch.range(1,n):view(2,-1):t():reshape(n) | |
return concat | |
:sample(function(dataset, idx) return multi_idx[idx] end) | |
:batch(2) | |
:transform{ | |
input = function(x) return x:view(-1,2,64,64) end, | |
target = function(y) return y:view(-1) end, | |
} | |
end | |
end | |
} | |
end | |
local full_model = torch.load('/home/zagoruys/raid/Zoo/torch/deepcompare/siam2stream/siam2stream_notredame.t7') | |
local desc = nn.Sequential() | |
:add(full_model:get(1):get(1)) | |
:add(nn.Normalize(2)) | |
desc:get(1):get(1):add(nn.Contiguous()) | |
desc:get(1):get(2):add(nn.Contiguous()) | |
for i,v in ipairs(desc:findModules'nn.SpatialConvolution') do | |
v.gradWeight = v.weight:clone() | |
v.gradBias = v.bias:clone() | |
end | |
cudnn.convert(desc:cuda(), cudnn) | |
print(desc) | |
-- embedding | |
local W = torch.CudaTensor(256,512):normal(0,0.1) | |
----------------------- Tester --------------------- | |
local test_engine = tnt.AutogradEngine() | |
local fpr95meter = tnt.FPR95Meter() | |
local inputs = cast(torch.Tensor()) | |
local targets = cast(torch.Tensor()) | |
test_engine.hooks.onSample = function(state) | |
inputs:resize(state.sample.input:size()):copy(state.sample.input) | |
targets:resize(state.sample.target:size()):copy(state.sample.target) | |
state.sample.input = inputs | |
state.sample.target = targets | |
end | |
test_engine.hooks.onStart = function(state) | |
fpr95meter:reset() | |
end | |
local function test() | |
test_engine:test{ | |
network = function(params, inputs, targets) | |
local x_l = W * desc:forward(inputs:select(2,1)):t() | |
local x_p = W * desc:forward(inputs:select(2,2)):t() | |
local dists = torch.sum(torch.pow(x_l - x_p,2),1) | |
fpr95meter:add(-dists, targets) | |
return 0, dists | |
end, | |
params = {W}, | |
iterator = getIterator'test', | |
data_type = 'torch.CudaTensor', | |
} | |
return fpr95meter:value() | |
end | |
---------------------- Embedding ------------------ | |
local l2_2 = function(x) return torch.sum(torch.pow(x,2),2) end | |
local g,params = grad.functionalize(desc) | |
local f = function(params, inputs) | |
local W = params.embedding | |
local x = (g(params.D, inputs) * torch.transpose(W,1,2)):view(opt.batchSize/2,3,-1) | |
local x_a = torch.select(x,2,1) | |
local x_p = torch.select(x,2,3) | |
local x_n = torch.select(x,2,2) | |
return torch.mean(torch.cmax(- l2_2(x_a - x_n) + l2_2(x_a - x_p) + opt.alpha, 0)) | |
end | |
local engine = tnt.AutogradEngine() | |
local meters = { | |
loss = tnt.AverageValueMeter() | |
} | |
local epochSize | |
engine.hooks.onStart = function(state) | |
engines.classification.onStart(state) | |
epochSize = state.iterator:exec'size'[1] | |
end | |
engine.hooks.onSample = function(state) | |
state.sample_cache = state.sample_cache or torch.CudaTensor() | |
state.sample_cache:resize(#state.sample):copy(state.sample) | |
-- state.sample = {input = desc:forward(state.sample_cache:view(-1,1,64,64))} | |
state.sample = {input = state.sample_cache:view(-1,1,64,64)} | |
end | |
engine.hooks.onForward = function(state) | |
meters.loss:add(state.loss) | |
-- print(('Epoch %d [%d/%d], %.6f'):format( | |
-- state.epoch+1, state.t % epochSize, epochSize, state.loss)) | |
end | |
engine.hooks.onEndEpoch = function(state) | |
log{ | |
testFPR95 = test(), | |
loss = meters.loss:value(), | |
epoch = state.epoch, | |
} | |
-- torch.save(paths.concat(opt.save, 'model.t7'), W) | |
end | |
engine:train{ | |
network = grad(f, {optimize = true}), | |
params = {embedding = W, D = params}, | |
iterator = getIterator'train', | |
maxepoch = opt.max_epoch, | |
optimMethod = grad.optim.sgd, | |
config = opt, | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment