Created
June 15, 2012 13:05
-
-
Save einblicker/2936392 to your computer and use it in GitHub Desktop.
Naive GP
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package ; | |
import js.JQuery; | |
import js.Lib; | |
enum Pair<A, B> { | |
pair(a : A, b : B); | |
} | |
class Util { | |
public static function fst<A, B>(p: Pair<A, B>) { | |
return switch p { | |
case pair(a, b): | |
a; | |
}; | |
} | |
public static function snd<A, B>(p: Pair<A, B>) { | |
return switch p { | |
case pair(a, b): | |
b; | |
}; | |
} | |
} | |
enum EitherWithCont<A, B, C> { | |
result(value : A); | |
cont(value : B, func : C -> EitherWithCont<A, B, C>); | |
} | |
class Suspend { | |
public static function ret<A, B, C>(x : A) : EitherWithCont<A, B, C> { | |
return result(x); | |
} | |
public static function bind<A, B, C, D>(x : EitherWithCont<A, B, C>, f : A -> EitherWithCont<D, B, C>) : EitherWithCont<D, B, C> { | |
return switch x { | |
case result(value): | |
f(value); | |
case cont(value, func): | |
cont(value, function(y) { | |
return bind(func(y), f); | |
}); | |
}; | |
} | |
public static function suspend<T>(value : T) { | |
return cont(value, ret); | |
} | |
public static function getResult(m) { | |
return switch m { | |
case result(value): | |
value; | |
case cont(value, func): | |
getResult(func(value)); | |
}; | |
} | |
} | |
enum Ast { | |
vari(name : String); | |
int(val : Int); | |
add(lhs : Ast, rhs : Ast); | |
mul(lhs : Ast, rhs : Ast); | |
sub(lhs : Ast, rhs : Ast); | |
div(lhs : Ast, rhs : Ast); | |
iflt(lhs : Ast, rhs : Ast, then : Ast, else_ : Ast); | |
} | |
class AstUtil { | |
public static function treeSize(ast : Ast) : Int { | |
return switch ast { | |
case vari(name): | |
1; | |
case int(val): | |
1; | |
case add(lhs, rhs): | |
treeSize(lhs) + treeSize(rhs); | |
case mul(lhs, rhs): | |
treeSize(lhs) + treeSize(rhs); | |
case sub(lhs, rhs): | |
treeSize(lhs) + treeSize(rhs); | |
case div(lhs, rhs): | |
treeSize(lhs) + treeSize(rhs); | |
case iflt(lhs, rhs, then, else_): | |
treeSize(lhs) + treeSize(rhs) + treeSize(then) + treeSize(else_); | |
}; | |
} | |
public static function strRepeat(s, i) { | |
var b = new StringBuf(); | |
for (x in 0...i) { | |
b.add(s); | |
} | |
return b.toString(); | |
} | |
public static function pprint_org(ast : Ast, tab = 0) : String { | |
var br = "\n"; //"<br>"; | |
return switch ast { | |
case vari(name): | |
name; | |
case int(val): | |
Std.string(val); | |
case add(lhs, rhs): | |
"(" + pprint(lhs, tab+1) + " + " + pprint(rhs, tab+1) + ")"; | |
case mul(lhs, rhs): | |
"("+ pprint(lhs, tab+1) + " * " + pprint(rhs, tab+1) + ")"; | |
case sub(lhs, rhs): | |
"("+ pprint(lhs, tab+1) + " - " + pprint(rhs, tab+1) + ")"; | |
case div(lhs, rhs): | |
"("+ pprint(lhs, tab+1) + " / " + pprint(rhs, tab+1) + ")"; | |
case iflt(lhs, rhs, then, else_): | |
br + " if (" + pprint(lhs, tab + 1) + " < " + pprint(rhs, tab + 1) + ") { " + br + | |
strRepeat(".", (tab+1)*4) + pprint(then, tab + 1) + br + " } else { " + br + | |
strRepeat(".", (tab+1)*4) + pprint(else_, tab + 1) + br + " } "; | |
}; | |
} | |
public static function pprint(ast : Ast, tab = 0) : String { | |
var br = "\n"; | |
function iter(ast : Ast, tab = 0, varCount = 0, cont: String -> String) : String { | |
return switch ast { | |
case vari(name): | |
cont(name); | |
case int(val): | |
cont(Std.string(val)); | |
case add(lhs, rhs): | |
iter(lhs, tab, varCount, function(val0) { | |
return iter(rhs, tab, varCount, function(val1) { | |
return cont("(" + val0 + " + " + val1 + ")"); | |
}); | |
}); | |
case mul(lhs, rhs): | |
iter(lhs, tab, varCount, function(val0) { | |
return iter(rhs, tab, varCount, function(val1) { | |
return cont("(" + val0 + " * " + val1 + ")"); | |
}); | |
}); | |
case sub(lhs, rhs): | |
iter(lhs, tab, varCount, function(val0) { | |
return iter(rhs, tab, varCount, function(val1) { | |
return cont("(" + val0 + " - " + val1 + ")"); | |
}); | |
}); | |
case div(lhs, rhs): | |
iter(lhs, tab, varCount, function(val0) { | |
return iter(rhs, tab, varCount, function(val1) { | |
return cont("(" + val0 + " / " + val1 + ")"); | |
}); | |
}); | |
case iflt(lhs, rhs, then, else_): | |
iter(lhs, tab + 1, varCount + 1, function(val0) { | |
return iter(rhs, tab + 1, varCount + 1, function(val1) { | |
return iter(then, tab + 1, varCount + 1, function(val2) { | |
return iter(else_, tab+1, varCount + 1, function(val3) { | |
return | |
strRepeat(" ", 4) + "int val" + Std.string(varCount) + ";" + br + | |
strRepeat(" ", 4) + "if (" + val0 + " < " + val1 + ") { " + br + | |
strRepeat(" ", 8) + "val" + Std.string(varCount) + " = " + val2 + ";" + br + | |
strRepeat(" ", 4) + "} else { " + br + | |
strRepeat(" ", 8) + "val" + Std.string(varCount) + " = " + val3 + ";" + br + | |
strRepeat(" ", 4) + "} " + br + br + | |
strRepeat(" ", 4) + cont("val" + Std.string(varCount)); | |
}); | |
}); | |
}); | |
}); | |
}; | |
} | |
return "int Signal(int x) " + br + "{" + br + iter(ast, 0, function(x) { return "return " + x + ";"; } ) + br + "}"; | |
} | |
public static function eval(ast, env) { | |
return switch ast { | |
case vari(name): | |
env(name); | |
case int(val): | |
val; | |
case add(lhs, rhs): | |
eval(lhs, env) + eval(rhs, env); | |
case mul(lhs, rhs): | |
eval(lhs, env) * eval(rhs, env); | |
case sub(lhs, rhs): | |
eval(lhs, env) - eval(rhs, env); | |
case div(lhs, rhs): | |
var y = eval(rhs, env); | |
if (y == 0) 1 else Std.int(eval(lhs, env) / y); | |
case iflt(lhs, rhs, then, else_): | |
if (eval(lhs, env) < eval(rhs, env)) eval(then, env); | |
else eval(else_, env); | |
}; | |
} | |
} | |
class Main { | |
public static inline var poolSize = 100; | |
public static var series : Array<Int>; | |
public static function genAst(depth : Int = 6) : Ast { | |
return if (Math.random() < 0.3 || depth <= 0) { | |
var r = Math.random(); | |
if (0.0 < r && r < 0.5) { | |
int(Std.int(Math.random() * 100)); | |
} else { | |
vari("x"); | |
} | |
} else { | |
var r = Math.random(); | |
if (0.0 < r && r < 0.2) { | |
add(genAst(depth-1), genAst(depth-1)); | |
} else if (0.2 < r && r < 0.4) { | |
mul(genAst(depth-1), genAst(depth-1)); | |
} else if (0.4 < r && r < 0.6) { | |
sub(genAst(depth-1), genAst(depth-1)); | |
} else if (0.6 < r && r < 0.8) { | |
div(genAst(depth-1), genAst(depth-1)); | |
} else if (0.8 < r && r < 0.9) { | |
iflt(genAst(depth-1),genAst(depth-1),genAst(depth-1), genAst(depth-1)); | |
} else { | |
genAst(depth-1); | |
} | |
}; | |
} | |
public static function mutation(ast : Ast) : Ast { | |
return if (Math.random() < 0.03) { | |
genAst(); | |
} else { | |
switch ast { | |
case vari(name): | |
vari(name); | |
case int(val): | |
int(val); | |
case add(lhs, rhs): | |
add(mutation(lhs), mutation(rhs)); | |
case mul(lhs, rhs): | |
mul(mutation(lhs), mutation(rhs)); | |
case sub(lhs, rhs): | |
sub(mutation(lhs), mutation(rhs)); | |
case div(lhs, rhs): | |
div(mutation(lhs), mutation(rhs)); | |
case iflt(lhs, rhs, then, else_): | |
iflt(mutation(lhs), mutation(rhs), | |
mutation(then), mutation(else_)); | |
} | |
} | |
} | |
public static function crossover(ast1 : Ast, ast2 : Ast) : Pair<Ast, Ast> { | |
function iter(ast : Ast) : EitherWithCont<Ast,Ast,Ast> { | |
return if (Math.random() < 0.3) | |
Suspend.suspend(ast) | |
else switch ast { | |
case add(lhs, rhs): | |
Suspend.bind(iter(lhs), function(lhs) { | |
return Suspend.bind(iter(rhs), function(rhs) { | |
return Suspend.ret(add(lhs, rhs)); | |
}); | |
}); | |
case mul(lhs, rhs): | |
Suspend.bind(iter(lhs), function(lhs) { | |
return Suspend.bind(iter(rhs), function(rhs) { | |
return Suspend.ret(mul(lhs, rhs)); | |
}); | |
}); | |
case div(lhs, rhs): | |
Suspend.bind(iter(lhs), function(lhs) { | |
return Suspend.bind(iter(rhs), function(rhs) { | |
return Suspend.ret(div(lhs, rhs)); | |
}); | |
}); | |
case sub(lhs, rhs): | |
Suspend.bind(iter(lhs), function(lhs) { | |
return Suspend.bind(iter(rhs), function(rhs) { | |
return Suspend.ret(sub(lhs, rhs)); | |
}); | |
}); | |
case iflt(lhs, rhs, then, else_): | |
Suspend.bind(iter(lhs), function(lhs) { | |
return Suspend.bind(iter(rhs), function(rhs) { | |
return Suspend.bind(iter(then), function(then) { | |
return Suspend.bind(iter(else_), function(else_) { | |
return Suspend.ret(iflt(lhs, rhs, then, else_)); | |
}); | |
}); | |
}); | |
}); | |
default: | |
Suspend.ret(ast); | |
} | |
}; | |
function loop(ast1, ast2) return switch ast1 { | |
case result(value1): | |
switch ast2 { | |
case result(value2): | |
pair(value1, value2); | |
case cont(value, func): | |
pair(value1, Suspend.getResult(func(value))); | |
}; | |
case cont(value1, func1): | |
switch ast2 { | |
case result(value2): | |
pair(Suspend.getResult(func1(value1)), value2); | |
case cont(value2, func2): | |
loop(func1(value2), func2(value1)); | |
}; | |
}; | |
return loop(iter(ast1), iter(ast2)); | |
} | |
public static var goal = function(x) { return if (x > 3) x * 2 else x*x+3; } | |
public static function fitness(ind, sizePenalty = true) { | |
var f = 0.0; | |
//var goal = function(x) { return series[x]; } | |
for (i in 0...100) {//series.length) { | |
f += Math.pow(goal(i) - | |
AstUtil.eval(ind, function(_) { return i; } ) | |
, 2.0); | |
} | |
return f + if (sizePenalty) AstUtil.treeSize(ind) * 0.001 else 0.0; | |
} | |
public static function makeNextGeneration(pool : Array<Ast>) { | |
var l = pool.length; | |
var next = new Array(); | |
var fitnesses = Lambda.array(Lambda.mapi(pool, function(i, ind) { | |
return pair(i, fitness(ind)); | |
})); | |
fitnesses.sort(function(x, y) { | |
return if (Util.snd(x) < Util.snd(y)) -1 | |
else if (Util.snd(x) == Util.snd(y)) 0 | |
else 1; | |
}); | |
for (i in 0...10) { | |
next.push(pool[Util.fst(fitnesses[i])]); | |
} | |
for (i in 0...Std.int((poolSize-10)/2)) { | |
var buf = new Array(); | |
for (j in 0...20) { | |
var indiAndFit = fitnesses[Std.int(Math.random() * l)]; | |
buf.push(indiAndFit); | |
} | |
buf.sort(function(x, y) { | |
return if (Util.snd(x) < Util.snd(y)) -1 | |
else if (Util.snd(x) == Util.snd(y)) 0 | |
else 1; | |
}); | |
var ind1 = pool[Util.fst(buf[0])]; | |
var ind2 = pool[Util.fst(buf[1])]; | |
if (Math.random() < 0.3) ind1 = mutation(ind1); | |
if (Math.random() < 0.3) ind2 = mutation(ind2); | |
var p = if (Math.random() < 0.3) crossover(ind1, ind2) else pair(ind1, ind2); | |
next.push(Util.fst(p)); | |
next.push(Util.snd(p)); | |
} | |
return next; | |
} | |
public static var pool = new Array(); | |
public static var count = 0; | |
public static var id; | |
public static function init() { | |
new JQuery("#answer").text(Std.string(goal)); | |
for (i in 0...poolSize) { | |
pool.push(genAst()); | |
} | |
} | |
public static function evolve() { | |
count += 1; | |
if (!(count < 1000)) untyped __js__("clearInterval(Main.id)"); | |
pool = makeNextGeneration(pool); | |
var fs = Lambda.map(pool, function(x) { return fitness(x, false); } ); | |
new JQuery("#count").text("count:" + Std.string(count)); | |
new JQuery("#best").text("best:" + Std.string(fs.first())); | |
new JQuery("#mean").text( | |
"mean:" + Std.string(Lambda.fold(fs, function(x, y) { return x + y; }, 0) / fs.length) | |
); | |
new js.JQuery("#result").text( | |
/* "count:" + Std.string(count) + "<br>" + | |
"best fitness:" + Std.string(fs.first()) + "<br>" + | |
"mean fitness:" + Std.string(Lambda.fold(fs, function(x, y) { return x + y;}, 0)/fs.length) + "<br>" + | |
*/ AstUtil.pprint(pool[0]) | |
); | |
} | |
public static function main() { | |
// series = Lambda.array(Lambda.map(new JQuery("textarea").html().split("\n"), Std.parseInt)).slice(0, 100); | |
init(); | |
id = untyped __js__("setInterval(Main.evolve, 100)"); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment