Skip to content

Instantly share code, notes, and snippets.

@nrnrnr
Created June 19, 2020 16:29
Show Gist options
  • Save nrnrnr/0cc1d8848ffae7f99f8ed2b332f19124 to your computer and use it in GitHub Desktop.
Save nrnrnr/0cc1d8848ffae7f99f8ed2b332f19124 to your computer and use it in GitHub Desktop.
Abstractions for use with luaproc
local filetable = require 'filetable'
local string = require 'string'
local stringf = string.format
local M = { }
local __doc = { }
M.__doc = __doc
__doc.tostring = [[function(value) return string
Return a string that, when fed to ofstring,
reconstructs an isomorphic value.
Input must be composed of scalars and tables,
no functions or userdata.
]]
__doc.ofstring = [[function(string) returns value
Undoes tostring, so ofstring(tostring(v)) is isomorphic to v.
]]
__doc.pickle = [[synonym for tostring]]
__doc.unpickle = [[synonym for ofstring]]
M.tostring = filetable.tostring
function M.ofstring(s)
local f = assert(loadstring('return ' .. s))
return f()
end
M.pickle = M.tostring
M.unpickle = M.ofstring
__doc.picklemod = [[function(string) returns pickled value
Given string x.y.z, returns "(require 'x.y')['z']", used
to pickle a function that lives in a module.
Also accepts arguments function(modname, membername)
]]
function M.picklemod(name, ext)
if ext then
return stringf('(require %q)[%q]', name, ext)
else
local modname, member = name:match '^(.*)%.([^%.]*)$'
assert(modname, 'ill-formed argument to picklemod; needs a dot')
return M.picklemod(modname, member)
end
end
__doc.picklechunk = [[function(string) returns pickle
Pickles a sequence of statements ending in return.
When unpickled, produces the value returned.
]]
function M.picklechunk(stmts)
return stringf('((function() %s\n end)())', stmts)
end
return M
-- 'work crews of multiple threads'
local luaproc = require 'luaproc'
local serialize = require 'luaproc.serialize'
local io = require 'io'
local os = require 'os'
local string = require 'string'
local unpack = table.unpack or unpack
local modname = ...
local function eprintf(...) return io.stderr:write(string.format(...)) end
local stringf = string.format
----------------------------------------------------------------
--- load me after the last 'require'
local globals = { }; for x in pairs(_G) do globals[x] = true end
local M = { }
local function thismodule()
for x in pairs(_G) do
if not globals[x] then
error('Accidentally defined global variable' .. x)
end
end
return M
end
local __doc = { }
M.__doc = __doc
----------------------------------------------------------------
__doc.__overview = [[
Completes a job of 'work', where work is an abstract type.
A job is represented by a table containing these fields:
{ work : work list
, unpack : function(work) returns scalar, ... -- defaults to table.unpack
, state : a value pickle option -- defaults to empty table
, worker : function pickled -- take result of unpack, return scalar list
, collector : function pickled -- merge scalar list from worker into state
, nworkers : int option -- how many worker threads to create
, uid : string -- unique identifier
, status : function(work) -- optional, for side effect
}
Function types:
worker : function (work) returns results
collector : function(state, results)
status : function(work)
If I write the unpickling operation with vertical bars,
workcrew.run(job) is equivalent to the following
local state = |job.state|
for _, w in ipairs(job.work) do -- parallel
job.status(w) -- atomic
local results = { |job.worker|(job.unpack(w)) } -- not atomic
|job.collector|(state, table.unpack(results)) -- atomic
end
return unpickle(pickle(state))
The first result returned by job.unpack must not be `false`.
The job and state tables will be copied multiple times, but only one
thing is mutated: the state table associated with the collector
thread.
Picklers and unpickler can be found in luaproc.serialize.
]]
__doc.run = ([[function(job) returns state
As described in the overview, create a work crew, run the job's work,
tear down the crew, and return the final state:
]]):gsub('\n$', __doc.__overview:match '\n%s*local state.-return.-\n')
local pickle = serialize.pickle
local unpickle = serialize.unpickle
function M.run(job)
local work = assert(job.work)
assert(type(work) == 'table')
local unpack = job.unpack or unpack
local worker = assert(job.worker)
local nworkers = assert(job.nworkers or tonumber(os.getenv 'NPROC') or 16)
local collector = assert(job.collector)
-- local continuation = assert(job.continuation)
local status = job.status or function () end
local jobstate = assert(job.state or '{}')
local path = package.path
local cpath = package.cpath
local uid = assert(job.uid, 'no identifier that distinguishes channels')
local workchan = stringf('%s.work-in', uid) -- work sent here
local sink = stringf('%s.results', uid) -- results sent here
local finish = stringf('%s.finish', uid) -- pickled final state sent here
local function mkworker(wno)
return function()
package.path = path
package.cpath = cpath
local luaproc = require 'luaproc'
local serial = require 'luaproc.serialize'
local worker = serial.unpickle(worker)
local codefile = require 'codefile'
codefile.salt('luaproc workcrew worker ' .. wno)
local function consume(flag, ...)
if flag == false and select('#', ...) then
luaproc.send(sink, false)
return false
else
luaproc.send(sink, worker(flag, ...))
return true
end
end
while consume(luaproc.receive(workchan)) do
-- nothing
end
end
end
local function collect()
package.path = path
package.cpath = cpath
local luaproc = require 'luaproc'
local serial = require 'luaproc.serialize'
local io = require 'io'
local string = require 'string'
local function eprintf(...) io.stderr:write(string.format(...)) end
local nworkers = nworkers
local state = serial.unpickle(jobstate)
local collector = serial.unpickle(collector)
local function grab(flag, ...)
if flag == false and select('#', ...) == 0 then
nworkers = nworkers - 1
else
collector(state, flag, ...)
end
end
while nworkers > 0 do
grab(luaproc.receive(sink))
end
luaproc.send(finish, serial.pickle(state))
end
local function diagnose(name, f)
-- finds broken upvalues
local debug = require 'debug'
for i = 1, 1000 do
local x, v = debug.getupvalue(f, i)
if x == nil then break end
eprintf('Upvalue %2d of %s (%s) is %s %s\n', i, name, x, type(v), tostring(v))
end
end
local luaproc = require 'luaproc'
luaproc.setnumworkers(nworkers + 2)
luaproc.newchannel(workchan)
luaproc.newchannel(sink)
luaproc.newchannel(finish)
assert(luaproc.newproc(collect))
for i = 1, nworkers do
assert(luaproc.newproc(mkworker(i)))
end
-- fill the work queue, blocking as needed
for _, w in ipairs(work) do
luaproc.send(workchan, unpack(w))
status(w)
end
-- now shut the workers down
for i = 1, nworkers do
luaproc.send(workchan, false)
end
local serial = require 'luaproc.serialize'
local finalstate = serial.unpickle(assert(luaproc.receive(finish)))
luaproc.wait()
return finalstate
end
return thismodule()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment