Created
July 12, 2016 17:49
-
-
Save bartvm/ef2a5ba070d20866338aaae0028c7352 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if opt.hogwild > 1 then | |
local ipc = require 'libipc' | |
local q = ipc.workqueue('examples') | |
local q2 = ipc.workqueue('done') | |
local ids = ipc.workqueue('ids') | |
for i = 1, opt.hogwild do | |
ids:write(i) | |
end | |
-- Initialize the states | |
local state = {} | |
local update, states = grad.optim.adam(grad(model.f), state, model.parameters) | |
update(dataset.get().item[1].extra) | |
local workers = ipc.map(opt.hogwild, function(opt, parameters, states) | |
local ipc = require 'libipc' | |
local grad = require 'autograd' | |
local sys = require 'sys' | |
local q = ipc.workqueue('examples') | |
local q2 = ipc.workqueue('done') | |
-- A local copy of the model | |
local model = require 'cortex-core.projects.research.segmentation.model.model'(opt) | |
-- Shared parameters | |
model.parameters = parameters | |
-- Replace the states with the global states | |
local update, worker_states = grad.optim.adam(grad(model.f), {}, model.parameters) | |
for i = 1, #worker_states do | |
worker_states[i] = states[i] | |
end | |
-- Get the ID of this worker | |
local id = ipc.workqueue('ids'):read() | |
while true do | |
print(id .. ':start', sys.clock()) | |
local data = q:read() | |
local example, i = data[1], data[2] | |
if not example then | |
break | |
end | |
for j = 1, #worker_states do | |
worker_states[j].t = i | |
end | |
sys.tic() | |
local _, loss = update(example) | |
local seq_len = example.in_arcs:size(1) | |
-- Signal that another example can be added to the queue | |
print(id .. ':stop', sys.clock(), sys.toc()) | |
q2:write({id, seq_len, loss}) | |
end | |
end, opt, model.parameters, states) | |
-- Queue the next few iterations iterations | |
local t = 1 | |
for i = 1, opt.hogwild * opt.hogwild_queue do | |
q:write({dataset.get().item[1].extra, t}) | |
t = t + 1 | |
end | |
-- Start retrieving results and add a new example each time | |
local ema | |
local timer = torch.Timer() | |
for epoch = 1, opt.num_epochs do | |
local i = 1 | |
timer:reset() | |
while i <= dataset.size() do | |
q:write({dataset.get().item[1].extra, t}) | |
t = t + 1 | |
local result = q2:read() | |
local id, seq_len, loss = result[1], result[2], result[3] | |
print(id .. ':received', sys.clock()) | |
local bits_per_char = loss / seq_len / torch.log(2) | |
ema = ema and ema * 0.9 + 0.1 * bits_per_char or bits_per_char | |
print(epoch, i, seq_len, bits_per_char, ema) | |
i = i + 1 | |
end | |
print('epoch took', timer:time().real) | |
end | |
-- Signal the workers to stop | |
for i = 1, opt.hogwild do | |
q:write({false, false}) | |
end | |
workers:join() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment