-- vim: set syntax=lua ts=2 sw=2 expandtab: |
local lib=require"lib" |
local the=lib.settings[[ |
sway101: a little LUA learning lab for exploring sampling |
(c) 2023, Tim Menzies <[email protected]> BSD-2 |
USAGE: lua tests.lua [OPTIONS] [ -g ACTION] |
-b --bins initial number of bins = 16 |
-c --cohen cohen's D = .35 |
-f --file data file = ../data/auto93.csv |
-F --Far distance to far points = .95 |
-g --go start up action = nothing |
-G --Goal goal criteria = plan |
-h --help show help = false |
-m --min cluster leaf size=N^min = .5 |
-n --nums how many nums to cache = 256 |
-p --p distance exponent = 2 |
-r --rest expansion of best = 3 |
-S --Some rows to explore for poles = 512 |
-s --seed random number seed = 937162211 |
-w --wild run tests,not protected = false]] |
--------------------------------------------------------------------------------------------------- |
local kap,map,o,oo,push = lib.kap, lib.map, lib.o, lib.oo, lib.push |
local rand,rint,lt,gt,sort = lib.rand,lib.rint,lib.lt, lib.gt, lib.sort |
local copy,inc,per = lib.copy,lib.inc,lib.per |
local fmt,any,many = lib.fmt,lib.any, lib.many |
local median,stdev,entropy = lib.median, lib.stdev, lib.entropy |
local isNum,isSym,isData,isRow |
function isRow(t) return t.cells end |
function isNum(t) return t.lo end |
function isSym(t) return t.most end |
function isData(t) return t.rows end |
--------------------------------------------------------------------------------------------------- |
function COL(n,s, col) |
col = (s:find"^[A-Z]" and NUM or SYM)(n,s) |
col.isIgnored = col.txt:find"X$" |
col.isKlass = col.txt:find"!$" |
col.isGoal = col.txt:find"[!+-]$" |
return col end |
function COLS(ss, col,cols) |
cols = {names=ss, all={},x={},y={}} |
for n,s in pairs(ss) do |
col = push(cols.all, COL(n,s)) |
if not col.isIgnored then |
if col.isKlass then cols.klass = col end |
push(col.isGoal and cols.y or cols.x, col) end end |
return cols end |
function NUM(n,s) |
return {at=n, txt=s or "", n=0, has={}, ok=false, |
hi=-math.huge, lo=math.huge, w=(s or ""):find"-$" and -1 or 1} end |
function SYM(n,s) |
return {at=n, txt=s or "", n=0, has={}, most=0, mode=nil} end |
--------------------------------------------------------------------------------------------------- |
local add,has,mid,div,norm,merge,merged,includes |
function add(col,x, n, t) |
if x=="?" then return x end |
n = n or 1 |
col.n = col.n + n |
t = col.has |
if isSym(col) |
then if inc(t,x,n) > col.most then col.most,col.mode = t[x],x end else |
col.lo = math.min(x, col.lo) |
col.hi = math.max(x, col.hi) |
if #t<the.nums then col.ok=false; t[1+#t]= x |
elseif rand() < the.nums/col.n then col.ok=false; t[rint(1,#t)] = x end end end |
function has(col) |
if isNum(col) and not col.ok then sort(col.has); col.ok=true end |
return col.has end |
function mid(col) |
return isSym(col) and col.mode or median(has(col)) end |
function div(col) |
return isNum(col) and stdev(has(col)) or entropy(col.has) end |
function norm(num,x) |
return (x=="?" and x) or (x - num.lo)/(num.hi - num.lo) end |
function merge(col1,col2, new) |
new = copy(col1) |
if isSym(col1) |
then for x,n in pairs(col2.has) do add(new,x,n) end |
else for _,n in pairs(col2.has) do add(new,n) end |
new.lo = math.min(col1.lo, col2.lo) |
new.hi = math.max(col1.hi, col2.hi) end |
return new end |
function merged(col1,col2,nSmall, nFar, new) |
new = merge(col1,col2) |
if nSmall and col1.n < nSmall or col2.n < nSmall then return new end |
if nFar and not isSym(col1) and math.abs(mid(col1) - mid(col2)) < nFar then return new end |
if div(new) <= (div(col1)*col1.n + div(col2)*col2.n)/new.n then |
return new end end |
function includes(col,x) |
return x=="?" or (isSym(col) and col.has[x]) or (x >= col.lo and x <= col.hi) end |
--------------------------------------------------------------------------------------------------- |
local ROW,DATA,row,better,column,stats |
function ROW(t) return {cells=t, best=false, evaluated=false} end |
function DATA(src, also, data,add) |
data = {rows={}, cols=nil} |
add = function(t) row(data,t) end |
if type(src)=="string" then lib.csv(src, add) -- src is a csv file |
elseif isData(src) then add(src.cols.names) -- src is a DATA we want to emulate |
else map(src,add) end -- src is a list |
map(also or {}, add) |
return data end |
function row(data,t) |
if data.cols then |
t = isRow(t) and t or ROW(t) |
push(data.rows, t) |
for _,cols in pairs{data.cols.x, data.cols.y} do |
for _,col in pairs(cols) do |
add(col, t.cells[col.at]) end end |
else data.cols = COLS(t) end end |
function better(data,row1,row2) |
local s1,s2,n,x,y = 0,0,#data.cols.y |
for _,col in pairs(data.cols.y) do |
x = norm(col, row1.cells[col.at] ) |
y = norm(col, row2.cells[col.at] ) |
s1 = s1 - math.exp(col.w * (x-y)/n) |
s2 = s2 - math.exp(col.w * (y-x)/n) end |
return s1/n < s2/n end |
function stats(data, what,cols,nPlaces, fun,tmp) |
function fun(_,col, tmp) |
tmp= (what or mid)(col) |
tmp= isNum(col) and lib.rnd(tmp,nPlaces) or tmp |
return tmp,col.txt end |
tmp = kap(cols or data.cols.y, fun) |
tmp["N"] = #data.rows |
return tmp end |
--------------------------------------------------------------------------------------------------- |
local dist,around,half,halves,optimize |
function dist(i,data,row1,row2, sym,num) |
function sym(a,b) |
return a=="?" and b=="?" and 1 or a==b and 0 or 1 end |
function num(a,b) |
if a=="?" and b=="?" then return 1 |
elseif a=="?" then a = b < .5 and 1 or 0 |
elseif b=="?" then b = a < .5 and 1 or 0 end |
return math.abs(a - b) |
end ---------- |
local d,n,a,b,inc = 0,#data.cols.x |
for _,col in pairs(data.cols.x) do |
a,b = row1.cells[col.at], row2.cells[col.at] |
inc = isNum(col) and num(norm(col,a), norm(col,b)) or sym(a,b) |
d = d + inc^i.p end |
return (d/n)^(1/i.p) end |
function around(i,data,row1,rows, fun) |
fun = function(row2) return {dist=dist(i,data,row1,row2), row=row2} end |
return sort(map(rows,fun),lt"dist") end |
--------------------------------------------------------------------------------------------------- |
function half(i,data,rows, project,some,A,B,c,left,right,far,gap) |
function gap(r1,r2) return dist(i,data,r1,r2) end |
function far(r,rows) |
return around(i,data,r,rows)[(i.Far * #rows) // 1].row end |
function project(r, a,b) |
a, b = gap(r,A), gap(r,B) |
return {row=r, x=(a^2 + c^2 - b^2)/(2*c)} end |
some = many(rows,i.Some) |
A = far(any(some),some) |
B = far(A,some) |
c = gap(A,B) |
left,right = {},{} |
for n,tmp in pairs(sort(map(rows,project),lt"x")) do |
push(n < #rows/2 and left or right, tmp.row) end |
return left,right,A,B,c end |
function halves(i,data, rows,stop) |
rows = rows or data.rows |
stop = stop or (#rows)^i.min |
if #rows <= stop |
then return {rows=rows} |
else local left,right = half(i,data,rows) |
return {rows = rows, |
left = halves(i,data,left,stop), |
right = halves(i,data,right,stop)} end end |
--------------------------------------------------------------------------------------------------- |
function optimize(i,data, rows,stop,rest) |
rows = rows or data.rows |
stop = stop or (#rows)^i.min |
rest = rest or {} |
if #rows <= stop |
then return DATA(data,rows), DATA(data, many(rest,#rows*the.rest)) |
else local left,right,A,B = half(i,data,rows) |
if better(data,B,A) then left,right=right,left end |
for _,row in pairs(right) do push(rest,row) end |
return optimize(i,data,left,stop,rest) end end |
--------------------------------------------------------------------------------------------------- |
local score |
local is={} |
is.plan = function(b,r) return b^2/(b+r) end |
is.fear = function(b,r) return r^2/(b+r) end |
is.tabu = function(b,r) return math.log(1/(b+r)) end |
function score(i,b,r,B,R) |
local tiny = 1E-64 |
return is[i.Goal](b/(B+tiny), r/(R+tiny)) end |
print("BR",score({Goal="plan"},4,0,11,33)) |
-- function extend(tests,test,filter) |
-- tests=copy(tests) |
-- tests[1+#tests]=(filter or lib.itself)(test) |
-- return tests end |
-- |
-- local function TREE(i,cols,best,rows0) |
-- local B,R = br(rows0,bestok-) |
-- local val = function(b,r) return score(i,b,r,B,R) end |
-- stop = (#rows0)^i.min |
-- local function tree(rows,path) |
-- if #rows >= 2*stop then |
-- tmp = map(cols, function(col) return branch(col,rows,,best,val) end) |
-- local b, r, yes, no, test = 0, 0, {}, {}, sort(tmp, gt"val")[1].test |
-- if test then |
-- for _,row in pairs(rows) do |
-- if row.y==best then b=b+1 else r=r+1 end |
-- push(accept(test,row) and yes or no, row) end |
-- return { |
-- rows = rows, |
-- score = val(b,r), |
-- tests = path, |
-- left = #yes< #rows and tree(yes,extend(path,test)), |
-- right = #no < #rows and tree(no, extend(path,test,flip))} end end |
-- end -------------------------------------------- |
-- return tree(rows0,{}) end |
-- |
-- function branch(col,rows,best,val) |
-- local function good(row) if row.cells[col.at] ~= "?" then return row end end |
-- rows = sort(map(rows,good), function(r1,r2) return r1.cells[col.at]< r2.cells[col.at] end) |
-- local function sym() |
-- local b,r = {},{} |
-- for j,row in pairs(rows) do |
-- local x=row.cells[col.at] |
-- b[x] = b[x] or 0 |
-- r[x] = r[x] or 0 |
-- if row.y==best then b[x]=b[x]+1 else r[x]=r[x]+1 end end |
-- local fun= function(x,b) return {val=val(b,r[x]), cut=x} end |
-- local tmp= sort(kap(b, fun),gt"val")[1] |
-- return {val=tmp.val, test={at=col.at,txt=col.txt,op="=",x=tmp.cut}} |
-- end ------------------------ |
-- function num() |
-- local b0,r0,op,cut = 0,0 |
-- local b1,r1 = br(rows, best) |
-- local best = val(b1,r1) |
-- for j,row in pairs(rows) do -- find the cut that minimizes expected value of entropy |
-- if row.y==best then b0=b0+1; b1=b1-1 else r0=r0+1; r1=r1-1 end |
-- local x=row.cells[col.at] |
-- if j < #rows and x ~= rows[j+1].cells[col.at] then |
-- local v1 = val(b0,r0) |
-- local v2 = val(b1,r1) |
-- if v1 > best then best,cut,op = v1,x,"<=" end |
-- if v2 > best then best,cut,op = v2,x,">" end end end |
-- if cut then |
-- return {val=best, test={at=col.at, txt=col.txt, op=op, x=cut}} end |
-- end -------------------------------------------------------- |
-- return (isNum(col) and num or sym)() end |
-- |
local bin,bins,merges,rank,noGaps,selects,sorted,select1 |
function bin(i,col,x, tmp) |
if x=="?" or isSym(col) then return x end |
tmp = (col.hi - col.lo)/(i.bins - 1) |
return col.hi == col.lo and 1 or math.floor(x/tmp + .5)*tmp end |
function bins(i,cols,best,rowss) |
local out={} |
for _,col in pairs(cols) do -- for all columns |
local n,xys=0,{} |
for klass,rows in pairs(rowss) do -- for all klasses |
for _,row in pairs(rows) do -- for all rows in a klass |
local x=row.cells[col.at] |
if x ~= "?" then -- for all non-empty cells |
n = n + 1 |
local k = bin(i,col,x) -- map cell to a small number of bins |
xys[k] = xys[k] or {x=NUM(col.at,col.txt), y=SYM(col.at,col.txt)} |
add(xys[k].y, klass) -- track best/non-best cell values seen in this bin |
add(xys[k].x, x) end -- track x cell values seen in this bin |
end -- for rows |
end -- for klasses |
xys = sort(map(xys,lib.itself), function(a,b) return a.x.lo < b.x.lo end) |
xys = isSym(col) and xys or merges(xys,n/i.bins,i.cohen*div(col)) |
push(out, sorted(i,best,rowss,xys)[1]) |
end -- for col |
return selects(i,out,best,rowss) end |
function merges(xys0,nSmall,nFar) |
local j,xys1=1,{} |
while j <= #xys0 do |
local one, two = xys0[j], xys0[j+1] -- in lua t0[j+1] returns nil after end of arrary |
if two then -- if not at end of array |
local x = merged(one.x,two.x,nSmall,nFar) -- if we can merge these x-values |
if x then |
one = {x=x, y=merge(one.y, two.y)} -- then combine the y-values |
j= j + 1 -- skip over item two |
end end |
push(xys1, one) -- at each stage of the while, keep one thing |
j=j+1 end |
return #xys0 == #xys1 and xys0 or merges(xys1,nSmall,nFar) end |
function noGaps(xys) |
xys[1].x.lo = -math.huge |
xys[#xys].x.hi = math.huge |
for j=2,#xys do xys[j].x.lo = xys[j-1].x.hi end |
return xys end |
function sorted(i,best,rowss,xys) |
local B,R=0,0 |
for klass,rows in pairs(rowss) do -- get background ratios |
if klass==best then B=B+#rows else R=R+#rows end end |
for _,xy in pairs(xys) do -- get ratios and score for each xy |
local b,r=0,0 |
for klass,n in pairs(xy.y.has) do |
if klass==best then b=b+n else r=r+n end end |
xy.score = score(i,b,r,B,R) end |
xys = map(xys, function(xy) if xy.score>0 then return xy end end) -- kill bad scores |
return sort(xys,gt"score") end |
function select(xy,row) |
local x= row.cells[xy.x.at] |
local tmp= x=="?" or xy.x.lo == xy.x.hi and x==xy.x.lo or x>= xy.x.lo and x<=xy.x.hi |
return tmp end |
function selects(i,xys,best,rowss) |
local all,scored,B,R = {},{},0,0 |
for klass,rows in pairs(rowss) do |
if klass==best then B=B+#rows else R=R+#rows end |
for j,row in pairs(rows) do io.write(j," ");row.klass=klass; push(all,row) end end |
print("all",#all) |
xys = sort(xys,gt"score") |
for stop=1,#xys do |
local tmp = select1(i,stop,xys,best,all,B,R) |
if tmp then push(scored,tmp) end end |
map(scored,function(z) print("zz",#z.rule,z.score) end) |
return sort(scored,gt"score") end |
function select1(i,stop,xys,best,all,B,R) |
local tmp,how=all,{} |
for j=1,stop do |
tmp = map(tmp,function(row) if select(xys[j],row) then return row end end) |
if #tmp == 0 then break else |
local x=xys[j].x |
how[j] = {at=x.at, txt=x.txt, lo=x.lo,hi=x.hi} end |
end |
local b,r=0,0 |
for _,row in pairs(tmp) do |
if row.klass==best then b=b+1 else r=r+1 end end |
print("br",b,r,B,R) |
return {rule=how, score=score(i,b,r,B,R)} end |
-- local rule,rule1,branches,numBranches,symBranches |
-- function rule(i,cols,B,R,rows) |
-- return rule1(cols, rows, |
-- {}, |
-- (#rows)^i.min, |
-- function(b,r) return score(i,b,r,B,R) end) end |
-- |
-- function rule1(cols,rows,path,stop,val) |
-- if #rows < 2*stop then return path end |
-- tests = {} |
-- map(cols, function(col) branches(col,rows,val,tests) end) |
-- test = sort(tests,gt"val")[1] |
-- local rest,b,r = {},0,0 |
-- for _,row in pairs(rows) do |
-- if accept(test,row) then |
-- if row.best then b=b+1 else r=r+1 end |
-- else push(rest,row) end end |
-- push(path,{test=test,b=b,r=r,val=val}) |
-- if #no < #rows1 then return rule1(cols,no,path,stop,val) end end |
-- |
-- function numBranches(col,rows,val,out) |
-- local b1,r1=br(rows) |
-- local most,b0,r0 = val(b1,r1),0,0 |
-- local cut,op |
-- for j,row in pairs(rows) do |
-- if row.y then b0=b0+1; b1=b1-1 else r0=r0+1; r1=r1-1 end |
-- local x=row.cells[col.at] |
-- if j < #rows-stop and j> stop and x ~= rows[j+1].cells[col.at] then |
-- local v1 = val(b0,r0) |
-- local v2 = val(b1,r1) |
-- if v1 > most then most,cut,op = v1,x,"<=" end |
-- if v2 > most then most,cut,op = v2,x,">" end end end |
-- if cut then |
-- push(out, {val=most, test={at=col.at, txt=col.txt, op=op, x=cut}}) end end |
-- |
-- function symBranches(col,rows,val,out) |
-- local bs,rs = {},{} |
-- for j,row in pairs(rows) do |
-- local x=row.cells[col.at] |
-- bs[x] = bs[x] or 0 |
-- rs[x] = rs[x] or 0 |
-- if row.best then bs[x]=bs[x]+1 else rs[x]=rs[x]+1 end end |
-- for x1,br in pairs(bs) do |
-- push(out, {val=val(br,rs[x1]), test={at=col.at,txt=col.txt,op="=",x=x1}}) end end |
-- |
-- function branches(col,rows,val,out) |
-- local function good(row) if row.cells[col.at] ~= "?" then return row end end |
-- rows = sort(map(rows,good), function(r1,r2) return r1.cells[col.at] < r2.cells[col.at] end) |
-- (isNum(col) and numBranches or symBranches)(col,rows,vals,out) end |
-- |
--------------------------------------------------------------------------------------------------- |
return pcall(debug.getlocal,4,1) and lib.locals() or {} |