Skip to content

Instantly share code, notes, and snippets.

@fab13n
Created September 29, 2011 10:24
Show Gist options
  • Save fab13n/1250474 to your computer and use it in GitHub Desktop.
Save fab13n/1250474 to your computer and use it in GitHub Desktop.
require 'metalua.walk'
local M = { }
-- support for old-style modules
treequery=M
-- multimap helper mmap: associate a key to a set of values
local function mmap_add (mmap, node, x)
if node==nil then return false end
local set = mmap[node]
if set then set[x] = true
else mmap[node] = {[x]=true} end
end
-- currently unused, I throw the whole set away
local function mmap_remove (mmap, node, x)
local set = mmap[node]
if not set then return false
elseif not set[x] then return false
elseif next(set) then set[x]=nil
else mmap[node] = nil end
return true
end
-- treequery metatable
local MT = { }; MT.__index = MT
--- treequery constructor
-- the resultingg object will allow to filter ans operate on the AST
-- @param root the AST to visit
-- @return a treequery visitor instance
function M.treequery(root)
return setmetatable({
root = root,
unsatisfied = 0,
predicates = { },
until_up = { },
from_up = { },
up_f = false,
down_f = false,
filters = { },
}, MT)
end
-- helper to share the implementations of positional filters
local function add_pos_filter(self, position, inverted, inclusive, f, ...)
if type(f)=='string' then f = M.has_tag(f, ...) end
if not inverted then self.unsatisfied += 1 end
local x = {
pred = f,
position = position,
satisfied = false,
inverted = inverted or false,
inclusive = inclusive or false }
table.insert(self.predicates, x)
return self
end
-- TODO: offer an API for inclusive pos_filters
--- select nodes which are after one which satisfies predicate f
MT.after = |self, f, ...| add_pos_filter(self, 'after', false, false, f, ...)
--- select nodes which are not after one which satisfies predicate f
MT.not_after = |self, f, ...| add_pos_filter(self, 'after', true, false, f, ...)
--- select nodes which are under one which satisfies predicate f
MT.under = |self, f, ...| add_pos_filter(self, 'under', false, false, f, ...)
--- select nodes which are not under one which satisfies predicate f
MT.not_under = |self, f, ...| add_pos_filter(self, 'under', true, false, f, ...)
--- select nodes which satisfy predicate f
function MT :filter(f, ...)
if type(f)=='string' then f = M.has_tag(f, ...) end
table.insert(self.filters, f);
return self
end
-- private helper: apply filters and execute up/down callbacks when applicable
function MT :execute()
local cfg = { }
-- TODO: optimize away not_under & not_after by pruning the tree
function cfg.down(...)
--printf ("[down]\t%s\t%s", self.unsatisfied, table.tostring((...)))
local satisfied = self.unsatisfied==0
for _, x in ipairs(self.predicates) do
if not x.satisfied and x.pred(self, ...) then
x.satisfied = true
local node, parent = ...
local inc = x.inverted and 1 or -1
if x.position=='under' then
-- satisfied from after we get down this node...
self.unsatisfied += inc
-- ...until before we get up this node
mmap_add(self.until_up, node, x)
elseif x.position=='after' then
-- satisfied from after we get up this node...
mmap_add(self.from_up, node, x)
-- ...until before we get up this node's parent
mmap_add(self.until_up, parent, x)
elseif x.position=='under_or_after' then
-- satisfied from after we get down this node...
self.satisfied += inc
-- ...until before we get up this node's parent...
mmap_add(self.until_up, parent, x)
else
error "position not understood"
end -- position
if x.inclusive then satisfied = self.unsatisfied==0 end
end -- predicate passed
end -- for predicates
if satisfied then
for _, f in ipairs(self.filters) do
if not f(self, ...) then satisfied=false; break end
end
if satisfied and self.down_f then self.down_f(...) end
end
end
function cfg.up(...)
--printf ("[up]\t%s", table.tostring((...)))
-- Remove predicates which are due before we go up this node
local preds = self.until_up[...]
if preds then
for x, _ in pairs(preds) do
local inc = x.inverted and -1 or 1
self.unsatisfied += inc
x.satisfied = false
end
self.until_up[...] = nil
end
-- Execute the up callback
-- TODO: cache the filter passing result from the down callback
-- TODO: skip if there's no callback
local satisfied = self.unsatisfied==0
if satisfied then
for _, f in ipairs(self.filters) do
if not f(self, ...) then satisfied=false; break end
end
if satisfied and self.up_f then self.up_f(...) end
end
-- Set predicate which are due after we go up this node
local preds = self.from_up[...]
if preds then
for p, _ in pairs(preds) do
local inc = p.inverted and 1 or -1
self.unsatisfied += inc
end
self.from_up[...] = nil
end
end
-- Hacks to comply with old versions of metalua.walk
function down_up(...) cfg.down(...); cfg.up(...) end
return walk.guess({expr=cfg, stat=cfg, block=cfg, binder=down_up}, self.root)
end
--- Execute a function on each selected node
-- @down: function executed when we go down a node, i.e. before its children
-- have been examined.
-- @up: function executed when we go up a node, i.e. after its children
-- have been examined.
function MT :foreach(down, up)
if not up and not down then
error "iterator not implemented"
end
self.up_f = up
self.down_f = down
return self :execute()
end
--- Return the list of nodes selected by a given treequery.
function MT :list()
local acc = { }
self :foreach(|x| table.insert(acc, x))
return acc
end
--- @section Predicates
--- Return a predicate which is true if the tested node's tag is among the
-- one listed as arguments
-- @param ... a sequence of tag names
function M.has_tag(...)
local args = {...}
if #args==1 then
local tag = ...
return (|self, node| node.tag==tag)
--return function(self, node) printf("node %s has_tag %s?", table.tostring(node), tag); return node.tag==tag end
else
local tags = { }
for _, tag in ipairs(args) do tags[tag]=true end
return function(self, node)
local node_tag = node.tag
return node_tag and tags[node_tag]
end
end
end
--- Predicate to test whether a node represents an expression.
M.is_expr = M.has_tag('Nil', 'Dots', 'True', 'False', 'Number','String',
'Function', 'Table', 'Op', 'Paren', 'Call', 'Invoke',
'Id', 'Index')
-- helper for is_stat
local STAT_TAGS = { Do=1, Set=1, While=1, Repeat=1, If=1, Fornum=1,
Forin=1, Local=1, Localrec=1, Return=1, Break=1 }
--- Predicate to test whether a node represents a statement.
-- It is context-aware, i.e. it recognizes `Call and `Invoke nodes
-- used in a statement context as such.
function M.is_stat(self, node, parent)
local tag = node.tag
if not tag then return false
elseif STAT_TAGS[tag] then return true
elseif tag=='Call' or tag=='Invoke' then return parent.tag==nil
else return false end
end
--- Predicate to test whether a node represents a statements block.
function M.is_block(self, node) return node.tag==nil end
--- Transform a predicate on a node into a predicate on this node's
-- parent. For instance if p tests whether a node has property P,
-- then parent(p) tests whether this node's parent has property P.
-- The ancestor level is precised with n, with 1 being the node itself,
-- 2 its parent, 3 its grand-parent etc.
-- @param[optional] n the parent to examine, default=2
-- @param pred the predicate to transform
-- @return a predicate
function M.parent(n, pred, ...)
if type(a)~='number' then n, pred = 2, n end
if type(pred)=='string' then pred = M.has_tag(pred, ...) end
return function(self, ...)
return select(n, ...) and pred(self, select(n, ...))
end
end
--- Predicate to test the position of a node in its parent.
-- The predicate succeeds if the node is the n-th child of its parent,
-- and a <= n <= b.
-- nth(a) is equivalent to nth(a, a).
-- Negative indices are admitted, and count from the last child,
-- as done for instance by string.sub().
-- @param a lower bound
-- @param a upper bound
-- @return a predicate
function M.nth(a, b)
b = b or a
return function(self, node, parent)
if not parent then return false end
local nchildren = #parent
local a = a<=0 and nchildren+a+1 or a
if a>nchildren then return false end
local b = b<=0 and nchildren+b+1 or b>nchildren and nchildren or b
for i=a,b do if parent[i]==node then return true end end
return false
end
end
return M
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment