Skip to content

Instantly share code, notes, and snippets.

@logicalguess
Last active January 16, 2020 03:46
Show Gist options
  • Save logicalguess/fac10b55d95fa17f4ffe1a33f3155057 to your computer and use it in GitHub Desktop.
Save logicalguess/fac10b55d95fa17f4ffe1a33f3155057 to your computer and use it in GitHub Desktop.
State Monad in Go.
package main
import (
"fmt"
"reflect"
)
type G interface{}
type S G
type StateMonad struct {
run func(S) (S, G)
}
func (m StateMonad) Apply(s S) (S, G) {
return m.run(s)
}
func (m StateMonad) Eval(s S) G {
_, a := m.run(s)
return a
}
func (m StateMonad) Map(f func(G) G) StateMonad {
r := func(s S) (S, G) {
s1, a := m.run(s)
return s1, f(a)
}
return StateMonad{run: r}
}
func (m StateMonad) FlatMap(f func(G) StateMonad) StateMonad {
r := func(s S) (S, G) {
s1, a := m.run(s)
return f(a).Apply(s1)
}
return StateMonad{run: r}
}
func Map2(ma StateMonad, mb StateMonad, f func(G, G) G) StateMonad {
//ma.FlatMap(a => mb.Map(b => f(a, b)))
r := func(a G) StateMonad {
f1 := func(b G) G {
return f(a, b)
}
return mb.Map(f1)
}
return ma.FlatMap(r)
}
// the value is a list of intermediate calculations (scan)
func Sequence(lma ...StateMonad) StateMonad {
var m StateMonad = StateMonad{
run: func(s S) (S, G) {
return s, make([]G, 0)
},
}
for _, ma := range lma {
m = Map2(m, ma, func(a G, b G) G {
return append(a.([]G), b)
})
}
return m
}
// the value is a list of intermediate calculations (scan)
// state monads are generated from inputs
func Traverse(la []G, f func(G) StateMonad) StateMonad {
var m StateMonad = StateMonad{
run: func(s S) (S, G) {
return s, make([]G, 0)
},
}
for _, a := range la {
m = Map2(m, f(a), func(a G, b G) G {
return append(a.([]G), b)
})
}
return m
}
// the value is the last result of chained calculations (fold)
// state monads are generated from inputs
func Chain(la []G, f func(G) StateMonad) StateMonad {
var m StateMonad = StateMonad{
run: func(s S) (S, G) {
return s, nil
},
}
for _, a := range la {
m = Map2(m, f(a), func(a G, b G) G {
return b
})
}
return m
}
func main() {
id := StateMonad{
run: func(s S) (S, G) {
return s.(int), s.(int)
},
}
succ := func(s S) (S, G) {
return s.(int) + 1, s.(int) + 1
}
m1 := StateMonad{ run: succ }
m2 := id.Map(func(a G) G {return a.(int) + 2})
s0 := 0
fmt.Println(id.Eval(s0)) //0
fmt.Println(m1.Eval(s0)) //1
fmt.Println(m2.Eval(s0)) //2
f3 := func(G) StateMonad {
return StateMonad{
run: func(s S) (S, G) {
return s.(int) + 3, s.(int) + 3
},
}
}
fmt.Println(m1.FlatMap(f3).FlatMap(f3).Eval(s0)) //7 = 0 + 1 + 3 + 3
mFactory := func(a G) StateMonad {
return StateMonad{
run: func(s S) (S, G) {
return s.(int) + a.(int), s.(int) + a.(int)
},
}
}
fmt.Println(m1.FlatMap(mFactory).FlatMap(mFactory).FlatMap(mFactory).Eval(s0)) //8 = 0 + 1 + 1 + 2 + 4
m := Map2(m1, m2, func(a G, b G) G {
return a.(int) + 5*b.(int)
})
fmt.Println(m.Eval(s0)) //16 = (0 + 1) + 5*(0 + 1 + 2)
fmt.Println(Sequence(id, m1, mFactory(3)).Eval(s0)) //[0 1 4]
fmt.Println(Traverse([]G{1, 2, 3}, mFactory).Eval(s0)) //[1 3 6]
// currying as a state monad traverse
invoke := func(fn G, args []G) G {
var vs []reflect.Value
for _, arg := range args {
vs = append(vs, reflect.ValueOf(arg))
}
fun := reflect.ValueOf(fn)
return fun.Call(vs)[0].Interface()
}
wrap := func(fn G) func(...G) G {
return func(a ...G) G {
return invoke(fn, a)
}
}
curryFactory := func(a G) StateMonad {
return StateMonad{
run: func(fn S) (S, G) {
if !reflect.TypeOf(fn).IsVariadic() {
fn = wrap(fn)
}
f := func(values ...G) G {
values = append([]G{a}, values...)
return fn.(func(...G) G)(values...)
}
var v G
setValue := func() {
defer func() {
if r := recover(); r != nil {
v = f
}
}()
v = f()
}
setValue()
return f, v
},
}
}
sum := func(a int, b int, c int) int {
return a + b + c
}
size := func(a int, b string) int {
return a + len(b)
}
fmt.Println(Chain([]G{1, 2, 3}, curryFactory).Eval(sum)) //6
fmt.Println(Chain([]G{1, "ab"}, curryFactory).Eval(size)) //3
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment