Skip to content

Instantly share code, notes, and snippets.

@johnynek
Last active May 13, 2024 02:10
Show Gist options
  • Save johnynek/c49e3515220420ee6860ed8d03194be7 to your computer and use it in GitHub Desktop.
Save johnynek/c49e3515220420ee6860ed8d03194be7 to your computer and use it in GitHub Desktop.
A generalization of tail recursion for stack safety in scala
/*
* Consider a simple recursive function like:
* f(x) = if (x > 1) f(x - 1) + x
* else 0
*
* This function isn't tail recursive (it could be, but let's set that aside for a moment).
* How can we mechanically, which is to say without thinking about it, convert this into a stack safe recursion?
* An approach is to model everything that happens after the recursion as a continuation, and build up that
* continuation in a stack safe manner. Here is some example code:
*/
object TailRec {
/**
* Represents a function that is "almost" tail recursive
*
* after the recursion, we allow a function B => B. Can
* Then we give a stack safe version of this as a result
*/
def almostTailRec[A, B](fn: A => (Either[A, B], B => B)): A => B = {
@annotation.tailrec
def applyAll(b: B, stack: List[B => B]): B =
stack match {
case Nil => b
case h :: t => applyAll(h(b), t)
}
@annotation.tailrec
def loop(a: A, finish: List[B => B]): B =
fn(a) match {
case (Right(b), bfn) => applyAll(bfn(b), finish)
case (Left(a), bfn) => loop(a, bfn :: finish)
}
{ a: A => loop(a, Nil) }
}
/**
* f(x) = if (x > 1) f(x-1) + x
* else 0
*/
def example1NotTail(x: Int): Int =
if (x > 1) example1NotTail(x - 1) + x
else 0
val example1Tail: Int => Int =
almostTailRec { x: Int =>
if (x > 1) (Left(x - 1), { x1: Int => x1 + x })
else (Right(0), identity[Int](_))
}
}
/*
* But how fast is it? It's slower, but it works. Using a naive approach to time it by
* running it 100 times and taking the average:
scala> timeit(example1NotTail(10000))
took: 50554ns
res13: Int = 50004999
scala> timeit(example1Tail(10000))
took: 300386ns
res14: Int = 50004999
That's 6x slower to use the stack safe, but what happens if we need to go to 100k?
scala> example1NotTail(100000)
java.lang.StackOverflowError
at org.bykn.bosatsu.TailRec$.example1NotTail(TailRec.scala:34)
at org.bykn.bosatsu.TailRec$.example1NotTail(TailRec.scala:34)
at org.bykn.bosatsu.TailRec$.example1NotTail(TailRec.scala:34)
scala> timeit(example1Tail(100000))
took: 3757148ns
res45: Int = 705082703
Now, this example could have been made tail recursive:
@annotation.tailrec
final def exampleRec(x: Int, acc: Int = 0): Int =
if (x > 1) exampleRec(x - 1, acc + x)
else acc
My naive benchmark runs this in 51us, 74x faster than the general approach above.
but this uses the fact that the monoid on the continuation function g(x) = { y: Int => y + x } is associative and commutative:
`g(x1).andThen(g(x2)) == g(x2).andThen(g(x1)))`
and has a simple representation: just a single integer we increment. What we have seen above is the generalization where
we use the free monoid for function composition and apply in the right order at the end. If we can find
a simpler representation of the monoid on function composition, we can use the above approach which will be faster
since we don't need to accumulate the list of continuations and apply it.
Hopefully this idea can help you think about how to make more recursions stack safe. I noticed this while making
a function stack safe in the paiges project: https://github.com/typelevel/paiges
Now, of course, we can also use a Free Monad approach to model recursive functions. That can handle a more general
class of recursions, including this one, however, the performance seems worse:
*/
import scala.util.control.TailCalls.{ done, tailcall, TailRec => Tail }
def example1TR(x: Int): Tail[Int] =
if (x > 1) tailcall(example1TR(x - 1)).map(_ + x)
else done(x)
/*
scala> timeit(example1TR(100000).result)
took: 9060604ns
res15: Int = 705082704
that's 2.4x slower than the approach given above. This is because the TailCalls code in scala is more general. Note, we
don't need flatMap here, only map, but TailCalls support of flatMap incurs a cost we must still pay.
==================================
MIT License
Copyright (c) 2019 P. Oscar Boykin <[email protected]>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment