Created
November 7, 2020 05:16
-
-
Save mnemnion/0a23f8ae5c7acd77874b4d7506b2f6b8 to your computer and use it in GitHub Desktop.
Simple currying in Lua, partial application
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- Simple currying module | |
-- | |
-- With a bit of extra complexity, so that | |
-- currying several times gives only one | |
-- level of function indirection for up to | |
-- five parameters. | |
-- weak table to store a reference to the curried | |
-- function, with the parameters and original function | |
local _curried = setmetatable({}, { __mode = 'k' }) | |
local currier = { | |
false, -- this won't happen | |
function(fn, a, b) -- [2] | |
return function(...) | |
return fn(a, b, ...) | |
end | |
end, | |
function(fn, a, b, c) -- [3] | |
return function(...) | |
return fn(a, b, c, ...) | |
end | |
end, | |
function(fn, a, b, c, d) -- [4] | |
return function(...) | |
return fn(a, b, c, d, ...) | |
end | |
end, | |
function(fn, a, b, c, d, e) -- [5] | |
return function(...) | |
return fn(a, b, c, d, e, ...) | |
end | |
end, | |
-- this can of course be extended if one feels | |
-- it's worthwhile... | |
} | |
local function curry(fn, param) | |
assert(type(fn) == 'function' or | |
type(fn) == 'table' and getmetatable(fn).__call, | |
'#1 of curry must be a function or callable table') | |
local curried; | |
local pre = _curried[fn] | |
if not pre then | |
-- curry the function and store it | |
-- in the private attribute table | |
curried = function(...) return fn(param, ...) end | |
_curried[curried] = { param, n = 1 , fn = fn } | |
else | |
if pre.n <= 4 then | |
-- make a copy so our curried-once function | |
-- isn't mutated | |
local post = {} | |
for i = 1, pre.n do | |
post[i] = pre[i] | |
end | |
post.n = pre.n + 1 | |
post.fn = pre.fn | |
post[post.n] = param | |
curried = currier[post.n](post.fn, unpack(post, 1, post.n)) | |
_curried[curried] = post | |
else | |
-- just wrap for more than 5 parameters | |
curried = function(...) return fn(param, ...) end | |
end | |
end | |
return curried | |
end | |
local function partial(fn, ...) | |
for i = 1, select('#', ...) do | |
fn = curry(fn, select(i, ...)) | |
end | |
return fn | |
end | |
return { curry = curry, partial = partial } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Currying is easy!
An example of a practical use, is binding a database connection as the first parameter of a function which transacts some data. This makes it easier to reason about what code can interact with the database, since we're not as promiscuous about sharing copies of the conn: a User object can make changes using updateUser, but can't touch, say, the product table.
We use this in production, where there are in fact tests covering the edges. But these would be inconvenient to publish at this time.