Last active
January 16, 2020 03:46
-
-
Save logicalguess/fac10b55d95fa17f4ffe1a33f3155057 to your computer and use it in GitHub Desktop.
State Monad in Go.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"fmt" | |
"reflect" | |
) | |
type G interface{} | |
type S G | |
type StateMonad struct { | |
run func(S) (S, G) | |
} | |
func (m StateMonad) Apply(s S) (S, G) { | |
return m.run(s) | |
} | |
func (m StateMonad) Eval(s S) G { | |
_, a := m.run(s) | |
return a | |
} | |
func (m StateMonad) Map(f func(G) G) StateMonad { | |
r := func(s S) (S, G) { | |
s1, a := m.run(s) | |
return s1, f(a) | |
} | |
return StateMonad{run: r} | |
} | |
func (m StateMonad) FlatMap(f func(G) StateMonad) StateMonad { | |
r := func(s S) (S, G) { | |
s1, a := m.run(s) | |
return f(a).Apply(s1) | |
} | |
return StateMonad{run: r} | |
} | |
func Map2(ma StateMonad, mb StateMonad, f func(G, G) G) StateMonad { | |
//ma.FlatMap(a => mb.Map(b => f(a, b))) | |
r := func(a G) StateMonad { | |
f1 := func(b G) G { | |
return f(a, b) | |
} | |
return mb.Map(f1) | |
} | |
return ma.FlatMap(r) | |
} | |
// the value is a list of intermediate calculations (scan) | |
func Sequence(lma ...StateMonad) StateMonad { | |
var m StateMonad = StateMonad{ | |
run: func(s S) (S, G) { | |
return s, make([]G, 0) | |
}, | |
} | |
for _, ma := range lma { | |
m = Map2(m, ma, func(a G, b G) G { | |
return append(a.([]G), b) | |
}) | |
} | |
return m | |
} | |
// the value is a list of intermediate calculations (scan) | |
// state monads are generated from inputs | |
func Traverse(la []G, f func(G) StateMonad) StateMonad { | |
var m StateMonad = StateMonad{ | |
run: func(s S) (S, G) { | |
return s, make([]G, 0) | |
}, | |
} | |
for _, a := range la { | |
m = Map2(m, f(a), func(a G, b G) G { | |
return append(a.([]G), b) | |
}) | |
} | |
return m | |
} | |
// the value is the last result of chained calculations (fold) | |
// state monads are generated from inputs | |
func Chain(la []G, f func(G) StateMonad) StateMonad { | |
var m StateMonad = StateMonad{ | |
run: func(s S) (S, G) { | |
return s, nil | |
}, | |
} | |
for _, a := range la { | |
m = Map2(m, f(a), func(a G, b G) G { | |
return b | |
}) | |
} | |
return m | |
} | |
func main() { | |
id := StateMonad{ | |
run: func(s S) (S, G) { | |
return s.(int), s.(int) | |
}, | |
} | |
succ := func(s S) (S, G) { | |
return s.(int) + 1, s.(int) + 1 | |
} | |
m1 := StateMonad{ run: succ } | |
m2 := id.Map(func(a G) G {return a.(int) + 2}) | |
s0 := 0 | |
fmt.Println(id.Eval(s0)) //0 | |
fmt.Println(m1.Eval(s0)) //1 | |
fmt.Println(m2.Eval(s0)) //2 | |
f3 := func(G) StateMonad { | |
return StateMonad{ | |
run: func(s S) (S, G) { | |
return s.(int) + 3, s.(int) + 3 | |
}, | |
} | |
} | |
fmt.Println(m1.FlatMap(f3).FlatMap(f3).Eval(s0)) //7 = 0 + 1 + 3 + 3 | |
mFactory := func(a G) StateMonad { | |
return StateMonad{ | |
run: func(s S) (S, G) { | |
return s.(int) + a.(int), s.(int) + a.(int) | |
}, | |
} | |
} | |
fmt.Println(m1.FlatMap(mFactory).FlatMap(mFactory).FlatMap(mFactory).Eval(s0)) //8 = 0 + 1 + 1 + 2 + 4 | |
m := Map2(m1, m2, func(a G, b G) G { | |
return a.(int) + 5*b.(int) | |
}) | |
fmt.Println(m.Eval(s0)) //16 = (0 + 1) + 5*(0 + 1 + 2) | |
fmt.Println(Sequence(id, m1, mFactory(3)).Eval(s0)) //[0 1 4] | |
fmt.Println(Traverse([]G{1, 2, 3}, mFactory).Eval(s0)) //[1 3 6] | |
// currying as a state monad traverse | |
invoke := func(fn G, args []G) G { | |
var vs []reflect.Value | |
for _, arg := range args { | |
vs = append(vs, reflect.ValueOf(arg)) | |
} | |
fun := reflect.ValueOf(fn) | |
return fun.Call(vs)[0].Interface() | |
} | |
wrap := func(fn G) func(...G) G { | |
return func(a ...G) G { | |
return invoke(fn, a) | |
} | |
} | |
curryFactory := func(a G) StateMonad { | |
return StateMonad{ | |
run: func(fn S) (S, G) { | |
if !reflect.TypeOf(fn).IsVariadic() { | |
fn = wrap(fn) | |
} | |
f := func(values ...G) G { | |
values = append([]G{a}, values...) | |
return fn.(func(...G) G)(values...) | |
} | |
var v G | |
setValue := func() { | |
defer func() { | |
if r := recover(); r != nil { | |
v = f | |
} | |
}() | |
v = f() | |
} | |
setValue() | |
return f, v | |
}, | |
} | |
} | |
sum := func(a int, b int, c int) int { | |
return a + b + c | |
} | |
size := func(a int, b string) int { | |
return a + len(b) | |
} | |
fmt.Println(Chain([]G{1, 2, 3}, curryFactory).Eval(sum)) //6 | |
fmt.Println(Chain([]G{1, "ab"}, curryFactory).Eval(size)) //3 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment