I've been working on a side project for the past several days: making a Snake using purely reactive programming with ScalaJS. You can play it at https://wildfield.github.io/snake/. The source code at https://github.com/wildfield/snake-frp
Reader PSA: A lot of wheel reinvention and non-standard terminology ahead. This is not supposed to be a tutorial of any sorts, instead I just wanted to share a cool concept I've been working on.
It's a basic snake, but has all the necessary features:
- Use arrow keys to control the snake
- P for Pause
- R for Reset
- The snake speed goes faster as the score goes up
- Collision with a wall or the snake itself is a Game Over
I didn't like the idea of having discrete events, so this will be a simple implementation of a pull-based continuous time FRP system (with some discretization). In hindsight, this is somewhat similar to Yampa, but simpler.
Let's consider a simple Behavior which describes if we pressed an Up button
def didPressUp(time: Double): Boolean
The idea is that events would be represented as point-like spikes over continuous time:
We want to build an accumulator that would keep the amount of times we pressed an up button:
Assuming we sample the events in some fixed intervals, the initial idea was to just recur over itself until the function hits t = 0
def accumulate(time: Double): Int =
if (time <= 0) 0 else
if (didPressUp(time)) then accumulate(time - DELTA_T) + 1 else accumulate(time)
This would work, except for the fact that it would take more and more time to calculate this value over time. If we introduce memoization, then the memory requirements would also grow over time. In addition, with memoization we would have to add it to every function, which is a lot of boilerplate. So, let's extract the memory part out:
def accumulate(time: Double, pastValue: Int): Int =
if (didPressUp(time)) then pastValue + 1 else pastValue
Then, we can just create a function that would iterate over this and memoize values. I called this function a "backstep". This is a Scala-like pseudocode:
def backstep(time: Double, timeDelta: Double, f: (Double, Int) => Int): Int = {
<memoization>
if (time == 0) {
0
} else {
if (memoized(time)) {
memoized
} else {
memoized = backstep(time - timeDelta, timeDelta, f)
memoized
}
}
}
The trick here is that we don't need to store all values at all possible times. Assuming monotonously increasing time, we only need to store the last computed value. We can make it more generic by introducing T and capturing all mutable state in a closure
def backstep[T](time: Double, timeDelta: Double, f: (Double, T) => T): Double => Int
The issue is that if we start introducing these functions all over the code, we MUST call every single one of them, otherwise the values won't get stored and we fall into the base recursive case. Let's instead keep a single instance of backstep
and focus on that function f
. What if we could combine all these stateful functions together and recur over them?
Let's consider a contrived example: accumulator of an accumulator! This new function would accumulate values of other accumulators and grow in a triangular sequence.
In fact if we change our function signature, we can reuse the function we already have:
def didPressUp(time: Double): Int // 0 if false, 1 if true
def accumulate(time: Double, valueToAdd: Int, pastValue: Int): Int =
pastValue + valueToAdd
// didPressUp -> accumulate -> accumulate
When it comes to combining a function like this with a constant, it's actually pretty simple, we can curry it and then plug it back into backstep
:
def apply(f: (Double, Int, Int) => Int, value: Int): Double => Int =
(time: Double, past: Int) => f(time, value, past)
...
backstep(time, timeDelta, apply(accumulate, MY_CONSTANT))
If we consider combining two functions together, we need the last trick: mapping and flattening. We can derive a very simple flatMap
function for our use case. Very similar definition can be used for a flatMap
for didPressUp
.
def flatMap(
f: (Double, Int, Int) => Int,
mapping: Int => ((Double, Int) => Int)
): (Double, Int, (Int, Int)) => (Int, Int) = {
(time: Double, valueToAdd: Int, pastValue: (Int, Int)) =>
val firstAccValue = f(time, valueToAdd, pastValue._1)
val mappedFunc = mapping(firstAccumulatedValue)
val secondAccValue = mappedFunc(time, pastValue._2)
(firstAccValue, secondAccValue)
}
...
backstep(time, timeDelta, flatMap(
didPressUp,
didPressUpValue => {
apply(
flatMap(accumulate, value => apply(accumulate, value)),
didPressUpValue
)
}
)
The above works, but now the output value is a tuple, instead of a single value, because we need to keep track of both past accumulators, as they will diverge, i.e. (1 -> 2 -> 3 ...) and (1 -> 3 -> 6 ...). When we apply time
to backstep
, we can take the second tuple member and discard the first. As we combine more and more functions, this tuple tends to go out of control pretty quickly. Let's call this a Memory
type and derive a general Stream signature:
type MemoryTuple[T] = (Option[Double], Option[T])
type ReactiveStream[Memory, Input, Output] =
(Time, Input, MemoryTuple[Memory]) => (Output, Memory)
This is pretty similar to our accumulate
function, besides a few additions. past
values are now Option
to handle a case where there are no past values at t = 0, (Double, Memory)
are the last time when this value was recorded along with it's value respectively. Output
is separated from Memory
, because at the end of the day, we only care about the Output, not the Memory.
Let's also mention its little brother: Source
. It's a function that doesn't take any inputs besides its own memory:
type Source[Memory, Output] =
(Time, MemoryTuple[Memory]) => (Output, Memory)
We get Source
if we apply a constant to a Stream
that takes a single value (it's also done by currying):
def apply[Memory, Input, Output](
f1: ReactiveStream[Memory, Input, Output],
a: Input
): Source[Memory, Output]
Source
is useful because we can define this version of flatMap
. It allows to access values inside the stream and apply
them
def flatMapSource[Memory1, Input1, Output1, Memory2, Output2](
f: ReactiveStream[Memory1, Input1, Output1],
map: Output1 => Source[Memory2, Output2]
): ReactiveStream[(Memory1, Memory2), Input1, Output2]
If we define accumulate
as ReactiveStream[Int, Int, Int]
, it would look like this:
def accumulate(time: Double, valueToAdd: Int, pastValue: (Option[Double], Option[Int])]): (Int, Int) =
val value = pastValue._2.getOrElse(0) + valueToAdd
(value, value)
There's some duplication because we need to return both output and memory from each stream ((value, value)
), but thankfully we can derive a series of lift
functions, that can take other functions and lift them to ReactiveStream
when needed. Assuming we have apply
between ReactiveStream
and Behavior
, we can combine everything with a backstep
:
def apply[T1, T2, T3](
f1: ReactiveStream[T1, T2, T3],
f2: Behavior[T2]
): Source[T1, T3]
...
val accumulatorAccumulator = backstep(
timeDelta,
apply(
flatMap(lift(accumulate), value => apply(lift(accumulate), value)),
didPressUp
)
)
...
val currentCount: Int = accumulatorAccumulator(time)
To understand what's going on we need to move from the inside: lift(accumulate)
creates a ReactiveStream
from simpler signatures, apply(lift(accumulate), value)
applies a value from another accumulator (this gives us an accumulator of another accumulator). Then we move to the outer apply
, which accepts values from didMoveUp
(this gives us an accumulator of arrow up presses). This is passed into backstep
, which gives us a Double => Int
function with memoization. The accumulatorAccumulator(time)
gives us the final value we are interested in. accumulatorAccumulator
can be called repeatedly with an increasing time to get the desired behavior.
This gives us most of the primitives we need to implement the Snake with.
Since we are making a game in javascript, we need to react to document
events. We will subscribe to "onkeydown" and bridge the imperative world with our FRP system:
document.addEventListener(
"keydown",
(e: dom.KeyboardEvent) =>
e.key match {
...
case "ArrowLeft" =>
onPressLeft()
...
}
)
When we press a button we want to capture current time of event in the time of the FRP system. The time in our system starts from 0, we can just substract the start of our application
var leftPressTime: Double = Double.MinValue
val start = new Date()
def curTime(): Double =
(new Date()).getTime() - start.getTime()
def onPressLeft(): Unit = {
leftPressTime = curTime()
}
This allows us to calculate whether the event happened via the following general function
def didPress(
getTime: () => Double
): Source[Unit, Option[Double]] = {
def _didPress(time: Double, past: MemoryTuple[Unit]): (Option[Double], Unit) = {
val pressedTime = getTime()
val lastTime = past._1.getOrElse(0.0)
val pressDetected = (pressedTime >= lastTime) && (pressedTime <= time)
(
if pressDetected then Some(pressedTime) else None,
()
)
}
_didPress
}
...
val leftPress = didPress(() => this.leftPressTime)
getTime()
is a function specific to each key, but didPress
itself is general. Each time we invoke this function we are passing both current time and previous time. We are not interested in any actual memory value, therefore Memory = Unit
. Having the interval of interest, we can calculate whether an event happened just by comparing the bounds of the interval. The return value is Option[Double]
because we might to order events even when if they happen within the sample interval.
The sampling interval is dynamic, so we are not refering to any DELTA_T constant in this function. We can expect it to change between invocations
Other keys can implemented in the same fashion
val upPress = didPress(() => this.upPressTime)
val downPress = didPress(() => this.downPressTime)
val leftPress = didPress(() => this.leftPressTime)
val rightPress = didPress(() => this.rightPressTime)
val upRelease = didPress(() => this.upReleaseTime)
val downRelease = didPress(() => this.downReleaseTime)
val leftRelease = didPress(() => this.leftReleaseTime)
val rightRelease = didPress(() => this.rightReleaseTime)
To store the signal of key presses to the key release, we need to introduce a button latch
def buttonStateLatch(
time: Double,
arg: (Option[Double], Option[Double]),
past: MemoryTuple[Option[Double]]
): (Option[Double], Option[Double]) = {
val (press, release) = arg
val result = if (!release.isEmpty) {
None
} else if (!press.isEmpty) {
press
} else {
past._2.flatten
}
(result, result)
}
The code is relatively straightforward, but since the button presses are implemented as Option[Double]
, to check if a press happened we need to check for !release.isEmpty
. To apply the latch, we would need to introduce a helper function makeButtonLatch
and apply it to each key pair
def makeButtonLatch[T1](
press: Source[T1, Option[Double]],
release: Source[T1, Option[Double]]
): Source[((T1, T1), Option[Double]), Option[Double]] = {
flatMapSource(
pair(press, release),
keyPair => {
apply(buttonStateLatch, keyPair)
}
)
}
val upLatch = makeButtonLatch(upPress, upRelease)
val downLatch = makeButtonLatch(downPress, downRelease)
val leftLatch = makeButtonLatch(leftPress, leftRelease)
val rightLatch = makeButtonLatch(rightPress, rightRelease)
I kept the memory type as a generic because makeButtonLatch
doesn't depend on the Memory type, so this way the implementation is more flexible. To finish it off, we put all the keys into Keys
case class:
case class Keys(
left: Option[Double],
right: Option[Double],
down: Option[Double],
up: Option[Double]
)
object Keys {
def from_tuples(
arguments: ((Option[Double], Option[Double]), (Option[Double], Option[Double]))
): Keys = {
val ((left, right), (down, up)) = arguments
Keys(left, right, down, up)
}
}
val keyTuples = pair(pair(leftLatch, rightLatch), pair(downLatch, upLatch))
val keyValues = map(keyTuples, Keys.from_tuples)
To use the movement, we need to convert into cardinal directions. This method returns an array of directions sorted by the key press date, which the latest being in front. We will later filter out invalid directions from this List.
def desiredDirections(
keys: Keys
): List[Direction]
The snake moves in a jumpy way: we need to wait n milliseconds before each move. To simulate this, we introduce a tick
method that fires a time pulse every n milliseconds, while keeping accumulated time in-between:
def tick(
time: Double,
args: (Boolean, Int),
past: MemoryTuple[Double]
): (Option[Double], Double) = {
...
val pastTime = past._1.getOrElse(0.0)
val accumulatedTime = past._2.getOrElse(0.0)
val totalTime = (time - pastTime) + accumulatedTime
...
val pulses = (totalTime / effectivePulseTime).toInt
...
}
Option[Double]
returns a duration if tick is present, while the memory is tracking time not yet enough for a full tick. The arguments are "stop" and "score" respectively, which determine the behavior of our tick function: we want to stop generating ticks if "stop" is true and we want to generate ticks faster, the higher the score. This function will be central to the synchronization of movement across the whole game
The movement will be calculated having a direction and a tick:
def movement(
arguments: (Option[Double], Direction)
): Option[Vect2d] = {
val (speedPulse, direction) = arguments
speedPulse.map((speed) => Vect2d(direction.x * speed, direction.y * speed))
}
This brings to the next problem: if the tick is generating time pulses and time pulses are used to move a snake, how can we get the score information before the tick even happened? This is an example of an apparent circular dependency, but we have a solution: we can use score from the past!
If we want to get score information before we process a tick, we would need to look at the values from the previous iteration. The combinator that would help us is pastFlatMap
def pastFlatMap[T1, T2, T3, T4, T5, T6, T7](
f1: ReactiveStream[T1, T2, T3],
map: (ReactiveStream[T1, T2, T3], Option[T5]) => ReactiveStream[T4, T7, (T6, T5)]
): ReactiveStream[(T4, T5), T7, T6]
The trick here is that we can "borrow" a value from the future: Option[T5]
, with the contract that we need to return it back when we are done with the transfromation: (T6, T5)
Output type. Option
is used because at the very first iteration we don't have a previous value, so it will be None
.
The way this works here is that we first combine the tick with a direction storing function. This is needed to validate direction, which I'll skip. This gives us a new stream which outputs the new direction, but still accepts (Boolean, Int)
that we need for the tick:
val timeWithPause = ...
val resultingMovement =
flatMapSource(
timeWithPause, {
...
})
val gameState = pastFlatMapSource(
resultingMovement,
(resultingMovement, memory: Option[(Boolean, Int)]) => {
val pastGameOver = memory.map(_._1).getOrElse(false)
val pastScore = memory.map(_._2).getOrElse(0)
flatMapSource(
apply(
resultingMovement,
(pastGameOver, pastScore)
),
...
(
...,
(gameOver, score)
)
)
}
)
The rest of the snake code is applying the same combinators with various streams to get the information we need where we need it. E.g. to display food, we need to know the game bounds and a past snake. Food returns us the food position in Vect2d
and whether we ate it in the Boolean
.
def food(bounds: Rect): ReactiveStream[
(Vect2d, Boolean),
List[Vect2d],
(Vect2d, Boolean)
] = {
...
def _food(
time: Double,
snake: List[Vect2d],
past: MemoryTuple[Vect2d]
): ((Vect2d, Boolean), Vect2d) = {
val oldFoodPosition = past._2.getOrElse(Vect2d(SNAKE_SIZE * 3, SNAKE_SIZE * 4))
snake.headOption match {
case None => ((oldFoodPosition, false), oldFoodPosition)
case Some(head) =>
if (
(head.x / SNAKE_SIZE).toInt == (oldFoodPosition.x / SNAKE_SIZE).toInt
&& (head.y / SNAKE_SIZE).toInt == (oldFoodPosition.y / SNAKE_SIZE).toInt
) { ... }
}
}
_food
}
Then we also apply it with the past value of the snake in the same fashion. We need to use pastFlatMap
because the snake position also depends on interaction with the food: we make the snake longer if we ate any food.
...
pastFlatMapSource(
food(bounds),
(food, pastSnakeOption: Option[List[Vect2d]]) => {
val pastSnake = pastSnakeOption.getOrElse(List())
flatMapSource(
apply(
food,
pastSnake
),
{ ... }
)
}
)
...
The hierarchies can become pretty deep, but to understand what's going you need to start looking from the inside: domain-specific functions and their arguments, their immediate combinators, and then up the chain.
At the end of the chain we have gameState
which contain all the necessary information about the game. We use a DrawOp
sealed trait to encode the information of a drawing operation we need and then we combine everything into a big list of operations:
sealed trait DrawOp
case class DrawRect(x: Double, y: Double, w: Double, h: Double, color: String) extends DrawOp
case class DrawText(x: Double, y: Double, text: String, font: String, color: String) extends DrawOp
...
def drawSnake(snake: List[Vect2d]): List[DrawOp] = {
snake.zipWithIndex.map((elem, idx) => {
val color = if idx == 0 then "#ffff00" else "#ff0000"
DrawRect(elem.x, elem.y, SNAKE_SIZE, SNAKE_SIZE, color)
})
}
...
drawClearScreen(bounds) ::: drawFood(food)
::: drawSnake(snake) ::: drawScore(score)
::: drawHighScore(highScore) ::: drawPause(pause) ::: drawGameOver(gameOver)
...
for (op <- drawOpsValue) {
op match {
case DrawRect(x, y, w, h, color) => { ... }
case DrawText(x, y, text, font, color) => { ... }
}
}
We get drawOpsValue
by applying backstep
to our final combined function. Sometimes delta
can become very large (e.g. when we switch tabs), so to avoid recursion stack overflow we choose time sampling to produce at most 200 iterations.
val draws = flatMapSource(
stateWithHighScore,
{ ... }
)
val drawOps = backstep(draws)
...
val time = (new Date()).getTime() - start.getTime()
val timeSampling = (delta / 200.0).max(DELTA_T)
val drawOpsValue = drawOps(time, timeSampling)
...
Because our memory tuples are type MemoryTuple[T] = (Option[Double], Option[T])
, to reset the game state all we need is to pass None
as a second argument. Functions which care about the memory would initialize it with default values. This what listenToReset
function does
def listenToReset[T1, T3](
f1: Source[T1, (T3, Boolean)]
): Source[Option[T1], T3] = {
def _listenToReset(
time: Double,
past: MemoryTuple[Option[T1]]
): (T3, Option[T1]) = {
val ((output, reset), memory) =
f1(time, (past._1, past._2.flatten))
if (reset) {
(output, None)
} else {
(output, Some(memory))
}
}
_listenToReset
}
Then we hook it up with the R
key:
val stateWithReset = listenToReset(
map(
pair(rPress, gameState),
output => {
val (rPress, gameState) = output
(gameState, !rPress.isEmpty)
}
)
)
Because this operates at the level of combined Memory, we don't really care what's exactly inside the memory. This function will reset everything up until the point of the function application. Any streams that are attached later (e.g. highScore
) are not subject to reset.
The reset
conceptually operates similar to how switching streams would operate, just in this case, we switch the stream with itself
Because we keep track of the inputs at every iteration, we know exactly when the state is updated. The state update mostly depends if our tick
function has produced a tick. This has 2 exceptions: paused
and gameOver
states which both stop ticks from coming. Additional exception is focusin
javascript event which happens when we switch the tabs back. If we don't redraw the screen on switch back, then the canvas will be out of date. Combined together we can implement a simple shouldRedraw
function:
def shouldRedraw(
time: Double,
input: (Option[Double], Boolean, Boolean, Boolean),
past: MemoryTuple[(Boolean, Boolean)]
): (Boolean, (Boolean, Boolean)) = {
val (tick, paused, gameOver, focusIn) = input
val (pastPaused, pastGameOver) = past._2.getOrElse((false, false))
val should = focusIn || !(tick.isEmpty && paused == pastPaused && gameOver == pastGameOver)
(should, (paused, gameOver))
}
...
shouldRedraw => {
if (shouldRedraw) {
drawClearScreen(bounds) ::: drawFood(food)
::: drawSnake(snake) ::: drawScore(score)
::: drawHighScore(highScore) ::: drawPause(pause) ::: drawGameOver(gameOver)
} else {
List()
}
}
...
As a result we spend barely any time drawing data on the screen
There are some suprising benefits to this approach: we know exactly when the state is updated so we can decide to update the canvas or skip this, the resets are simple and all state that depends on other state is always in sync
The drawbacks are:
- Typing errors are pretty much impossible to debug. I had to put a lot of
???
and spend time thinking why the types don't match - Every combinator is essentially a new function. So to get information we need, we have to construct a bunch of functions and call them on every iteration. That's not very effecient
- Passing the state around can sometimes be a pain. If we introduce a new state at the beginning of our streams, it might be difficult to pass it all the way to the end
- Passing streams (not
map
orflatMap
arguments) around can result in memory duplication if the same stream is referenced twice. Sometimes this is what we want (e.g. key presses), but other times we have to remember to wrap the state extraction inmap
orflatMap
if we want to use it multiple times.
Is it a practical way to write software? Probably not at this iteration. Was it fun writing this? Yep, it was.
I might spend more time extracting this in a separate framework if there's any value in it. Otherwise, feel free to check out the code at https://github.com/wildfield/snake-frp and reuse some code for yourself.