Skip to content

Instantly share code, notes, and snippets.

@bartvm
Created July 12, 2016 17:49
Show Gist options
  • Save bartvm/ef2a5ba070d20866338aaae0028c7352 to your computer and use it in GitHub Desktop.
Save bartvm/ef2a5ba070d20866338aaae0028c7352 to your computer and use it in GitHub Desktop.
if opt.hogwild > 1 then
local ipc = require 'libipc'
local q = ipc.workqueue('examples')
local q2 = ipc.workqueue('done')
local ids = ipc.workqueue('ids')
for i = 1, opt.hogwild do
ids:write(i)
end
-- Initialize the states
local state = {}
local update, states = grad.optim.adam(grad(model.f), state, model.parameters)
update(dataset.get().item[1].extra)
local workers = ipc.map(opt.hogwild, function(opt, parameters, states)
local ipc = require 'libipc'
local grad = require 'autograd'
local sys = require 'sys'
local q = ipc.workqueue('examples')
local q2 = ipc.workqueue('done')
-- A local copy of the model
local model = require 'cortex-core.projects.research.segmentation.model.model'(opt)
-- Shared parameters
model.parameters = parameters
-- Replace the states with the global states
local update, worker_states = grad.optim.adam(grad(model.f), {}, model.parameters)
for i = 1, #worker_states do
worker_states[i] = states[i]
end
-- Get the ID of this worker
local id = ipc.workqueue('ids'):read()
while true do
print(id .. ':start', sys.clock())
local data = q:read()
local example, i = data[1], data[2]
if not example then
break
end
for j = 1, #worker_states do
worker_states[j].t = i
end
sys.tic()
local _, loss = update(example)
local seq_len = example.in_arcs:size(1)
-- Signal that another example can be added to the queue
print(id .. ':stop', sys.clock(), sys.toc())
q2:write({id, seq_len, loss})
end
end, opt, model.parameters, states)
-- Queue the next few iterations iterations
local t = 1
for i = 1, opt.hogwild * opt.hogwild_queue do
q:write({dataset.get().item[1].extra, t})
t = t + 1
end
-- Start retrieving results and add a new example each time
local ema
local timer = torch.Timer()
for epoch = 1, opt.num_epochs do
local i = 1
timer:reset()
while i <= dataset.size() do
q:write({dataset.get().item[1].extra, t})
t = t + 1
local result = q2:read()
local id, seq_len, loss = result[1], result[2], result[3]
print(id .. ':received', sys.clock())
local bits_per_char = loss / seq_len / torch.log(2)
ema = ema and ema * 0.9 + 0.1 * bits_per_char or bits_per_char
print(epoch, i, seq_len, bits_per_char, ema)
i = i + 1
end
print('epoch took', timer:time().real)
end
-- Signal the workers to stop
for i = 1, opt.hogwild do
q:write({false, false})
end
workers:join()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment