Last active
October 4, 2015 21:08
-
-
Save einblicker/2700474 to your computer and use it in GitHub Desktop.
トーナメント方式のGP
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
open System | |
open Microsoft.FSharp.Quotations | |
open Microsoft.FSharp.Quotations.ExprShape | |
open Microsoft.FSharp.Linq.QuotationEvaluation | |
module Util = | |
let rnd = new Random() | |
let (|Range|_|) min max x = | |
if min <= x && x <= max then Some() else None | |
module List = | |
open Util | |
let rndPick (xs : list<'T>) : 'T = | |
rnd.Next(0, xs.Length) | |
|> List.nth xs | |
let rndAppOnce (f : 'T -> 'T) (xs : list<'T>) = | |
let pos = rnd.Next(xs.Length) | |
List.mapi (fun i e -> if i = pos then f e else e) xs | |
module ErrorEx = | |
type ErrorEx<'v, 'e, 'a> = | |
| Result of 'a | |
| Suspend of 'e * ('v -> ErrorEx<'v, 'e, 'a>) | |
type ErrorExBuilder() = | |
member this.Return(x) = | |
Result(x) | |
member this.Bind(x, f) = | |
match x with | |
| Result(a) -> f a | |
| Suspend(e, c) -> | |
Suspend(e, fun y -> this.Bind(c(y), f)) | |
let error = ErrorExBuilder() | |
let suspend e = Suspend(e, error.Return) | |
let rec getContent = function | |
| Result(x) -> x | |
| Suspend(e, cont) -> getContent (cont e) | |
let rec sequence xs = error { | |
match xs with | |
| [] -> return [] | |
| x :: xs -> | |
let! x' = x | |
let! xs' = sequence xs | |
return x' :: xs' | |
} | |
module GP = | |
open ErrorEx | |
open Util | |
exception private Done | |
type GP(mutProb, crossProb, poolSize, initTreeSize, termFn, nonTermFn, fitnessFn) = | |
///ランダムに式を生成する。木の深さはdepth以下に制限される。 | |
let rec genExpr (depth : int) : Expr<float> = | |
let genTerm (depth : int): Expr<float> = termFn depth genExpr | |
let genNonTerm (depth : int) : Expr<float> = nonTermFn depth genExpr | |
match rnd.NextDouble() with | |
| Range 0.0 0.5 when depth > 0 -> genNonTerm depth | |
| _ -> genTerm depth | |
///突然変異を行い、新たな式を返す。 | |
let rec mutation (expr : Expr<float>) : Expr<float> = | |
if rnd.NextDouble() < mutProb then | |
genExpr initTreeSize | |
else | |
match expr with | |
| ShapeCombination(a, ([_; _] as xs)) -> | |
xs | |
|> List.map Expr.Cast | |
|> List.map mutation | |
|> function | |
| [lhs'; rhs'] -> | |
RebuildShapeCombination(a, [lhs'; rhs']) | |
|> Expr.Cast | |
| _ -> failwith "match failure" | |
| _ -> Expr.Cast expr | |
///二つの式を交配し、新たな二つの式を返す。 | |
let rec crossover (lhs : Expr<float>) (rhs : Expr<float>) : Expr<float> * Expr<float> = | |
let rec extract expr = error { | |
if rnd.NextDouble() < mutProb then | |
let! newExpr = suspend expr | |
return newExpr | |
else | |
match expr with | |
| ShapeCombination(s, ([_; _] as xs)) -> | |
let! xs' = | |
xs | |
|> List.map extract | |
|> sequence | |
return RebuildShapeCombination(s, xs') | |
| _ -> return expr | |
} | |
let rec iter = function | |
| (Result(lhs'), Result(rhs')) -> (lhs',rhs') | |
| (Suspend(lhs', contL), Suspend(rhs', contR)) -> iter (contR lhs', contL rhs') | |
| (Result(lhs'), Suspend(rhs', contR)) -> (lhs', getContent (contR rhs')) | |
| (Suspend(lhs', contL), Result(rhs')) -> (getContent (contL lhs'), rhs') | |
let (e1, e2) = iter (extract lhs, extract rhs) | |
(Expr.Cast e1, Expr.Cast e2) | |
let rec fitness (expr : Expr<float>) : float = fitnessFn expr fitness | |
let makeNextGeneration (pool : Expr<float> []) : Expr<float> [] = | |
let tournamentSelect (pool : (float * Expr<float>)[]) = | |
[| | |
for x = 0 to 100 do | |
yield pool.[int <| rnd.NextDouble() * float pool.Length] | |
|] | |
|> Array.maxBy fst | |
|> snd | |
let fitnesses = | |
Array.map (fun e -> | |
let f = fitness e | |
if f < 0.005 then raise Done | |
(100.0 / f, e)) pool | |
let total = Array.map fst fitnesses |> Array.sum | |
let average = total / float(Array.length pool) | |
let upper = fst <| Array.maxBy fst fitnesses | |
let lower = fst <| Array.minBy fst fitnesses | |
printfn "average: %A, upper: %A, lower: %A" average upper lower | |
let eliteSize = float pool.Length * 0.03 |> int | |
let elite = (Array.sortBy fst fitnesses).[0..eliteSize-1] |> Array.map snd | |
Array.append elite [| | |
for x = 0 to (Array.length pool - eliteSize)/2 - 1 do | |
let e1 = tournamentSelect fitnesses | |
let e2 = tournamentSelect fitnesses | |
let (e1', e2') = if rnd.NextDouble() < crossProb then crossover e1 e2 else (e1, e2) | |
yield if rnd.NextDouble() < mutProb then mutation e1' else e1' | |
yield if rnd.NextDouble() < mutProb then mutation e2' else e2' | |
|] | |
member this.Fitness e = fitness e | |
member this.Evolve () : Expr<float> = | |
let mutable pool = Array.init poolSize (fun _ -> genExpr(initTreeSize)) | |
try | |
for i = 1 to 100 do | |
pool <- makeNextGeneration pool | |
with | |
Done -> () | |
(Array.sortBy fitness pool).[0] | |
[<EntryPoint>] | |
let main _ = | |
let varX = new Var("x", typeof<float>) | |
///作成する関数。 | |
let f x = (x * x) - (2.0 * x) + 12.0 | |
let gp = | |
new GP.GP(0.3, 0.3, 1000, 3, | |
(fun depth _ -> | |
let num = Util.rnd.NextDouble() * 10.0 | |
match Util.rnd.NextDouble() with | |
| Util.Range 0.0 0.33 -> <@ num @> | |
| _ -> <@ %(Expr.Cast (Expr.Var(varX))) @>), | |
(fun depth self -> | |
let rhs = depth - 1 |> self | |
let lhs = depth - 1 |> self | |
match Util.rnd.NextDouble() with | |
| Util.Range 0.0 0.33 -> <@ %lhs + %rhs @> | |
| Util.Range 0.33 0.66 -> <@ %lhs - %rhs @> | |
| _ -> <@ %lhs * %rhs @>), | |
///関数fとの[0, 100]の範囲の二乗誤差を求める。 | |
(fun expr _ -> | |
let lambda : Expr<float -> float> = | |
Expr.Lambda(varX, expr) | |
|> Expr.Cast | |
let lambda = lambda.Compile() () | |
seq { | |
for i in 0.0..10.0..1000.0 do | |
yield (f i - lambda i) ** 2.0 } | |
|> Seq.sum | |
|> fun x -> x / (1000.0 / 10.0))) | |
let result = gp.Evolve () | |
printfn "%A" result | |
printfn "%A" <| gp.Fitness result | |
0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment