Skip to content

Instantly share code, notes, and snippets.

@Daenyth
Last active November 9, 2023 17:14
Show Gist options
  • Save Daenyth/024c5584da01acabe7a435c8a53c4f3c to your computer and use it in GitHub Desktop.
Save Daenyth/024c5584da01acabe7a435c8a53c4f3c to your computer and use it in GitHub Desktop.
Designing an fs2 `Pull` from scratch

The problem

I have some data which has adjacent entries that I want to group together and perform actions on. I know roughly that fs2.Pull can be used to "step" through a stream and do more complicated logic than the built in combinators allow. I don't know how to write one though!

In the end we should have something like

def combineAdjacent[F[_], A](
    shouldCombine: (A, A) => Boolean,
    combine: (A, A) => A
): Pipe[F, A, A] =
  ???

The rest of this documents the steps I followed to implement the right behavior

About Pull

Stream[F, A] is a stream with effects in F and elements in A. Pull[F, A, R] is a pull with effects in F, elements in A, and a return value in R. Streams are implemented in terms of Pull[F, A, Unit] (more or less; handwaving some details). Streams are monads in A, meaning they have flatMap(f: A => Stream[F, B]): Stream[F, B]. Pull is a monad in R, so it has flatMap(f: R => Pull[F, A, R2]): Pull[F, A, R2].

You might think of List's :: (pronounced "Cons") case as "The next element, and then the rest of the list", and Pull's uncons1 helper allows us to talk about "The next stream element, and the rest of the stream".

We will mainly be using pull's uncons1, which lets us look at Option[(A, Stream[F, A])] - this is very similar to using List recursion.

Type Operation Result
List[A] match Nil or ::(A, List[A])
Stream[F, A] .pull.uncons1 Pull[F, A, Option[(A, Stream[F, A])]

Note how we have List[A] to (A, List[A]) vs Stream[F, A] to Option[(A, Stream[F, A])]

We could write an uncons1 for list;

def uncons1[A](as: List[A]): Option[(A, List[A])] =
  as.headOption.map(a => (a, as.tail))

With this in mind, let's solve our Pull problem by figuring out how we'd write the equivalent option for List

Testing

First, let's write some test cases for the behavior we want to have.

  describe("combineAdjacent") {
    it("doesn't combine when condition is false") {
      val s1 = Stream.range(0, 10)
      val noCombine =
        combineAdjacent[Pure, Int]((_, _) => false, null)
      val s2 = s1.through(noCombine)
      s2.toVector shouldEqual s1.toVector
    }

    it("combines when condition is true") {
      val s1 = Stream(1, 3, 5, 7, 11, 13, 17, 23, 27)
      val s2 = s1.flatMap(i => Stream.emits(Seq(i, i)))
      val combineEqual = combineAdjacent[Pure, Int](_ == _, _ + _)
      s2.through(combineEqual).toVector shouldEqual s1.map(_ * 2).toVector
    }

    it("combines repeatedly") {
      val s1 = Stream(1, 1, 1)
      val combineEqual =
        combineAdjacent[Pure, Int]((_, _) => true, _ + _)
      s1.through(combineEqual).toVector shouldEqual s1.foldMonoid.toVector
    }
  }

Now let's translate the problem from Stream to List, and the test cases.

  describe("combineAdjacent for lists") {
    it("doesn't combine when condition is false") {
      val s1 = (0 to 10).toList
      val noCombine =
        combineAdjacent2[Int]((_, _) => false, null)
      val s2 = noCombine(s1)
      s2 shouldEqual s1
    }

    it("combines when condition is true") {
      val s1 = List(1, 3, 5, 7, 11, 13, 17, 23, 27)
      val s2 = s1.flatMap(i => List(i, i))
      val combineEqual = combineAdjacent2[Int](_ == _, _ + _)
      combineEqual(s2) shouldEqual s1.map(_ * 2)
    }

    it("combines repeatedly") {
      val s1 = List(1, 1, 1)
      val combineEqual = combineAdjacent2[Int]((_, _) => true, _ + _)
      combineEqual(s1) shouldEqual List(s1.combineAll)
    }
  }

Solving the List case

Here's my first draft at a recursive list solution to the problem, and it passes the test cases:

    def combineAdjacent2[A](
        shouldCombine: (A, A) => Boolean,
        combine: (A, A) => A
    ): List[A] => List[A] = {
      case Nil             => Nil
      case xs @ (_ :: Nil) => xs
      case current :: next :: tail =>
        if (shouldCombine(current, next)) {
          val newLst = combine(current, next) :: tail
          combineAdjacent2[A](shouldCombine, combine)(newLst)
        } else
          current :: combineAdjacent2[A](shouldCombine, combine)(next :: tail)
    }

This works, but let's get rid of that current :: next :: tail case, since Pull only allows us to talk about "the next and the rest"

    def combineAdjacent2[A](
        shouldCombine: (A, A) => Boolean,
        combine: (A, A) => A
    ): List[A] => List[A] = as => as match {
      case Nil => Nil
      case current :: rest =>
        rest match {
          case Nil => current :: Nil
          case next :: nextTail =>
            if (shouldCombine(current, next))
              combineAdjacent2(shouldCombine, combine)(
                combine(current, next) :: nextTail)
            else
              current :: combineAdjacent2(shouldCombine, combine)(rest)

        }
    }

Tests are still passing! 🎉

Now we can mechanically translate this, remembering that flatMapping on stream.pull.uncons1 works a lot like our recursive list match

Translating the List solution to Pull

  def combineAdjacent[F[_], A](
      shouldCombine: (A, A) => Boolean,
      combine: (A, A) => A
  ): Pipe[F, A, A] = { in =>
    def go(s: Stream[F, A]): Pull[F, A, Unit] =
      s.pull.uncons1.flatMap {
        case None => Pull.done
        case Some((current, rest)) =>
          rest.pull.uncons1.flatMap {
            case None => Pull.output1(current)
            case Some((next, nextTail)) =>
              if (shouldCombine(current, next)) {
                val s2 = Stream.emit(combine(current, next)) ++ nextTail
                go(s2)
              } else
                Pull.output1(current) >> go(rest)
          }
      }
    go(in).stream
  }
List Stream + Pull
match s.pull.uncons1.flatMap
case Nil case None
case current :: rest case Some((current, rest))
current :: <recurse>(rest) Pull.output1(current) >> go(rest)

And this version passes all our tests!

Note also that the list based version is not stack safe, because it's not tail recursive. However, the flatMap (>>) in Pull is safe, because Pull/Stream is implemented with a "trampolined" construct, which gives us stack safe recursion.

One final detail

While fs2 streams are purely functional, they can model side effects. We can observe a problem when the source stream we pass to combineAdjacent is generated by effects

it("doesn't omit records when used effectfully") {
  val s = Stream.range(1, 20)

  // prefetchN puts the stream elements into a queue so that
  // the stream from that point is non-repeatable with effects
  for {
    result <- s
      .covary[IO]
      .prefetchN(20)
      .through(combineAdjacent[IO, Int]((_, _) => false, (a, _) => a))
      .compile
      .toVector
  } yield result shouldEqual s.toVector
}

This test case will fail because of the way we've reused part of the data. Executing pull.uncons1 in our flatMap chain had the effect of pulling data out of the stream, so there's no guarantee that the next time we get an element, it would be the same. Consider what it means to call .pull.uncons1 on a stream of Stream.repeatEval(IO(System.currentTimeMillis())) - there's no way we'll get the same element twice.

In the case that the stream is driven by a queue, we now have a problem because we pulled two elements out. We need to fix it by re-emitting the element we just pulled out.

Fortunately we can simply state that as Pull.output1(current) >> go(Stream.emit(next) ++ nextTail), and the test then passes, leaving our final implementation as

  def combineAdjacent[F[_], A](
      shouldCombine: (A, A) => Boolean,
      combine: (A, A) => A
  ): Pipe[F, A, A] = { in =>
    def go(s: Stream[F, A]): Pull[F, A, Unit] =
      s.pull.uncons1.flatMap {
        case None => Pull.done
        case Some((current, rest)) =>
          rest.pull.uncons1.flatMap {
            case None => Pull.output1(current)
            case Some((next, nextTail)) =>
              if (shouldCombine(current, next)) {
                val s2 = Stream.emit(combine(current, next)) ++ nextTail
                go(s2)
              } else
                Pull.output1(current) >> go(Stream.emit(next) ++ nextTail)
          }
      }
    go(in).stream
  }
@chuwy
Copy link

chuwy commented Jun 29, 2018

In doesn't combine when condition is false specification it should be:

s2.toVector shouldEqual s1.toVector

instead of

s2.toVector shouldEqual s2.toVector

Great post, thank you!

@Daenyth
Copy link
Author

Daenyth commented Aug 1, 2018

@chuwy fixed, thanks!

@dsebban
Copy link

dsebban commented Aug 1, 2018

Great post ! Thnks

@kubukoz
Copy link

kubukoz commented Aug 6, 2018

Nice post! It's really cool that we can do this kind of lookahead and reemit the items we've seen (as in the final snippet) to make sure they stay the same.

This post really cleared up the idea of a Pull for me. Would you consider posting it on the typelevel blog?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment