Created
October 5, 2011 15:34
-
-
Save fab13n/1264745 to your computer and use it in GitHub Desktop.
TreeQuery prototype
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
local walk = require 'metalua.treequery.walk' | |
local M = { } | |
-- support for old-style modules | |
treequery = M | |
-- multimap helper mmap: associate a key to a set of values | |
local function mmap_add (mmap, node, x) | |
if node==nil then return false end | |
local set = mmap[node] | |
if set then set[x] = true | |
else mmap[node] = {[x]=true} end | |
end | |
-- currently unused, I throw the whole set away | |
local function mmap_remove (mmap, node, x) | |
local set = mmap[node] | |
if not set then return false | |
elseif not set[x] then return false | |
elseif next(set) then set[x]=nil | |
else mmap[node] = nil end | |
return true | |
end | |
-- Allow to retrieve a binding `Id{ } from one of its occurrences. | |
-- This is a weak `Id{ } -> `Id{ } table. | |
--local | |
OCC2BIND = setmetatable({ }, { __mode='kv' }) | |
-- treequery metatable | |
local Q = { }; Q.__index = Q | |
--- treequery constructor | |
-- the resultingg object will allow to filter ans operate on the AST | |
-- @param root the AST to visit | |
-- @return a treequery visitor instance | |
function M.treequery(root) | |
return setmetatable({ | |
root = root, | |
unsatisfied = 0, | |
predicates = { }, | |
until_up = { }, | |
from_up = { }, | |
up_f = false, | |
down_f = false, | |
filters = { }, | |
}, Q) | |
end | |
-- helper to share the implementations of positional filters | |
local function add_pos_filter(self, position, inverted, inclusive, f, ...) | |
if type(f)=='string' then f = M.has_tag(f, ...) end | |
if not inverted then self.unsatisfied += 1 end | |
local x = { | |
pred = f, | |
position = position, | |
satisfied = false, | |
inverted = inverted or false, | |
inclusive = inclusive or false } | |
table.insert(self.predicates, x) | |
return self | |
end | |
-- TODO: offer an API for inclusive pos_filters | |
--- select nodes which are after one which satisfies predicate f | |
Q.after = |self, f, ...| add_pos_filter(self, 'after', false, false, f, ...) | |
--- select nodes which are not after one which satisfies predicate f | |
Q.not_after = |self, f, ...| add_pos_filter(self, 'after', true, false, f, ...) | |
--- select nodes which are under one which satisfies predicate f | |
Q.under = |self, f, ...| add_pos_filter(self, 'under', false, false, f, ...) | |
--- select nodes which are not under one which satisfies predicate f | |
Q.not_under = |self, f, ...| add_pos_filter(self, 'under', true, false, f, ...) | |
--- select nodes which satisfy predicate f | |
function Q :filter(f, ...) | |
if type(f)=='string' then f = M.has_tag(f, ...) end | |
table.insert(self.filters, f); | |
return self | |
end | |
-- private helper: apply filters and execute up/down callbacks when applicable | |
function Q :execute() | |
local cfg = { } | |
-- TODO: optimize away not_under & not_after by pruning the tree | |
function cfg.down(...) | |
--printf ("[down]\t%s\t%s", self.unsatisfied, table.tostring((...))) | |
local satisfied = self.unsatisfied==0 | |
for _, x in ipairs(self.predicates) do | |
if not x.satisfied and x.pred(self, ...) then | |
x.satisfied = true | |
local node, parent = ... | |
local inc = x.inverted and 1 or -1 | |
if x.position=='under' then | |
-- satisfied from after we get down this node... | |
self.unsatisfied += inc | |
-- ...until before we get up this node | |
mmap_add(self.until_up, node, x) | |
elseif x.position=='after' then | |
-- satisfied from after we get up this node... | |
mmap_add(self.from_up, node, x) | |
-- ...until before we get up this node's parent | |
mmap_add(self.until_up, parent, x) | |
elseif x.position=='under_or_after' then | |
-- satisfied from after we get down this node... | |
self.satisfied += inc | |
-- ...until before we get up this node's parent... | |
mmap_add(self.until_up, parent, x) | |
else | |
error "position not understood" | |
end -- position | |
if x.inclusive then satisfied = self.unsatisfied==0 end | |
end -- predicate passed | |
end -- for predicates | |
if satisfied then | |
for _, f in ipairs(self.filters) do | |
if not f(self, ...) then satisfied=false; break end | |
end | |
if satisfied and self.down_f then self.down_f(...) end | |
end | |
end | |
function cfg.up(...) | |
--printf ("[up]\t%s", table.tostring((...))) | |
-- Remove predicates which are due before we go up this node | |
local preds = self.until_up[...] | |
if preds then | |
for x, _ in pairs(preds) do | |
local inc = x.inverted and -1 or 1 | |
self.unsatisfied += inc | |
x.satisfied = false | |
end | |
self.until_up[...] = nil | |
end | |
-- Execute the up callback | |
-- TODO: cache the filter passing result from the down callback | |
-- TODO: skip if there's no callback | |
local satisfied = self.unsatisfied==0 | |
if satisfied then | |
for _, f in ipairs(self.filters) do | |
if not f(self, ...) then satisfied=false; break end | |
end | |
if satisfied and self.up_f then self.up_f(...) end | |
end | |
-- Set predicate which are due after we go up this node | |
local preds = self.from_up[...] | |
if preds then | |
for p, _ in pairs(preds) do | |
local inc = p.inverted and 1 or -1 | |
self.unsatisfied += inc | |
end | |
self.from_up[...] = nil | |
end | |
end | |
function cfg.binder(id_node, ...) | |
--printf(" >>> Binder called on %s, %s", table.tostring(id_node), | |
-- table.tostring{...}:sub(2,-2)) | |
cfg.down(id_node, ...) | |
cfg.up(id_node, ...) | |
--printf("down/up on binder done") | |
end | |
function cfg.occurrence (binder, occ) | |
if binder then OCC2BIND[occ] = binder[1] end | |
--printf(" >>> %s is an occurrence of %s", occ[1], table.tostring(binder and binder[2])) | |
end | |
--function cfg.binder(...) cfg.down(...); cfg.up(...) end | |
return walk.guess(cfg, self.root) | |
end | |
--- Execute a function on each selected node | |
-- @down: function executed when we go down a node, i.e. before its children | |
-- have been examined. | |
-- @up: function executed when we go up a node, i.e. after its children | |
-- have been examined. | |
function Q :foreach(down, up) | |
if not up and not down then | |
error "iterator not implemented" | |
end | |
self.up_f = up | |
self.down_f = down | |
return self :execute() | |
end | |
--- Return the list of nodes selected by a given treequery. | |
function Q :list() | |
local acc = { } | |
self :foreach(|x| table.insert(acc, x)) | |
return acc | |
end | |
--- Return the first matching element | |
-- TODO: dirty hack, to implement properly with a 'break' return. | |
-- Also, it won't behave correctly if a predicate causes an error, | |
-- or if coroutines are involved. | |
function Q :first() | |
local result = nil | |
local function f(...) result = {...}; error() end | |
pcall(|| self :foreach(f)) | |
return unpack(result) | |
end | |
--- Pretty printer for queries | |
function Q :__tostring() return "treequery("..table.tostring(self.root, 'nohash')..")" end | |
--- @section Predicates | |
--- Return a predicate which is true if the tested node's tag is among the | |
-- one listed as arguments | |
-- @param ... a sequence of tag names | |
function M.has_tag(...) | |
local args = {...} | |
if #args==1 then | |
local tag = ... | |
return (|self, node| node.tag==tag) | |
--return function(self, node) printf("node %s has_tag %s?", table.tostring(node), tag); return node.tag==tag end | |
else | |
local tags = { } | |
for _, tag in ipairs(args) do tags[tag]=true end | |
return function(self, node) | |
local node_tag = node.tag | |
return node_tag and tags[node_tag] | |
end | |
end | |
end | |
--- Predicate to test whether a node represents an expression. | |
M.is_expr = M.has_tag('Nil', 'Dots', 'True', 'False', 'Number','String', | |
'Function', 'Table', 'Op', 'Paren', 'Call', 'Invoke', | |
'Id', 'Index') | |
-- helper for is_stat | |
local STAT_TAGS = { Do=1, Set=1, While=1, Repeat=1, If=1, Fornum=1, | |
Forin=1, Local=1, Localrec=1, Return=1, Break=1 } | |
--- Predicate to test whether a node represents a statement. | |
-- It is context-aware, i.e. it recognizes `Call and `Invoke nodes | |
-- used in a statement context as such. | |
function M.is_stat(self, node, parent) | |
local tag = node.tag | |
if not tag then return false | |
elseif STAT_TAGS[tag] then return true | |
elseif tag=='Call' or tag=='Invoke' then return parent.tag==nil | |
else return false end | |
end | |
--- Predicate to test whether a node represents a statements block. | |
function M.is_block(self, node) return node.tag==nil end | |
local BINDER_GRAND_PARENT_TAG = { | |
Local=true, Localrec=true, Forin=true, Function=true } | |
function M.is_binder(self, a, b) | |
--printf('is_binder(self, %s, %s, %s)', table.tostring(a), table.tostring(b), table.tostring(c)) | |
if a.tag ~= 'Id' or not b then return false end | |
if b.tag=='Fornum' then return b[1]==a end | |
if not BINDER_GRAND_PARENT_TAG[b.tag] then return false end | |
for _, a2 in ipairs(b[1]) do if a2==a then return true end end | |
return false | |
end | |
function M.binder(node) | |
return OCC2BIND[node] | |
end | |
M.is_bound_by = |binder||self, node| OCC2BIND[node] == binder | |
--- Transform a predicate on a node into a predicate on this node's | |
-- parent. For instance if p tests whether a node has property P, | |
-- then parent(p) tests whether this node's parent has property P. | |
-- The ancestor level is precised with n, with 1 being the node itself, | |
-- 2 its parent, 3 its grand-parent etc. | |
-- @param[optional] n the parent to examine, default=2 | |
-- @param pred the predicate to transform | |
-- @return a predicate | |
function M.parent(n, pred, ...) | |
if type(a)~='number' then n, pred = 2, n end | |
if type(pred)=='string' then pred = M.has_tag(pred, ...) end | |
return function(self, ...) | |
return select(n, ...) and pred(self, select(n, ...)) | |
end | |
end | |
--- Predicate to test the position of a node in its parent. | |
-- The predicate succeeds if the node is the n-th child of its parent, | |
-- and a <= n <= b. | |
-- nth(a) is equivalent to nth(a, a). | |
-- Negative indices are admitted, and count from the last child, | |
-- as done for instance by string.sub(). | |
-- @param a lower bound | |
-- @param a upper bound | |
-- @return a predicate | |
function M.nth(a, b) | |
b = b or a | |
return function(self, node, parent) | |
if not parent then return false end | |
local nchildren = #parent | |
local a = a<=0 and nchildren+a+1 or a | |
if a>nchildren then return false end | |
local b = b<=0 and nchildren+b+1 or b>nchildren and nchildren or b | |
for i=a,b do if parent[i]==node then return true end end | |
return false | |
end | |
end | |
local comment_extractor = |which_side| function (node) | |
local x = node.lineinfo | |
x = x and x[which_side] | |
x = x and x.comments | |
if not x then return nil end | |
local lines = { } | |
for _, record in ipairs(x) do | |
table.insert(lines, record[1]) | |
end | |
return table.concat(lines, '\n') | |
end | |
M.comment_prefix = comment_extractor 'first' | |
M.comment_suffix = comment_extractor 'last' | |
--- Shortcut for the query constructor | |
function M :__call(...) return self.treequery(...) end | |
setmetatable(M, M) | |
return M |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment