Last active
December 17, 2023 18:16
-
-
Save Daenyth/67575575b5c1acc1d6ea100aae05b3a9 to your computer and use it in GitHub Desktop.
Cats-effect IOSpec for scalatest / TestContext usage
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cats.Eq | |
import cats.effect.{ContextShift, IO, Timer} | |
import org.scalactic.Prettifier | |
import org.scalactic.source.Position | |
import org.scalatest.exceptions.TestFailedException | |
import org.scalatest.{Assertion, AsyncTestSuite} | |
import scala.concurrent.Future | |
import scala.concurrent.duration._ | |
import scala.reflect.ClassTag | |
trait IOAssertions { self: AsyncTestSuite with org.scalatest.Assertions => | |
// It's "extremely unsafe" because it's an implicit conversion that takes a pure value | |
// and silently converts it to a side effecting value that begins running on a thread immediately. | |
implicit def extremelyUnsafeIOAssertionToFuture(ioa: IO[Assertion])( | |
implicit pos: Position | |
): Future[Assertion] = { | |
val _ = pos // unused here; exists for override to use | |
ioa.unsafeToFuture() | |
} | |
implicit protected class IOAssertionOps[A](val io: IO[A])( | |
implicit pos: Position | |
) { | |
/** Same as shouldNotFail, but preferable in cases where explicit type signatures are not used, | |
* as `shouldNotFail` will discard any result, and this method will only compile where the author | |
* intends `io` to be made from assertions. | |
* | |
* @example {{{ | |
* List(1,2,3,4,5,6).traverse { n => | |
* databaseCheck(n).map { result => | |
* result shouldEqual GoodValue | |
* } | |
* }.flattenAssertion | |
* }}} | |
* */ | |
def flattenAssertion(implicit ev: A <:< Seq[Assertion]): IO[Assertion] = { | |
val _ = ev // "unused implicit" warning | |
io.shouldNotFail | |
} | |
def shouldResultIn( | |
expected: A | |
)(implicit eq: Eq[A], prettifier: Prettifier): IO[Assertion] = | |
io.flatMap { actual => | |
IO(assert(eq.eqv(expected, actual))) | |
} | |
def shouldNotFail: IO[Assertion] = io.attempt.flatMap { | |
case Left(failed: TestFailedException) => | |
IO.raiseError(failed) | |
case Left(err) => | |
IO(fail(s"IO Failed with ${err.getMessage}", err)(pos)) | |
case Right(_) => | |
IO.pure(succeed) | |
} | |
def shouldTerminate( | |
within: FiniteDuration = 5.seconds | |
)(implicit timer: Timer[IO], CS: ContextShift[IO]) = | |
io.shouldNotFail.timeoutTo(within, IO(fail(s"IO didn't terminate within $within"))) | |
/** | |
* Equivalent to [[org.scalatest.Assertions.assertThrows]] for [[cats.effect.IO]] | |
*/ | |
def shouldFailWith[T <: AnyRef]( | |
implicit classTag: ClassTag[T] | |
): IO[Assertion] = | |
io.attempt.flatMap { attempt => | |
IO(assertThrows[T](attempt.toTry.get)) | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cats.effect.laws.util.TestContext | |
import cats.effect.{Blocker, ContextShift, IO, Timer} | |
import org.scalactic.source.Position | |
import org.scalatest.{Assertion, AsyncTestSuite} | |
import org.scalatest.matchers.should.Matchers | |
import org.scalatest.funspec.AsyncFunSpec | |
import fs2.Stream | |
import scala.concurrent.duration._ | |
import scala.concurrent.{ExecutionContext, Future} | |
import scala.reflect.ClassTag | |
/** Locates ContextShift[IO] and Timer[IO] via implicit ExecutionContext, similar to how cats-effect 0.10 worked. */ | |
trait ContextShiftTest { | |
implicit protected def CS(implicit ec: ExecutionContext): ContextShift[IO] = | |
IO.contextShift(ec) | |
implicit protected def timer(implicit ec: ExecutionContext): Timer[IO] = | |
IO.timer(ec) | |
} | |
/** Overrides ContextShiftTest behaviors to be provided by cats-effect `TestContext` */ | |
trait TestContextShiftTest extends ContextShiftTest with IOAssertions { | |
this: AsyncTestSuite => | |
final protected val ctx = TestContext() | |
implicit final override protected def CS( | |
implicit ec: ExecutionContext | |
): ContextShift[IO] = | |
ctx.contextShift[IO](IO.ioEffect) | |
implicit final override protected def timer( | |
implicit ec: ExecutionContext | |
): Timer[IO] = ctx.timer[IO] | |
implicit final override def extremelyUnsafeIOAssertionToFuture( | |
test: IO[Assertion] | |
)(implicit pos: Position): Future[Assertion] = { | |
val result: Future[Assertion] = test.unsafeToFuture() | |
ctx.tick(1000.day) // Advance the clock | |
if (result.value.isDefined) | |
result | |
else | |
fail( | |
s"""Test probably deadlocked. Test `IO` didn't resolve after simulating 1000 days of time. | |
| Remaining tasks: ${ctx.state.tasks}""".stripMargin | |
)(pos) | |
} | |
} | |
/** Version of IOSpec not requiring a particular syntax (like FunSpec) */ | |
trait IOSpecBase extends Matchers with IOAssertions with ContextShiftTest { | |
self: AsyncTestSuite => | |
/** @return bytes from `src/test/resources/path/to/this/spec/package/$fileName` | |
* @param blocker Blocking pool for thread-blocking IO to run on, for example as provided by the [[BlockingTest]] mixin */ | |
protected def testResource(fileName: String, blocker: Blocker): Stream[IO, Byte] = | |
TestResource.readFileBytes(fileName, blocker)(ClassTag(this.getClass), ContextShift[IO]) | |
} | |
/** Spec base class making it easier to standardize IO-based tests */ | |
trait IOSpec extends AsyncFunSpec with IOSpecBase | |
trait BlockingTestLike { | |
protected def blockingExecutionContext: ExecutionContextExecutor | |
final protected def blocker: Blocker = | |
Blocker.liftExecutionContext(blockingExecutionContext) | |
} | |
/** Mixin for tests to help ease blocking management */ | |
trait BlockingTest extends BlockingTestLike { | |
override protected def blockingExecutionContext: ExecutionContextExecutor = | |
ExecutionContext.global | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.scalatestplus.scalacheck | |
// Note ^ important for stealing private access, see https://github.com/scalatest/scalatestplus-scalacheck/issues/40 | |
import cats.effect.IO | |
import org.scalacheck.Test | |
import org.scalacheck.effect.PropF | |
import org.scalactic.source.Position | |
import org.scalactic.{Prettifier, source} | |
import org.{scalacheck, scalatest} | |
import org.scalatest.{Assertion, Assertions} | |
import org.scalatest.exceptions.{ | |
GeneratorDrivenPropertyCheckFailedException, | |
StackDepth, | |
StackDepthException | |
} | |
import scala.concurrent.Future | |
import IOAssertions | |
/** Mixin to support scalatest + scalacheck-effect PropF | |
* | |
* NB: This is in the scalatest package to steal access to package-private helpers. This is not | |
* especially safe, but I don't want to re-implement the world - I already have to copy/paste a lot | |
* of code. */ | |
trait ScalacheckEffectAssertions extends ScalaCheckDrivenPropertyChecks { | |
this: IOAssertions => | |
import CheckerAsserting._ | |
/** Support `PropF.forAllF { _ => ???: IO[Assertion] }` via already-supported IO[Unit] */ | |
implicit final protected def ioAssertionToProp(test: IO[scalatest.Assertion]): PropF[IO] = | |
test.void | |
// Logic of this is mainly copied from `CheckerAsserting#check`. | |
// Here we return `Future` using the `IOAssertions` implicit IO~>Future | |
implicit final protected def scalacheckResultToScalatest( | |
result: IO[scalacheck.Test.Result] | |
)(implicit prettifier: Prettifier, pos: Position): Future[scalatest.Assertion] = result.map { | |
result => | |
// scalatest's impl takes this as a function arg, but all call sites pass None | |
val argNames = None | |
result.status match { | |
case Test.Passed => Assertions.succeed | |
case _: Test.Proved => Assertions.succeed | |
case Test.Exhausted => | |
val failureMsg = | |
if (result.succeeded == 1) | |
FailureMessages.propCheckExhaustedAfterOne(prettifier, result.discarded) | |
else | |
FailureMessages.propCheckExhausted(prettifier, result.succeeded, result.discarded) | |
val (args, labels) = argsAndLabels(result) | |
indicateFailure( | |
_ => failureMsg, | |
failureMsg, | |
args, | |
labels, | |
None, | |
pos | |
) | |
case Test.Failed(scalaCheckArgs, scalaCheckLabels) => | |
indicateFailure( | |
sde => | |
FailureMessages.propertyException(prettifier, | |
UnquotedString(sde.getClass.getSimpleName)) + "\n" + | |
(sde.failedCodeFileNameAndLineNumberString match { | |
case Some(s) => " (" + s + ")"; | |
case None => "" | |
}) + "\n" + | |
" " + FailureMessages.propertyFailed(prettifier, result.succeeded) + "\n" + | |
( | |
sde match { | |
case sd: StackDepth if sd.failedCodeFileNameAndLineNumberString.isDefined => | |
" " + FailureMessages.thrownExceptionsLocation( | |
prettifier, | |
UnquotedString(sd.failedCodeFileNameAndLineNumberString.get)) + "\n" | |
case _ => "" | |
} | |
) + | |
" " + FailureMessages.occurredOnValues + "\n" + | |
prettyArgs(getArgsWithSpecifiedNames(argNames, scalaCheckArgs), prettifier) + "\n" + | |
" )" + | |
getLabelDisplay(scalaCheckLabels), | |
FailureMessages.propertyFailed(prettifier, result.succeeded), | |
scalaCheckArgs, | |
scalaCheckLabels.toList, | |
None, | |
pos | |
) | |
case Test.PropException(scalaCheckArgs, e, scalaCheckLabels) => | |
indicateFailure( | |
_ => | |
FailureMessages.propertyException(prettifier, | |
UnquotedString(e.getClass.getSimpleName)) + "\n" + | |
" " + FailureMessages.thrownExceptionsMessage( | |
prettifier, | |
if (e.getMessage == null) "None" else UnquotedString(e.getMessage)) + "\n" + | |
( | |
e match { | |
case sd: StackDepth if sd.failedCodeFileNameAndLineNumberString.isDefined => | |
" " + FailureMessages.thrownExceptionsLocation( | |
prettifier, | |
UnquotedString(sd.failedCodeFileNameAndLineNumberString.get)) + "\n" | |
case _ => "" | |
} | |
) + | |
" " + FailureMessages.occurredOnValues + "\n" + | |
prettyArgs(getArgsWithSpecifiedNames(argNames, scalaCheckArgs), prettifier) + "\n" + | |
" )" + | |
getLabelDisplay(scalaCheckLabels), | |
FailureMessages.propertyException(prettifier, UnquotedString(e.getClass.getName)), | |
scalaCheckArgs, | |
scalaCheckLabels.toList, | |
Some(e), | |
pos | |
) | |
} | |
} | |
// copy/pasted from CheckerAsserting | |
private[scalacheck] def indicateFailure( | |
messageFun: StackDepthException => String, | |
undecoratedMessage: => String, | |
scalaCheckArgs: List[Any], | |
scalaCheckLabels: List[String], | |
optionalCause: Option[Throwable], | |
pos: source.Position | |
): Assertion = | |
throw new GeneratorDrivenPropertyCheckFailedException( | |
messageFun, | |
optionalCause, | |
pos, | |
None, | |
undecoratedMessage, | |
scalaCheckArgs, | |
None, | |
scalaCheckLabels | |
) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.io.FileNotFoundException | |
import java.nio.file.{Path, Paths} | |
import cats.effect.{Blocker, ContextShift, IO} | |
import cats.implicits._ | |
import fs2.Stream | |
import scala.reflect.ClassTag | |
object TestResource { | |
/** @return the `Path` for file named `classPathResource` in the resources path of `C` */ | |
def resourceAsPath[C]( | |
classPathResource: String | |
)(implicit tag: ClassTag[C]): Either[FileNotFoundException, Path] = | |
Option(tag.runtimeClass.getResource(classPathResource)) | |
.map(p => Paths.get(p.toURI)) | |
.toRight(new FileNotFoundException( | |
s"No file '$classPathResource' in test resources for package ${tag.runtimeClass.getPackage.getName}")) | |
/** Looks for a file with `name` in src/test/resources/${package-path-of-T} | |
* and reads Bytes from it | |
* | |
* @param name The resource name (Without package path) to read from | |
* @param blocker Pool for thread-blocking work, for example as provided by the [[teikametrics.BlockingTest]] mixin | |
* */ | |
def readFileBytes[T: ClassTag]( | |
name: String, | |
blocker: Blocker | |
)(implicit CS: ContextShift[IO]): Stream[IO, Byte] = | |
resourceAsPath[T](name) | |
.liftTo[Stream[IO, *]] | |
.flatMap(path => fs2.io.file.readAll[IO](path, blocker, 4096)) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment