|
-- vim: set syntax=lua ts=2 sw=2 expandtab: |
|
local lib=require"lib" |
|
local the=lib.settings[[ |
|
sway101: a little LUA learning lab for exploring sampling |
|
(c) 2023, Tim Menzies <[email protected]> BSD-2 |
|
|
|
USAGE: lua tests.lua [OPTIONS] [ -g ACTION] |
|
|
|
OPTIONS: |
|
-b --bins initial number of bins = 16 |
|
-c --cohen cohen's D = .35 |
|
-f --file data file = ../data/auto93.csv |
|
-F --Far distance to far points = .95 |
|
-g --go start up action = nothing |
|
-G --Goal goal criteria = plan |
|
-h --help show help = false |
|
-m --min cluster leaf size=N^min = .5 |
|
-n --nums how many nums to cache = 256 |
|
-p --p distance exponent = 2 |
|
-r --rest expansion of best = 3 |
|
-S --Some rows to explore for poles = 512 |
|
-s --seed random number seed = 937162211 |
|
-w --wild run tests,not protected = false]] |
|
--------------------------------------------------------------------------------------------------- |
|
local kap,map,o,oo,push = lib.kap, lib.map, lib.o, lib.oo, lib.push |
|
local rand,rint,lt,gt,sort = lib.rand,lib.rint,lib.lt, lib.gt, lib.sort |
|
local copy,inc,per = lib.copy,lib.inc,lib.per |
|
local fmt,any,many = lib.fmt,lib.any, lib.many |
|
local median,stdev,entropy = lib.median, lib.stdev, lib.entropy |
|
|
|
local isNum,isSym,isData,isRow |
|
function isRow(t) return t.cells end |
|
function isNum(t) return t.lo end |
|
function isSym(t) return t.most end |
|
function isData(t) return t.rows end |
|
--------------------------------------------------------------------------------------------------- |
|
local NUM,SYM,COL,COLS |
|
function COL(n,s, col) |
|
col = (s:find"^[A-Z]" and NUM or SYM)(n,s) |
|
col.isIgnored = col.txt:find"X$" |
|
col.isKlass = col.txt:find"!$" |
|
col.isGoal = col.txt:find"[!+-]$" |
|
return col end |
|
|
|
function COLS(ss, col,cols) |
|
cols = {names=ss, all={},x={},y={}} |
|
for n,s in pairs(ss) do |
|
col = push(cols.all, COL(n,s)) |
|
if not col.isIgnored then |
|
if col.isKlass then cols.klass = col end |
|
push(col.isGoal and cols.y or cols.x, col) end end |
|
return cols end |
|
|
|
function NUM(n,s) |
|
return {at=n, txt=s or "", n=0, has={}, ok=false, |
|
hi=-math.huge, lo=math.huge, w=(s or ""):find"-$" and -1 or 1} end |
|
|
|
function SYM(n,s) |
|
return {at=n, txt=s or "", n=0, has={}, most=0, mode=nil} end |
|
--------------------------------------------------------------------------------------------------- |
|
local add,has,mid,div,norm,merge,merged,includes |
|
function add(col,x, n, t) |
|
if x=="?" then return x end |
|
n = n or 1 |
|
col.n = col.n + n |
|
t = col.has |
|
if isSym(col) |
|
then if inc(t,x,n) > col.most then col.most,col.mode = t[x],x end else |
|
col.lo = math.min(x, col.lo) |
|
col.hi = math.max(x, col.hi) |
|
if #t<the.nums then col.ok=false; t[1+#t]= x |
|
elseif rand() < the.nums/col.n then col.ok=false; t[rint(1,#t)] = x end end end |
|
|
|
function has(col) |
|
if isNum(col) and not col.ok then sort(col.has); col.ok=true end |
|
return col.has end |
|
|
|
function mid(col) |
|
return isSym(col) and col.mode or median(has(col)) end |
|
|
|
function div(col) |
|
return isNum(col) and stdev(has(col)) or entropy(col.has) end |
|
|
|
function norm(num,x) |
|
return (x=="?" and x) or (x - num.lo)/(num.hi - num.lo) end |
|
|
|
function merge(col1,col2, new) |
|
new = copy(col1) |
|
if isSym(col1) |
|
then for x,n in pairs(col2.has) do add(new,x,n) end |
|
else for _,n in pairs(col2.has) do add(new,n) end |
|
new.lo = math.min(col1.lo, col2.lo) |
|
new.hi = math.max(col1.hi, col2.hi) end |
|
return new end |
|
|
|
function merged(col1,col2,nSmall, nFar, new) |
|
new = merge(col1,col2) |
|
if nSmall and col1.n < nSmall or col2.n < nSmall then return new end |
|
if nFar and not isSym(col1) and math.abs(mid(col1) - mid(col2)) < nFar then return new end |
|
if div(new) <= (div(col1)*col1.n + div(col2)*col2.n)/new.n then |
|
return new end end |
|
|
|
function includes(col,x) |
|
return x=="?" or (isSym(col) and col.has[x]) or (x >= col.lo and x <= col.hi) end |
|
--------------------------------------------------------------------------------------------------- |
|
local ROW,DATA,row,better,column,stats |
|
function ROW(t) return {cells=t, best=false, evaluated=false} end |
|
|
|
function DATA(src, also, data,add) |
|
data = {rows={}, cols=nil} |
|
add = function(t) row(data,t) end |
|
if type(src)=="string" then lib.csv(src, add) -- src is a csv file |
|
elseif isData(src) then add(src.cols.names) -- src is a DATA we want to emulate |
|
else map(src,add) end -- src is a list |
|
map(also or {}, add) |
|
return data end |
|
|
|
function row(data,t) |
|
if data.cols then |
|
t = isRow(t) and t or ROW(t) |
|
push(data.rows, t) |
|
for _,cols in pairs{data.cols.x, data.cols.y} do |
|
for _,col in pairs(cols) do |
|
add(col, t.cells[col.at]) end end |
|
else data.cols = COLS(t) end end |
|
|
|
function better(data,row1,row2) |
|
local s1,s2,n,x,y = 0,0,#data.cols.y |
|
for _,col in pairs(data.cols.y) do |
|
x = norm(col, row1.cells[col.at] ) |
|
y = norm(col, row2.cells[col.at] ) |
|
s1 = s1 - math.exp(col.w * (x-y)/n) |
|
s2 = s2 - math.exp(col.w * (y-x)/n) end |
|
return s1/n < s2/n end |
|
|
|
function stats(data, what,cols,nPlaces, fun,tmp) |
|
function fun(_,col, tmp) |
|
tmp= (what or mid)(col) |
|
tmp= isNum(col) and lib.rnd(tmp,nPlaces) or tmp |
|
return tmp,col.txt end |
|
tmp = kap(cols or data.cols.y, fun) |
|
tmp["N"] = #data.rows |
|
return tmp end |
|
--------------------------------------------------------------------------------------------------- |
|
local dist,around,half,halves,optimize |
|
function dist(i,data,row1,row2, sym,num) |
|
function sym(a,b) |
|
return a=="?" and b=="?" and 1 or a==b and 0 or 1 end |
|
function num(a,b) |
|
if a=="?" and b=="?" then return 1 |
|
elseif a=="?" then a = b < .5 and 1 or 0 |
|
elseif b=="?" then b = a < .5 and 1 or 0 end |
|
return math.abs(a - b) |
|
end ---------- |
|
local d,n,a,b,inc = 0,#data.cols.x |
|
for _,col in pairs(data.cols.x) do |
|
a,b = row1.cells[col.at], row2.cells[col.at] |
|
inc = isNum(col) and num(norm(col,a), norm(col,b)) or sym(a,b) |
|
d = d + inc^i.p end |
|
return (d/n)^(1/i.p) end |
|
|
|
function around(i,data,row1,rows, fun) |
|
fun = function(row2) return {dist=dist(i,data,row1,row2), row=row2} end |
|
return sort(map(rows,fun),lt"dist") end |
|
--------------------------------------------------------------------------------------------------- |
|
function half(i,data,rows, project,some,A,B,c,left,right,far,gap) |
|
function gap(r1,r2) return dist(i,data,r1,r2) end |
|
function far(r,rows) |
|
return around(i,data,r,rows)[(i.Far * #rows) // 1].row end |
|
function project(r, a,b) |
|
a, b = gap(r,A), gap(r,B) |
|
return {row=r, x=(a^2 + c^2 - b^2)/(2*c)} end |
|
some = many(rows,i.Some) |
|
A = far(any(some),some) |
|
B = far(A,some) |
|
c = gap(A,B) |
|
left,right = {},{} |
|
for n,tmp in pairs(sort(map(rows,project),lt"x")) do |
|
push(n < #rows/2 and left or right, tmp.row) end |
|
return left,right,A,B,c end |
|
|
|
function halves(i,data, rows,stop) |
|
rows = rows or data.rows |
|
stop = stop or (#rows)^i.min |
|
if #rows <= stop |
|
then return {rows=rows} |
|
else local left,right = half(i,data,rows) |
|
return {rows = rows, |
|
left = halves(i,data,left,stop), |
|
right = halves(i,data,right,stop)} end end |
|
--------------------------------------------------------------------------------------------------- |
|
function optimize(i,data, rows,stop,rest) |
|
rows = rows or data.rows |
|
stop = stop or (#rows)^i.min |
|
rest = rest or {} |
|
if #rows <= stop |
|
then return DATA(data,rows), DATA(data, many(rest,#rows*the.rest)) |
|
else local left,right,A,B = half(i,data,rows) |
|
if better(data,B,A) then left,right=right,left end |
|
for _,row in pairs(right) do push(rest,row) end |
|
return optimize(i,data,left,stop,rest) end end |
|
--------------------------------------------------------------------------------------------------- |
|
local score |
|
local is={} |
|
is.plan = function(b,r) return b^2/(b+r) end |
|
is.fear = function(b,r) return r^2/(b+r) end |
|
is.tabu = function(b,r) return math.log(1/(b+r)) end |
|
|
|
function score(i,b,r,B,R) |
|
local tiny = 1E-64 |
|
return is[i.Goal](b/(B+tiny), r/(R+tiny)) end |
|
|
|
print("BR",score({Goal="plan"},4,0,11,33)) |
|
-- function extend(tests,test,filter) |
|
-- tests=copy(tests) |
|
-- tests[1+#tests]=(filter or lib.itself)(test) |
|
-- return tests end |
|
-- |
|
-- local function TREE(i,cols,best,rows0) |
|
-- local B,R = br(rows0,bestok-) |
|
-- local val = function(b,r) return score(i,b,r,B,R) end |
|
-- stop = (#rows0)^i.min |
|
-- local function tree(rows,path) |
|
-- if #rows >= 2*stop then |
|
-- tmp = map(cols, function(col) return branch(col,rows,,best,val) end) |
|
-- local b, r, yes, no, test = 0, 0, {}, {}, sort(tmp, gt"val")[1].test |
|
-- if test then |
|
-- for _,row in pairs(rows) do |
|
-- if row.y==best then b=b+1 else r=r+1 end |
|
-- push(accept(test,row) and yes or no, row) end |
|
-- return { |
|
-- rows = rows, |
|
-- score = val(b,r), |
|
-- tests = path, |
|
-- left = #yes< #rows and tree(yes,extend(path,test)), |
|
-- right = #no < #rows and tree(no, extend(path,test,flip))} end end |
|
-- end -------------------------------------------- |
|
-- return tree(rows0,{}) end |
|
-- |
|
-- function branch(col,rows,best,val) |
|
-- local function good(row) if row.cells[col.at] ~= "?" then return row end end |
|
-- rows = sort(map(rows,good), function(r1,r2) return r1.cells[col.at]< r2.cells[col.at] end) |
|
-- local function sym() |
|
-- local b,r = {},{} |
|
-- for j,row in pairs(rows) do |
|
-- local x=row.cells[col.at] |
|
-- b[x] = b[x] or 0 |
|
-- r[x] = r[x] or 0 |
|
-- if row.y==best then b[x]=b[x]+1 else r[x]=r[x]+1 end end |
|
-- local fun= function(x,b) return {val=val(b,r[x]), cut=x} end |
|
-- local tmp= sort(kap(b, fun),gt"val")[1] |
|
-- return {val=tmp.val, test={at=col.at,txt=col.txt,op="=",x=tmp.cut}} |
|
-- end ------------------------ |
|
-- function num() |
|
-- local b0,r0,op,cut = 0,0 |
|
-- local b1,r1 = br(rows, best) |
|
-- local best = val(b1,r1) |
|
-- for j,row in pairs(rows) do -- find the cut that minimizes expected value of entropy |
|
-- if row.y==best then b0=b0+1; b1=b1-1 else r0=r0+1; r1=r1-1 end |
|
-- local x=row.cells[col.at] |
|
-- if j < #rows and x ~= rows[j+1].cells[col.at] then |
|
-- local v1 = val(b0,r0) |
|
-- local v2 = val(b1,r1) |
|
-- if v1 > best then best,cut,op = v1,x,"<=" end |
|
-- if v2 > best then best,cut,op = v2,x,">" end end end |
|
-- if cut then |
|
-- return {val=best, test={at=col.at, txt=col.txt, op=op, x=cut}} end |
|
-- end -------------------------------------------------------- |
|
-- return (isNum(col) and num or sym)() end |
|
-- |
|
local bin,bins,merges,rank,noGaps,selects,sorted,select1 |
|
function bin(i,col,x, tmp) |
|
if x=="?" or isSym(col) then return x end |
|
tmp = (col.hi - col.lo)/(i.bins - 1) |
|
return col.hi == col.lo and 1 or math.floor(x/tmp + .5)*tmp end |
|
|
|
function bins(i,cols,best,rowss) |
|
local out={} |
|
for _,col in pairs(cols) do -- for all columns |
|
local n,xys=0,{} |
|
for klass,rows in pairs(rowss) do -- for all klasses |
|
for _,row in pairs(rows) do -- for all rows in a klass |
|
local x=row.cells[col.at] |
|
if x ~= "?" then -- for all non-empty cells |
|
n = n + 1 |
|
local k = bin(i,col,x) -- map cell to a small number of bins |
|
xys[k] = xys[k] or {x=NUM(col.at,col.txt), y=SYM(col.at,col.txt)} |
|
add(xys[k].y, klass) -- track best/non-best cell values seen in this bin |
|
add(xys[k].x, x) end -- track x cell values seen in this bin |
|
end -- for rows |
|
end -- for klasses |
|
xys = sort(map(xys,lib.itself), function(a,b) return a.x.lo < b.x.lo end) |
|
xys = isSym(col) and xys or merges(xys,n/i.bins,i.cohen*div(col)) |
|
push(out, sorted(i,best,rowss,xys)[1]) |
|
end -- for col |
|
return selects(i,out,best,rowss) end |
|
|
|
function merges(xys0,nSmall,nFar) |
|
local j,xys1=1,{} |
|
while j <= #xys0 do |
|
local one, two = xys0[j], xys0[j+1] -- in lua t0[j+1] returns nil after end of arrary |
|
if two then -- if not at end of array |
|
local x = merged(one.x,two.x,nSmall,nFar) -- if we can merge these x-values |
|
if x then |
|
one = {x=x, y=merge(one.y, two.y)} -- then combine the y-values |
|
j= j + 1 -- skip over item two |
|
end end |
|
push(xys1, one) -- at each stage of the while, keep one thing |
|
j=j+1 end |
|
return #xys0 == #xys1 and xys0 or merges(xys1,nSmall,nFar) end |
|
|
|
function noGaps(xys) |
|
xys[1].x.lo = -math.huge |
|
xys[#xys].x.hi = math.huge |
|
for j=2,#xys do xys[j].x.lo = xys[j-1].x.hi end |
|
return xys end |
|
|
|
function sorted(i,best,rowss,xys) |
|
local B,R=0,0 |
|
for klass,rows in pairs(rowss) do -- get background ratios |
|
if klass==best then B=B+#rows else R=R+#rows end end |
|
for _,xy in pairs(xys) do -- get ratios and score for each xy |
|
local b,r=0,0 |
|
for klass,n in pairs(xy.y.has) do |
|
if klass==best then b=b+n else r=r+n end end |
|
xy.score = score(i,b,r,B,R) end |
|
xys = map(xys, function(xy) if xy.score>0 then return xy end end) -- kill bad scores |
|
return sort(xys,gt"score") end |
|
|
|
function select(xy,row) |
|
local x= row.cells[xy.x.at] |
|
local tmp= x=="?" or xy.x.lo == xy.x.hi and x==xy.x.lo or x>= xy.x.lo and x<=xy.x.hi |
|
return tmp end |
|
|
|
function selects(i,xys,best,rowss) |
|
local all,scored,B,R = {},{},0,0 |
|
for klass,rows in pairs(rowss) do |
|
if klass==best then B=B+#rows else R=R+#rows end |
|
for j,row in pairs(rows) do io.write(j," ");row.klass=klass; push(all,row) end end |
|
print("all",#all) |
|
xys = sort(xys,gt"score") |
|
for stop=1,#xys do |
|
local tmp = select1(i,stop,xys,best,all,B,R) |
|
if tmp then push(scored,tmp) end end |
|
map(scored,function(z) print("zz",#z.rule,z.score) end) |
|
return sort(scored,gt"score") end |
|
|
|
function select1(i,stop,xys,best,all,B,R) |
|
local tmp,how=all,{} |
|
for j=1,stop do |
|
tmp = map(tmp,function(row) if select(xys[j],row) then return row end end) |
|
if #tmp == 0 then break else |
|
local x=xys[j].x |
|
how[j] = {at=x.at, txt=x.txt, lo=x.lo,hi=x.hi} end |
|
end |
|
local b,r=0,0 |
|
for _,row in pairs(tmp) do |
|
if row.klass==best then b=b+1 else r=r+1 end end |
|
print("br",b,r,B,R) |
|
return {rule=how, score=score(i,b,r,B,R)} end |
|
|
|
|
|
|
|
-- local rule,rule1,branches,numBranches,symBranches |
|
-- function rule(i,cols,B,R,rows) |
|
-- return rule1(cols, rows, |
|
-- {}, |
|
-- (#rows)^i.min, |
|
-- function(b,r) return score(i,b,r,B,R) end) end |
|
-- |
|
-- function rule1(cols,rows,path,stop,val) |
|
-- if #rows < 2*stop then return path end |
|
-- tests = {} |
|
-- map(cols, function(col) branches(col,rows,val,tests) end) |
|
-- test = sort(tests,gt"val")[1] |
|
-- local rest,b,r = {},0,0 |
|
-- for _,row in pairs(rows) do |
|
-- if accept(test,row) then |
|
-- if row.best then b=b+1 else r=r+1 end |
|
-- else push(rest,row) end end |
|
-- push(path,{test=test,b=b,r=r,val=val}) |
|
-- if #no < #rows1 then return rule1(cols,no,path,stop,val) end end |
|
-- |
|
-- function numBranches(col,rows,val,out) |
|
-- local b1,r1=br(rows) |
|
-- local most,b0,r0 = val(b1,r1),0,0 |
|
-- local cut,op |
|
-- for j,row in pairs(rows) do |
|
-- if row.y then b0=b0+1; b1=b1-1 else r0=r0+1; r1=r1-1 end |
|
-- local x=row.cells[col.at] |
|
-- if j < #rows-stop and j> stop and x ~= rows[j+1].cells[col.at] then |
|
-- local v1 = val(b0,r0) |
|
-- local v2 = val(b1,r1) |
|
-- if v1 > most then most,cut,op = v1,x,"<=" end |
|
-- if v2 > most then most,cut,op = v2,x,">" end end end |
|
-- if cut then |
|
-- push(out, {val=most, test={at=col.at, txt=col.txt, op=op, x=cut}}) end end |
|
-- |
|
-- function symBranches(col,rows,val,out) |
|
-- local bs,rs = {},{} |
|
-- for j,row in pairs(rows) do |
|
-- local x=row.cells[col.at] |
|
-- bs[x] = bs[x] or 0 |
|
-- rs[x] = rs[x] or 0 |
|
-- if row.best then bs[x]=bs[x]+1 else rs[x]=rs[x]+1 end end |
|
-- for x1,br in pairs(bs) do |
|
-- push(out, {val=val(br,rs[x1]), test={at=col.at,txt=col.txt,op="=",x=x1}}) end end |
|
-- |
|
-- function branches(col,rows,val,out) |
|
-- local function good(row) if row.cells[col.at] ~= "?" then return row end end |
|
-- rows = sort(map(rows,good), function(r1,r2) return r1.cells[col.at] < r2.cells[col.at] end) |
|
-- (isNum(col) and numBranches or symBranches)(col,rows,vals,out) end |
|
-- |
|
--------------------------------------------------------------------------------------------------- |
|
return pcall(debug.getlocal,4,1) and lib.locals() or {} |