Last active
November 15, 2020 21:20
-
-
Save marc0der/d1b89b6077639fdd6ffd6df7b9b3855a to your computer and use it in GitHub Desktop.
ST Monad
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ForST private constructor() { | |
companion object | |
} | |
typealias STOf<S, A> = arrow.Kind2<ForST, S, A> | |
typealias STPartialOf<S> = arrow.Kind<ForST, S> | |
@Suppress("UNCHECKED_CAST", "NOTHING_TO_INLINE") | |
inline fun <S, A> STOf<S, A>.fix(): ST<S, A> = this as ST<S, A> | |
abstract class ST<S, A> internal constructor() : STOf<S, A> { | |
companion object { | |
operator fun <S, A> invoke(a: () -> A): ST<S, A> { | |
val memo by lazy(a) | |
return object : ST<S, A>() { | |
override fun run(s: S) = Pair(memo, s) | |
} | |
} | |
fun <A> runST(st: RunnableST<A>): A = | |
st.invoke<Unit>().run(Unit).first | |
} | |
protected abstract fun run(s: S): Pair<A, S> | |
fun <B> map(f: (A) -> B): ST<S, B> = object : ST<S, B>() { | |
override fun run(s: S): Pair<B, S> { | |
val (a, s1) = [email protected](s) | |
return Pair(f(a), s1) | |
} | |
} | |
fun <B> flatMap(f: (A) -> ST<S, B>): ST<S, B> = object : ST<S, B>() { | |
override fun run(s: S): Pair<B, S> { | |
val (a, s1) = [email protected](s) | |
return f(a).run(s1) | |
} | |
} | |
} | |
interface RunnableST<A> { | |
fun <S> invoke(): ST<S, A> | |
} | |
abstract class STRef<S, A> private constructor() { | |
companion object { | |
operator fun <S, A> invoke(a: A): ST<S, STRef<S, A>> = ST { | |
object : STRef<S, A>() { | |
override var cell: A = a | |
} | |
} | |
} | |
protected abstract var cell: A | |
fun read(): ST<S, A> = ST { | |
cell | |
} | |
fun write(a: A): ST<S, Unit> = object : ST<S, Unit>() { | |
override fun run(s: S): Pair<Unit, S> { | |
cell = a | |
return Pair(Unit, s) | |
} | |
} | |
} | |
@extension | |
interface STMonad<S, A> : Monad<STPartialOf<S>> { | |
override fun <A> just(a: A): STOf<S, A> = ST { a } | |
override fun <A, B> STOf<S, A>.flatMap( | |
f: (A) -> STOf<S, B> | |
): STOf<S, B> = | |
this.fix().flatMap { a -> f(a).fix() } | |
override fun <A, B> tailRecM( | |
a: A, | |
f: (A) -> STOf<S, Either<A, B>> | |
): STOf<S, B> = TODO() | |
} | |
//TODO: wish to write the following with a for comprehension | |
val prog = object : RunnableST<Pair<Int, Int>> { | |
override fun <S> invoke(): ST<S, Pair<Int, Int>> = | |
STRef<S, Int>(10).flatMap { r1: STRef<S, Int> -> | |
STRef<S, Int>(20).flatMap { r2: STRef<S, Int> -> | |
r1.read().flatMap { x -> | |
r2.read().flatMap { y -> | |
r1.write(y + 1).flatMap { | |
r2.write(x + 1).flatMap { | |
r1.read().flatMap { a -> | |
r2.read().map { b -> | |
Pair(a, b) | |
} | |
} | |
} | |
} | |
} | |
} | |
} | |
} | |
} | |
fun main() { | |
println(ST.runST(prog)) | |
} |
In the end, I was missing some boilerplate code. All I needed was some sugar on the ST
's companion to delegate the fx function call through to the monad type class' fx
function. Go figure!
fun <S, A> ST.Companion.fx(
c: suspend MonadSyntax<STPartialOf<S>>.() -> A
): ST<S, A> =
ST.monad<S, A>().fx.monad(c).fix()
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@raulraja This is what I'd like to do instead of
prog
above.