Skip to content

Instantly share code, notes, and snippets.

@Daenyth
Last active December 17, 2023 18:16
Show Gist options
  • Save Daenyth/67575575b5c1acc1d6ea100aae05b3a9 to your computer and use it in GitHub Desktop.
Save Daenyth/67575575b5c1acc1d6ea100aae05b3a9 to your computer and use it in GitHub Desktop.
Cats-effect IOSpec for scalatest / TestContext usage
import cats.Eq
import cats.effect.{ContextShift, IO, Timer}
import org.scalactic.Prettifier
import org.scalactic.source.Position
import org.scalatest.exceptions.TestFailedException
import org.scalatest.{Assertion, AsyncTestSuite}
import scala.concurrent.Future
import scala.concurrent.duration._
import scala.reflect.ClassTag
trait IOAssertions { self: AsyncTestSuite with org.scalatest.Assertions =>
// It's "extremely unsafe" because it's an implicit conversion that takes a pure value
// and silently converts it to a side effecting value that begins running on a thread immediately.
implicit def extremelyUnsafeIOAssertionToFuture(ioa: IO[Assertion])(
implicit pos: Position
): Future[Assertion] = {
val _ = pos // unused here; exists for override to use
ioa.unsafeToFuture()
}
implicit protected class IOAssertionOps[A](val io: IO[A])(
implicit pos: Position
) {
/** Same as shouldNotFail, but preferable in cases where explicit type signatures are not used,
* as `shouldNotFail` will discard any result, and this method will only compile where the author
* intends `io` to be made from assertions.
*
* @example {{{
* List(1,2,3,4,5,6).traverse { n =>
* databaseCheck(n).map { result =>
* result shouldEqual GoodValue
* }
* }.flattenAssertion
* }}}
* */
def flattenAssertion(implicit ev: A <:< Seq[Assertion]): IO[Assertion] = {
val _ = ev // "unused implicit" warning
io.shouldNotFail
}
def shouldResultIn(
expected: A
)(implicit eq: Eq[A], prettifier: Prettifier): IO[Assertion] =
io.flatMap { actual =>
IO(assert(eq.eqv(expected, actual)))
}
def shouldNotFail: IO[Assertion] = io.attempt.flatMap {
case Left(failed: TestFailedException) =>
IO.raiseError(failed)
case Left(err) =>
IO(fail(s"IO Failed with ${err.getMessage}", err)(pos))
case Right(_) =>
IO.pure(succeed)
}
def shouldTerminate(
within: FiniteDuration = 5.seconds
)(implicit timer: Timer[IO], CS: ContextShift[IO]) =
io.shouldNotFail.timeoutTo(within, IO(fail(s"IO didn't terminate within $within")))
/**
* Equivalent to [[org.scalatest.Assertions.assertThrows]] for [[cats.effect.IO]]
*/
def shouldFailWith[T <: AnyRef](
implicit classTag: ClassTag[T]
): IO[Assertion] =
io.attempt.flatMap { attempt =>
IO(assertThrows[T](attempt.toTry.get))
}
}
}
import cats.effect.laws.util.TestContext
import cats.effect.{Blocker, ContextShift, IO, Timer}
import org.scalactic.source.Position
import org.scalatest.{Assertion, AsyncTestSuite}
import org.scalatest.matchers.should.Matchers
import org.scalatest.funspec.AsyncFunSpec
import fs2.Stream
import scala.concurrent.duration._
import scala.concurrent.{ExecutionContext, Future}
import scala.reflect.ClassTag
/** Locates ContextShift[IO] and Timer[IO] via implicit ExecutionContext, similar to how cats-effect 0.10 worked. */
trait ContextShiftTest {
implicit protected def CS(implicit ec: ExecutionContext): ContextShift[IO] =
IO.contextShift(ec)
implicit protected def timer(implicit ec: ExecutionContext): Timer[IO] =
IO.timer(ec)
}
/** Overrides ContextShiftTest behaviors to be provided by cats-effect `TestContext` */
trait TestContextShiftTest extends ContextShiftTest with IOAssertions {
this: AsyncTestSuite =>
final protected val ctx = TestContext()
implicit final override protected def CS(
implicit ec: ExecutionContext
): ContextShift[IO] =
ctx.contextShift[IO](IO.ioEffect)
implicit final override protected def timer(
implicit ec: ExecutionContext
): Timer[IO] = ctx.timer[IO]
implicit final override def extremelyUnsafeIOAssertionToFuture(
test: IO[Assertion]
)(implicit pos: Position): Future[Assertion] = {
val result: Future[Assertion] = test.unsafeToFuture()
ctx.tick(1000.day) // Advance the clock
if (result.value.isDefined)
result
else
fail(
s"""Test probably deadlocked. Test `IO` didn't resolve after simulating 1000 days of time.
| Remaining tasks: ${ctx.state.tasks}""".stripMargin
)(pos)
}
}
/** Version of IOSpec not requiring a particular syntax (like FunSpec) */
trait IOSpecBase extends Matchers with IOAssertions with ContextShiftTest {
self: AsyncTestSuite =>
/** @return bytes from `src/test/resources/path/to/this/spec/package/$fileName`
* @param blocker Blocking pool for thread-blocking IO to run on, for example as provided by the [[BlockingTest]] mixin */
protected def testResource(fileName: String, blocker: Blocker): Stream[IO, Byte] =
TestResource.readFileBytes(fileName, blocker)(ClassTag(this.getClass), ContextShift[IO])
}
/** Spec base class making it easier to standardize IO-based tests */
trait IOSpec extends AsyncFunSpec with IOSpecBase
trait BlockingTestLike {
protected def blockingExecutionContext: ExecutionContextExecutor
final protected def blocker: Blocker =
Blocker.liftExecutionContext(blockingExecutionContext)
}
/** Mixin for tests to help ease blocking management */
trait BlockingTest extends BlockingTestLike {
override protected def blockingExecutionContext: ExecutionContextExecutor =
ExecutionContext.global
}
package org.scalatestplus.scalacheck
// Note ^ important for stealing private access, see https://github.com/scalatest/scalatestplus-scalacheck/issues/40
import cats.effect.IO
import org.scalacheck.Test
import org.scalacheck.effect.PropF
import org.scalactic.source.Position
import org.scalactic.{Prettifier, source}
import org.{scalacheck, scalatest}
import org.scalatest.{Assertion, Assertions}
import org.scalatest.exceptions.{
GeneratorDrivenPropertyCheckFailedException,
StackDepth,
StackDepthException
}
import scala.concurrent.Future
import IOAssertions
/** Mixin to support scalatest + scalacheck-effect PropF
*
* NB: This is in the scalatest package to steal access to package-private helpers. This is not
* especially safe, but I don't want to re-implement the world - I already have to copy/paste a lot
* of code. */
trait ScalacheckEffectAssertions extends ScalaCheckDrivenPropertyChecks {
this: IOAssertions =>
import CheckerAsserting._
/** Support `PropF.forAllF { _ => ???: IO[Assertion] }` via already-supported IO[Unit] */
implicit final protected def ioAssertionToProp(test: IO[scalatest.Assertion]): PropF[IO] =
test.void
// Logic of this is mainly copied from `CheckerAsserting#check`.
// Here we return `Future` using the `IOAssertions` implicit IO~>Future
implicit final protected def scalacheckResultToScalatest(
result: IO[scalacheck.Test.Result]
)(implicit prettifier: Prettifier, pos: Position): Future[scalatest.Assertion] = result.map {
result =>
// scalatest's impl takes this as a function arg, but all call sites pass None
val argNames = None
result.status match {
case Test.Passed => Assertions.succeed
case _: Test.Proved => Assertions.succeed
case Test.Exhausted =>
val failureMsg =
if (result.succeeded == 1)
FailureMessages.propCheckExhaustedAfterOne(prettifier, result.discarded)
else
FailureMessages.propCheckExhausted(prettifier, result.succeeded, result.discarded)
val (args, labels) = argsAndLabels(result)
indicateFailure(
_ => failureMsg,
failureMsg,
args,
labels,
None,
pos
)
case Test.Failed(scalaCheckArgs, scalaCheckLabels) =>
indicateFailure(
sde =>
FailureMessages.propertyException(prettifier,
UnquotedString(sde.getClass.getSimpleName)) + "\n" +
(sde.failedCodeFileNameAndLineNumberString match {
case Some(s) => " (" + s + ")";
case None => ""
}) + "\n" +
" " + FailureMessages.propertyFailed(prettifier, result.succeeded) + "\n" +
(
sde match {
case sd: StackDepth if sd.failedCodeFileNameAndLineNumberString.isDefined =>
" " + FailureMessages.thrownExceptionsLocation(
prettifier,
UnquotedString(sd.failedCodeFileNameAndLineNumberString.get)) + "\n"
case _ => ""
}
) +
" " + FailureMessages.occurredOnValues + "\n" +
prettyArgs(getArgsWithSpecifiedNames(argNames, scalaCheckArgs), prettifier) + "\n" +
" )" +
getLabelDisplay(scalaCheckLabels),
FailureMessages.propertyFailed(prettifier, result.succeeded),
scalaCheckArgs,
scalaCheckLabels.toList,
None,
pos
)
case Test.PropException(scalaCheckArgs, e, scalaCheckLabels) =>
indicateFailure(
_ =>
FailureMessages.propertyException(prettifier,
UnquotedString(e.getClass.getSimpleName)) + "\n" +
" " + FailureMessages.thrownExceptionsMessage(
prettifier,
if (e.getMessage == null) "None" else UnquotedString(e.getMessage)) + "\n" +
(
e match {
case sd: StackDepth if sd.failedCodeFileNameAndLineNumberString.isDefined =>
" " + FailureMessages.thrownExceptionsLocation(
prettifier,
UnquotedString(sd.failedCodeFileNameAndLineNumberString.get)) + "\n"
case _ => ""
}
) +
" " + FailureMessages.occurredOnValues + "\n" +
prettyArgs(getArgsWithSpecifiedNames(argNames, scalaCheckArgs), prettifier) + "\n" +
" )" +
getLabelDisplay(scalaCheckLabels),
FailureMessages.propertyException(prettifier, UnquotedString(e.getClass.getName)),
scalaCheckArgs,
scalaCheckLabels.toList,
Some(e),
pos
)
}
}
// copy/pasted from CheckerAsserting
private[scalacheck] def indicateFailure(
messageFun: StackDepthException => String,
undecoratedMessage: => String,
scalaCheckArgs: List[Any],
scalaCheckLabels: List[String],
optionalCause: Option[Throwable],
pos: source.Position
): Assertion =
throw new GeneratorDrivenPropertyCheckFailedException(
messageFun,
optionalCause,
pos,
None,
undecoratedMessage,
scalaCheckArgs,
None,
scalaCheckLabels
)
}
import java.io.FileNotFoundException
import java.nio.file.{Path, Paths}
import cats.effect.{Blocker, ContextShift, IO}
import cats.implicits._
import fs2.Stream
import scala.reflect.ClassTag
object TestResource {
/** @return the `Path` for file named `classPathResource` in the resources path of `C` */
def resourceAsPath[C](
classPathResource: String
)(implicit tag: ClassTag[C]): Either[FileNotFoundException, Path] =
Option(tag.runtimeClass.getResource(classPathResource))
.map(p => Paths.get(p.toURI))
.toRight(new FileNotFoundException(
s"No file '$classPathResource' in test resources for package ${tag.runtimeClass.getPackage.getName}"))
/** Looks for a file with `name` in src/test/resources/${package-path-of-T}
* and reads Bytes from it
*
* @param name The resource name (Without package path) to read from
* @param blocker Pool for thread-blocking work, for example as provided by the [[teikametrics.BlockingTest]] mixin
* */
def readFileBytes[T: ClassTag](
name: String,
blocker: Blocker
)(implicit CS: ContextShift[IO]): Stream[IO, Byte] =
resourceAsPath[T](name)
.liftTo[Stream[IO, *]]
.flatMap(path => fs2.io.file.readAll[IO](path, blocker, 4096))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment