Skip to content

Instantly share code, notes, and snippets.

@wildfield
Last active July 12, 2024 07:46
Show Gist options
  • Save wildfield/5cef98101b5b37d117afa4c29573b497 to your computer and use it in GitHub Desktop.
Save wildfield/5cef98101b5b37d117afa4c29573b497 to your computer and use it in GitHub Desktop.

Snake With FRP

Introduction

The Game

I've been working on a side project for the past several days: making a Snake using purely reactive programming with ScalaJS. You can play it at https://wildfield.github.io/snake/. The source code at https://github.com/wildfield/snake-frp

Reader PSA: A lot of wheel reinvention and non-standard terminology ahead. This is not supposed to be a tutorial of any sorts, instead I just wanted to share a cool concept I've been working on.

snake

It's a basic snake, but has all the necessary features:

  • Use arrow keys to control the snake
  • P for Pause
  • R for Reset
  • The snake speed goes faster as the score goes up
  • Collision with a wall or the snake itself is a Game Over

Background

I didn't like the idea of having discrete events, so this will be a simple implementation of a pull-based continuous time FRP system (with some discretization). In hindsight, this is somewhat similar to Yampa, but simpler.

Let's consider a simple Behavior which describes if we pressed an Up button

def didPressUp(time: Double): Boolean

The idea is that events would be represented as point-like spikes over continuous time:

Point-like events

We want to build an accumulator that would keep the amount of times we pressed an up button:

Point-like events - Copy

Assuming we sample the events in some fixed intervals, the initial idea was to just recur over itself until the function hits t = 0

def accumulate(time: Double): Int =
  if (time <= 0) 0 else
    if (didPressUp(time)) then accumulate(time - DELTA_T) + 1 else accumulate(time)

This would work, except for the fact that it would take more and more time to calculate this value over time. If we introduce memoization, then the memory requirements would also grow over time. In addition, with memoization we would have to add it to every function, which is a lot of boilerplate. So, let's extract the memory part out:

def accumulate(time: Double, pastValue: Int): Int =
   if (didPressUp(time)) then pastValue + 1 else pastValue

Then, we can just create a function that would iterate over this and memoize values. I called this function a "backstep". This is a Scala-like pseudocode:

def backstep(time: Double, timeDelta: Double, f: (Double, Int) => Int): Int = {
  <memoization>
  if (time == 0) {
    0
  } else {
    if (memoized(time)) {
      memoized
    } else {
      memoized = backstep(time - timeDelta, timeDelta, f)
      memoized
    }
  }
}

The trick here is that we don't need to store all values at all possible times. Assuming monotonously increasing time, we only need to store the last computed value. We can make it more generic by introducing T and capturing all mutable state in a closure

def backstep[T](time: Double, timeDelta: Double, f: (Double, T) => T): Double => Int

The issue is that if we start introducing these functions all over the code, we MUST call every single one of them, otherwise the values won't get stored and we fall into the base recursive case. Let's instead keep a single instance of backstep and focus on that function f. What if we could combine all these stateful functions together and recur over them?

Accumulators all the way down

Let's consider a contrived example: accumulator of an accumulator! This new function would accumulate values of other accumulators and grow in a triangular sequence.

Point-like events - Copy - Copy

In fact if we change our function signature, we can reuse the function we already have:

def didPressUp(time: Double): Int // 0 if false, 1 if true

def accumulate(time: Double, valueToAdd: Int, pastValue: Int): Int =
  pastValue + valueToAdd

// didPressUp -> accumulate -> accumulate

When it comes to combining a function like this with a constant, it's actually pretty simple, we can curry it and then plug it back into backstep:

def apply(f: (Double, Int, Int) => Int, value: Int): Double => Int =
  (time: Double, past: Int) => f(time, value, past)
...
backstep(time, timeDelta, apply(accumulate, MY_CONSTANT))

If we consider combining two functions together, we need the last trick: mapping and flattening. We can derive a very simple flatMap function for our use case. Very similar definition can be used for a flatMap for didPressUp.

def flatMap(
  f: (Double, Int, Int) => Int, 
  mapping: Int => ((Double, Int) => Int)
): (Double, Int, (Int, Int)) => (Int, Int) = {
  (time: Double, valueToAdd: Int, pastValue: (Int, Int)) =>
    val firstAccValue = f(time, valueToAdd, pastValue._1)
    val mappedFunc = mapping(firstAccumulatedValue)
    val secondAccValue = mappedFunc(time, pastValue._2)
    (firstAccValue, secondAccValue)
}
...
backstep(time, timeDelta, flatMap(
  didPressUp,
  didPressUpValue => {
    apply(
      flatMap(accumulate, value => apply(accumulate, value)),
      didPressUpValue
    )
  }
)

General Streams

The above works, but now the output value is a tuple, instead of a single value, because we need to keep track of both past accumulators, as they will diverge, i.e. (1 -> 2 -> 3 ...) and (1 -> 3 -> 6 ...). When we apply time to backstep, we can take the second tuple member and discard the first. As we combine more and more functions, this tuple tends to go out of control pretty quickly. Let's call this a Memory type and derive a general Stream signature:

type MemoryTuple[T] = (Option[Double], Option[T])

type ReactiveStream[Memory, Input, Output] =
  (Time, Input, MemoryTuple[Memory]) => (Output, Memory)

This is pretty similar to our accumulate function, besides a few additions. past values are now Option to handle a case where there are no past values at t = 0, (Double, Memory) are the last time when this value was recorded along with it's value respectively. Output is separated from Memory, because at the end of the day, we only care about the Output, not the Memory.

Let's also mention its little brother: Source. It's a function that doesn't take any inputs besides its own memory:

type Source[Memory, Output] =
  (Time, MemoryTuple[Memory]) => (Output, Memory)

We get Source if we apply a constant to a Stream that takes a single value (it's also done by currying):

def apply[Memory, Input, Output](
    f1: ReactiveStream[Memory, Input, Output],
    a: Input
): Source[Memory, Output]

Source is useful because we can define this version of flatMap. It allows to access values inside the stream and apply them

def flatMapSource[Memory1, Input1, Output1, Memory2, Output2](
    f: ReactiveStream[Memory1, Input1, Output1],
    map: Output1 => Source[Memory2, Output2]
): ReactiveStream[(Memory1, Memory2), Input1, Output2]

If we define accumulate as ReactiveStream[Int, Int, Int], it would look like this:

def accumulate(time: Double, valueToAdd: Int, pastValue: (Option[Double], Option[Int])]): (Int, Int) =
  val value = pastValue._2.getOrElse(0) + valueToAdd
  (value, value)

There's some duplication because we need to return both output and memory from each stream ((value, value)), but thankfully we can derive a series of lift functions, that can take other functions and lift them to ReactiveStream when needed. Assuming we have apply between ReactiveStream and Behavior, we can combine everything with a backstep:

def apply[T1, T2, T3](
    f1: ReactiveStream[T1, T2, T3],
    f2: Behavior[T2]
): Source[T1, T3]
...
val accumulatorAccumulator = backstep(
    timeDelta, 
    apply(
      flatMap(lift(accumulate), value => apply(lift(accumulate), value)),
      didPressUp
    )
)
...
val currentCount: Int = accumulatorAccumulator(time) 

To understand what's going on we need to move from the inside: lift(accumulate) creates a ReactiveStream from simpler signatures, apply(lift(accumulate), value) applies a value from another accumulator (this gives us an accumulator of another accumulator). Then we move to the outer apply, which accepts values from didMoveUp (this gives us an accumulator of arrow up presses). This is passed into backstep, which gives us a Double => Int function with memoization. The accumulatorAccumulator(time) gives us the final value we are interested in. accumulatorAccumulator can be called repeatedly with an increasing time to get the desired behavior.

This gives us most of the primitives we need to implement the Snake with.

Implementing the Snake

Handling Input

Since we are making a game in javascript, we need to react to document events. We will subscribe to "onkeydown" and bridge the imperative world with our FRP system:

document.addEventListener(
      "keydown",
      (e: dom.KeyboardEvent) =>
        e.key match {
          ...
          case "ArrowLeft" =>
            onPressLeft()
          ...
       }
)

When we press a button we want to capture current time of event in the time of the FRP system. The time in our system starts from 0, we can just substract the start of our application

var leftPressTime: Double = Double.MinValue
val start = new Date()

def curTime(): Double =
  (new Date()).getTime() - start.getTime()

def onPressLeft(): Unit = {
  leftPressTime = curTime()
}

This allows us to calculate whether the event happened via the following general function

  def didPress(
      getTime: () => Double
  ): Source[Unit, Option[Double]] = {
    def _didPress(time: Double, past: MemoryTuple[Unit]): (Option[Double], Unit) = {
      val pressedTime = getTime()
      val lastTime = past._1.getOrElse(0.0)
      val pressDetected = (pressedTime >= lastTime) && (pressedTime <= time)
      (
        if pressDetected then Some(pressedTime) else None,
        ()
      )
    }
    _didPress
  }
...
val leftPress = didPress(() => this.leftPressTime)

getTime() is a function specific to each key, but didPress itself is general. Each time we invoke this function we are passing both current time and previous time. We are not interested in any actual memory value, therefore Memory = Unit. Having the interval of interest, we can calculate whether an event happened just by comparing the bounds of the interval. The return value is Option[Double] because we might to order events even when if they happen within the sample interval.

Point-like events - Copy (2)

The sampling interval is dynamic, so we are not refering to any DELTA_T constant in this function. We can expect it to change between invocations

Other keys can implemented in the same fashion

val upPress = didPress(() => this.upPressTime)
val downPress = didPress(() => this.downPressTime)
val leftPress = didPress(() => this.leftPressTime)
val rightPress = didPress(() => this.rightPressTime)

val upRelease = didPress(() => this.upReleaseTime)
val downRelease = didPress(() => this.downReleaseTime)
val leftRelease = didPress(() => this.leftReleaseTime)
val rightRelease = didPress(() => this.rightReleaseTime)

To store the signal of key presses to the key release, we need to introduce a button latch

def buttonStateLatch(
    time: Double,
    arg: (Option[Double], Option[Double]),
    past: MemoryTuple[Option[Double]]
): (Option[Double], Option[Double]) = {
  val (press, release) = arg
  val result = if (!release.isEmpty) {
    None
  } else if (!press.isEmpty) {
    press
  } else {
    past._2.flatten
  }
  (result, result)
}

The code is relatively straightforward, but since the button presses are implemented as Option[Double], to check if a press happened we need to check for !release.isEmpty. To apply the latch, we would need to introduce a helper function makeButtonLatch and apply it to each key pair

def makeButtonLatch[T1](
    press: Source[T1, Option[Double]],
    release: Source[T1, Option[Double]]
): Source[((T1, T1), Option[Double]), Option[Double]] = {
  flatMapSource(
    pair(press, release),
    keyPair => {
      apply(buttonStateLatch, keyPair)
    }
  )
}
  
val upLatch = makeButtonLatch(upPress, upRelease)
val downLatch = makeButtonLatch(downPress, downRelease)
val leftLatch = makeButtonLatch(leftPress, leftRelease)
val rightLatch = makeButtonLatch(rightPress, rightRelease)

I kept the memory type as a generic because makeButtonLatch doesn't depend on the Memory type, so this way the implementation is more flexible. To finish it off, we put all the keys into Keys case class:

case class Keys(
    left: Option[Double],
    right: Option[Double],
    down: Option[Double],
    up: Option[Double]
)

object Keys {
  def from_tuples(
      arguments: ((Option[Double], Option[Double]), (Option[Double], Option[Double]))
  ): Keys = {
    val ((left, right), (down, up)) = arguments
    Keys(left, right, down, up)
  }
}

val keyTuples = pair(pair(leftLatch, rightLatch), pair(downLatch, upLatch))
val keyValues = map(keyTuples, Keys.from_tuples)

To use the movement, we need to convert into cardinal directions. This method returns an array of directions sorted by the key press date, which the latest being in front. We will later filter out invalid directions from this List.

def desiredDirections(
    keys: Keys
): List[Direction]

Handling Time

The snake moves in a jumpy way: we need to wait n milliseconds before each move. To simulate this, we introduce a tick method that fires a time pulse every n milliseconds, while keeping accumulated time in-between:

def tick(
    time: Double,
    args: (Boolean, Int),
    past: MemoryTuple[Double]
): (Option[Double], Double) = {
...
val pastTime = past._1.getOrElse(0.0)
val accumulatedTime = past._2.getOrElse(0.0)
val totalTime = (time - pastTime) + accumulatedTime
...
val pulses = (totalTime / effectivePulseTime).toInt
...
}

Option[Double] returns a duration if tick is present, while the memory is tracking time not yet enough for a full tick. The arguments are "stop" and "score" respectively, which determine the behavior of our tick function: we want to stop generating ticks if "stop" is true and we want to generate ticks faster, the higher the score. This function will be central to the synchronization of movement across the whole game

The movement will be calculated having a direction and a tick:

def movement(
    arguments: (Option[Double], Direction)
): Option[Vect2d] = {
  val (speedPulse, direction) = arguments
  speedPulse.map((speed) => Vect2d(direction.x * speed, direction.y * speed))
}

This brings to the next problem: if the tick is generating time pulses and time pulses are used to move a snake, how can we get the score information before the tick even happened? This is an example of an apparent circular dependency, but we have a solution: we can use score from the past!

Handling Circular Dependencies

If we want to get score information before we process a tick, we would need to look at the values from the previous iteration. The combinator that would help us is pastFlatMap

def pastFlatMap[T1, T2, T3, T4, T5, T6, T7](
    f1: ReactiveStream[T1, T2, T3],
    map: (ReactiveStream[T1, T2, T3], Option[T5]) => ReactiveStream[T4, T7, (T6, T5)]
): ReactiveStream[(T4, T5), T7, T6]

The trick here is that we can "borrow" a value from the future: Option[T5], with the contract that we need to return it back when we are done with the transfromation: (T6, T5) Output type. Option is used because at the very first iteration we don't have a previous value, so it will be None.

The way this works here is that we first combine the tick with a direction storing function. This is needed to validate direction, which I'll skip. This gives us a new stream which outputs the new direction, but still accepts (Boolean, Int) that we need for the tick:

val timeWithPause = ...

val resultingMovement =
      flatMapSource(
        timeWithPause, {
        ...
      })

val gameState = pastFlatMapSource(
  resultingMovement,
  (resultingMovement, memory: Option[(Boolean, Int)]) => {
    val pastGameOver = memory.map(_._1).getOrElse(false)
    val pastScore = memory.map(_._2).getOrElse(0)
    flatMapSource(
      apply(
        resultingMovement,
        (pastGameOver, pastScore)
      ),
      ...
      (
        ...,
        (gameOver, score)
      )
    )
  }
)

The rest of the snake code is applying the same combinators with various streams to get the information we need where we need it. E.g. to display food, we need to know the game bounds and a past snake. Food returns us the food position in Vect2d and whether we ate it in the Boolean.

def food(bounds: Rect): ReactiveStream[
  (Vect2d, Boolean),
  List[Vect2d],
  (Vect2d, Boolean)
] = {
  ...
  def _food(
        time: Double,
        snake: List[Vect2d],
        past: MemoryTuple[Vect2d]
    ): ((Vect2d, Boolean), Vect2d) = {
      val oldFoodPosition = past._2.getOrElse(Vect2d(SNAKE_SIZE * 3, SNAKE_SIZE * 4))
      snake.headOption match {
        case None => ((oldFoodPosition, false), oldFoodPosition)
        case Some(head) =>
          if (
            (head.x / SNAKE_SIZE).toInt == (oldFoodPosition.x / SNAKE_SIZE).toInt
            && (head.y / SNAKE_SIZE).toInt == (oldFoodPosition.y / SNAKE_SIZE).toInt
          ) { ... }
      }
  }
  _food
}

Then we also apply it with the past value of the snake in the same fashion. We need to use pastFlatMap because the snake position also depends on interaction with the food: we make the snake longer if we ate any food.

...
pastFlatMapSource(
  food(bounds),
  (food, pastSnakeOption: Option[List[Vect2d]]) => {
    val pastSnake = pastSnakeOption.getOrElse(List())
    flatMapSource(
      apply(
        food,
        pastSnake
      ),
      { ... }
    )
  }
)
...

The hierarchies can become pretty deep, but to understand what's going you need to start looking from the inside: domain-specific functions and their arguments, their immediate combinators, and then up the chain.

Drawing

At the end of the chain we have gameState which contain all the necessary information about the game. We use a DrawOp sealed trait to encode the information of a drawing operation we need and then we combine everything into a big list of operations:

sealed trait DrawOp
case class DrawRect(x: Double, y: Double, w: Double, h: Double, color: String) extends DrawOp
case class DrawText(x: Double, y: Double, text: String, font: String, color: String) extends DrawOp
...
def drawSnake(snake: List[Vect2d]): List[DrawOp] = {
  snake.zipWithIndex.map((elem, idx) => {
    val color = if idx == 0 then "#ffff00" else "#ff0000"
    DrawRect(elem.x, elem.y, SNAKE_SIZE, SNAKE_SIZE, color)
  })
}
...
drawClearScreen(bounds) ::: drawFood(food)
  ::: drawSnake(snake) ::: drawScore(score)
  ::: drawHighScore(highScore) ::: drawPause(pause) ::: drawGameOver(gameOver)
...
for (op <- drawOpsValue) {
  op match {
    case DrawRect(x, y, w, h, color) => { ... }
    case DrawText(x, y, text, font, color) => { ... }
  }
}

We get drawOpsValue by applying backstep to our final combined function. Sometimes delta can become very large (e.g. when we switch tabs), so to avoid recursion stack overflow we choose time sampling to produce at most 200 iterations.

val draws = flatMapSource(
  stateWithHighScore,
  { ... }
)
val drawOps = backstep(draws)
...
val time = (new Date()).getTime() - start.getTime()
val timeSampling = (delta / 200.0).max(DELTA_T)
val drawOpsValue = drawOps(time, timeSampling)
...

Nifty Features

Reset

Because our memory tuples are type MemoryTuple[T] = (Option[Double], Option[T]), to reset the game state all we need is to pass None as a second argument. Functions which care about the memory would initialize it with default values. This what listenToReset function does


def listenToReset[T1, T3](
    f1: Source[T1, (T3, Boolean)]
): Source[Option[T1], T3] = {
  def _listenToReset(
      time: Double,
      past: MemoryTuple[Option[T1]]
  ): (T3, Option[T1]) = {
    val ((output, reset), memory) =
      f1(time, (past._1, past._2.flatten))
    if (reset) {
      (output, None)
    } else {
      (output, Some(memory))
    }
  }
  _listenToReset
}

Then we hook it up with the R key:

val stateWithReset = listenToReset(
  map(
    pair(rPress, gameState),
    output => {
      val (rPress, gameState) = output
      (gameState, !rPress.isEmpty)
    }
  )
)

Because this operates at the level of combined Memory, we don't really care what's exactly inside the memory. This function will reset everything up until the point of the function application. Any streams that are attached later (e.g. highScore) are not subject to reset.

The reset conceptually operates similar to how switching streams would operate, just in this case, we switch the stream with itself

Caching

Because we keep track of the inputs at every iteration, we know exactly when the state is updated. The state update mostly depends if our tick function has produced a tick. This has 2 exceptions: paused and gameOver states which both stop ticks from coming. Additional exception is focusin javascript event which happens when we switch the tabs back. If we don't redraw the screen on switch back, then the canvas will be out of date. Combined together we can implement a simple shouldRedraw function:

def shouldRedraw(
    time: Double,
    input: (Option[Double], Boolean, Boolean, Boolean),
    past: MemoryTuple[(Boolean, Boolean)]
): (Boolean, (Boolean, Boolean)) = {
  val (tick, paused, gameOver, focusIn) = input
  val (pastPaused, pastGameOver) = past._2.getOrElse((false, false))
  val should = focusIn || !(tick.isEmpty && paused == pastPaused && gameOver == pastGameOver)
  (should, (paused, gameOver))
}
...
shouldRedraw => {
  if (shouldRedraw) {
    drawClearScreen(bounds) ::: drawFood(food)
      ::: drawSnake(snake) ::: drawScore(score)
      ::: drawHighScore(highScore) ::: drawPause(pause) ::: drawGameOver(gameOver)
  } else {
    List()
  }
}
...

As a result we spend barely any time drawing data on the screen

Performance

Conclusions

There are some suprising benefits to this approach: we know exactly when the state is updated so we can decide to update the canvas or skip this, the resets are simple and all state that depends on other state is always in sync

The drawbacks are:

  1. Typing errors are pretty much impossible to debug. I had to put a lot of ??? and spend time thinking why the types don't match
  2. Every combinator is essentially a new function. So to get information we need, we have to construct a bunch of functions and call them on every iteration. That's not very effecient
  3. Passing the state around can sometimes be a pain. If we introduce a new state at the beginning of our streams, it might be difficult to pass it all the way to the end
  4. Passing streams (not map or flatMap arguments) around can result in memory duplication if the same stream is referenced twice. Sometimes this is what we want (e.g. key presses), but other times we have to remember to wrap the state extraction in map or flatMap if we want to use it multiple times.

Is it a practical way to write software? Probably not at this iteration. Was it fun writing this? Yep, it was.

What's Next?

I might spend more time extracting this in a separate framework if there's any value in it. Otherwise, feel free to check out the code at https://github.com/wildfield/snake-frp and reuse some code for yourself.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment