Skip to content

Instantly share code, notes, and snippets.

@szagoruyko
Created July 19, 2018 13:03
Show Gist options
  • Save szagoruyko/5f958d986a2973340f392a53b8e5575e to your computer and use it in GitHub Desktop.
Save szagoruyko/5f958d986a2973340f392a53b8e5575e to your computer and use it in GitHub Desktop.
require 'optim'
require 'image'
require 'xlua'
require 'cudnn'
local json = require 'cjson'
local tnt = require 'torchnet'
local utils = require 'models.utils'
-- local tds = require 'tds'
local iterm = require 'iterm'
local engines = require 'engines'
local grad = require 'autograd'
require 'FPR95Meter'
local opt = {
save_folder = 'logs',
batchSize = 256,
learningRate = 0.1,
weightDecay = 0.0005,
momentum = 0.9,
-- t0 = 1e+4,
-- eta0 = 0.1,
max_epoch = 1200,
model = '2ch',
optimMethod = 'asgd',
backend = 'cudnn',
train_set = 'notredame',
test_set = 'liberty',
train_matches = 'm50_500000_500000_0.txt',
test_matches = 'm50_100000_100000_0.txt',
nDonkeys = 6,
manualSeed = 555,
grad_clamp = 2,
data_type = 'torch.CudaTensor',
testOnly = false,
nGPU = 1,
alpha = 1,
dampening = 0,
}
opt = xlua.envparams(opt)
print(opt)
local function cast(x) return x:type(opt.data_type) end
local function log(t) print('json_stats: '..json.encode(tablex.merge(t,opt,true))) end
-- Data loading --
local function loadProvider()
print'Loading train and test data'
local p = {
train = torch.load(opt.train_set),
test = torch.load(opt.test_set),
}
p.train.info:add(1)
p.test.info:add(1)
-- assign matches and nonmatches
-- add 1 because original indexes are 0-based
for i,v in ipairs{'matches', 'nonmatches'} do
p.train[v] = p.train[1][opt.train_matches][v] + 1
p.test[v] = p.test[1][opt.test_matches][v] + 1
end
return p
end
local provider = loadProvider()
local mode = 'train'
local function getIterator(mode)
return tnt.ParallelDatasetIterator{
nthread = opt.nDonkeys,
init = function()
require 'torchnet'
require 'image'
tds = require 'tds'
end,
closure = function()
local dataset = provider[mode]
if mode == 'train' then
local function perClassTable(info)
local t = tds.Vec()
for i=1,info:max() do t:insert(tds.Vec()) end
for j=1,info:numel() do t[info[j]]:insert(j) end
return t
end
local idxs = perClassTable(dataset.info)
print('Found dataset with '..#idxs..' unique points')
local list_dataset = tnt.ListDataset{
list = dataset.info,
load = function(idx)
local mean = dataset.patches_mean[idx][1]
return (dataset.patches[idx]:float() - mean) / 256
end,
}
local triplet_dataset = tnt.ListDataset{
list = torch.range(1,#idxs):long(),
load = function(idx)
local other_idx = idx
while other_idx == idx do
other_idx = torch.random(#idxs)
end
local instances_a = idxs[idx]
local instances_b = idxs[other_idx]
return {
list_dataset.load(instances_a[torch.random(#instances_a)]),
list_dataset.load(instances_b[torch.random(#instances_b)]),
list_dataset.load(instances_a[torch.random(#instances_a)]),
}
end
}
return triplet_dataset
:transform(function(x) return {tnt.utils.table.mergetensor(x)} end)
:batch(opt.batchSize / 2, 'skip-last')
:transform(function(x) return x[1] end)
else
local function getListDataset(pair_type)
local list_dataset = tnt.ListDataset{
list = dataset[pair_type],
load = function(idx)
local im = torch.FloatTensor(2,64,64)
for i=1,2 do
local mean = dataset.patches_mean[idx[i]][1]
im[i]:copy(dataset.patches[idx[i]]):add(-mean):div(256)
end
return {
input = im,
target = torch.LongTensor{pair_type == 'matches' and 1 or -1},
}
end,
}
return list_dataset:batch(opt.batchSize / 2, 'include-last')
end
local concat = tnt.ConcatDataset{
datasets = {
getListDataset'matches',
getListDataset'nonmatches',
}
}
local n = concat:size()
local multi_idx = torch.range(1,n):view(2,-1):t():reshape(n)
return concat
:sample(function(dataset, idx) return multi_idx[idx] end)
:batch(2)
:transform{
input = function(x) return x:view(-1,2,64,64) end,
target = function(y) return y:view(-1) end,
}
end
end
}
end
local full_model = torch.load('/home/zagoruys/raid/Zoo/torch/deepcompare/siam2stream/siam2stream_notredame.t7')
local desc = nn.Sequential()
:add(full_model:get(1):get(1))
:add(nn.Normalize(2))
desc:get(1):get(1):add(nn.Contiguous())
desc:get(1):get(2):add(nn.Contiguous())
for i,v in ipairs(desc:findModules'nn.SpatialConvolution') do
v.gradWeight = v.weight:clone()
v.gradBias = v.bias:clone()
end
cudnn.convert(desc:cuda(), cudnn)
print(desc)
-- embedding
local W = torch.CudaTensor(256,512):normal(0,0.1)
----------------------- Tester ---------------------
local test_engine = tnt.AutogradEngine()
local fpr95meter = tnt.FPR95Meter()
local inputs = cast(torch.Tensor())
local targets = cast(torch.Tensor())
test_engine.hooks.onSample = function(state)
inputs:resize(state.sample.input:size()):copy(state.sample.input)
targets:resize(state.sample.target:size()):copy(state.sample.target)
state.sample.input = inputs
state.sample.target = targets
end
test_engine.hooks.onStart = function(state)
fpr95meter:reset()
end
local function test()
test_engine:test{
network = function(params, inputs, targets)
local x_l = W * desc:forward(inputs:select(2,1)):t()
local x_p = W * desc:forward(inputs:select(2,2)):t()
local dists = torch.sum(torch.pow(x_l - x_p,2),1)
fpr95meter:add(-dists, targets)
return 0, dists
end,
params = {W},
iterator = getIterator'test',
data_type = 'torch.CudaTensor',
}
return fpr95meter:value()
end
---------------------- Embedding ------------------
local l2_2 = function(x) return torch.sum(torch.pow(x,2),2) end
local g,params = grad.functionalize(desc)
local f = function(params, inputs)
local W = params.embedding
local x = (g(params.D, inputs) * torch.transpose(W,1,2)):view(opt.batchSize/2,3,-1)
local x_a = torch.select(x,2,1)
local x_p = torch.select(x,2,3)
local x_n = torch.select(x,2,2)
return torch.mean(torch.cmax(- l2_2(x_a - x_n) + l2_2(x_a - x_p) + opt.alpha, 0))
end
local engine = tnt.AutogradEngine()
local meters = {
loss = tnt.AverageValueMeter()
}
local epochSize
engine.hooks.onStart = function(state)
engines.classification.onStart(state)
epochSize = state.iterator:exec'size'[1]
end
engine.hooks.onSample = function(state)
state.sample_cache = state.sample_cache or torch.CudaTensor()
state.sample_cache:resize(#state.sample):copy(state.sample)
-- state.sample = {input = desc:forward(state.sample_cache:view(-1,1,64,64))}
state.sample = {input = state.sample_cache:view(-1,1,64,64)}
end
engine.hooks.onForward = function(state)
meters.loss:add(state.loss)
-- print(('Epoch %d [%d/%d], %.6f'):format(
-- state.epoch+1, state.t % epochSize, epochSize, state.loss))
end
engine.hooks.onEndEpoch = function(state)
log{
testFPR95 = test(),
loss = meters.loss:value(),
epoch = state.epoch,
}
-- torch.save(paths.concat(opt.save, 'model.t7'), W)
end
engine:train{
network = grad(f, {optimize = true}),
params = {embedding = W, D = params},
iterator = getIterator'train',
maxepoch = opt.max_epoch,
optimMethod = grad.optim.sgd,
config = opt,
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment