Last active
August 29, 2015 14:09
-
-
Save dwickern/4d93a2bdf993a9e59ba3 to your computer and use it in GitHub Desktop.
Akka Testkit mixin for testing unordered messages
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import akka.testkit._ | |
import org.scalatest._ | |
import scala.collection.mutable | |
import scala.concurrent.duration._ | |
import scala.reflect.ClassTag | |
import scala.util._ | |
/** | |
* Extensions to [[akka.testkit.TestKit]] for testing actors | |
* which publish messages in no particular order. | |
* | |
* APIs are analogous to the `TestKit` expectation methods. | |
* The expectations are evaluated at the end of each test. | |
*/ | |
trait UnorderedTestKit extends SuiteMixin { self: Suite with TestKitBase => | |
def expectUnorderedMsg(value: Any): Unit = expectUnorderedMsg(defaultTimeout, null, value) | |
def expectUnorderedMsg(max: FiniteDuration, value: Any): Unit = expectUnorderedMsg(max, null, value) | |
def expectUnorderedMsg(hint: String, value: Any): Unit = expectUnorderedMsg(defaultTimeout, hint, value) | |
def expectUnorderedMsg(max: FiniteDuration, hint: String, value: Any): Unit = { | |
expectInternal(max, s"Did not encounter message equal to $value", hint) { | |
case `value` => | |
} | |
} | |
def expectUnorderedNoMsg(): Unit = expectUnorderedNoMsg(defaultTimeout) | |
def expectUnorderedNoMsg(max: FiniteDuration): Unit = { | |
expectNoMsgFor = max | |
} | |
def expectUnorderedMsgType[T](implicit t: ClassTag[T]): Unit = expectUnorderedMsgClass(defaultTimeout, t.runtimeClass) | |
def expectUnorderedMsgType[T](max: FiniteDuration)(implicit t: ClassTag[T]): Unit = expectUnorderedMsgClass(max, t.runtimeClass) | |
def expectUnorderedMsgClass(c: Class[_]): Unit = expectUnorderedMsgClass(defaultTimeout, c) | |
def expectUnorderedMsgClass(max: FiniteDuration, c: Class[_]): Unit = { | |
expectInternal(max, s"Did not encounter message of type ${c.getName}", null) { | |
case msg if c.isInstance(msg) => | |
} | |
} | |
def expectUnorderedMsgPF(max: FiniteDuration = defaultTimeout, hint: String = null)(pf: PartialFunction[Any, Unit]): Unit = { | |
expectInternal(max, "Did not encounter message matching the partial function", hint)(pf) | |
} | |
/** The expectations to evaluate at the end of the current test */ | |
private val expectations = mutable.Set[Expectation]() | |
private var expectNoMsgFor: FiniteDuration = Duration.Zero | |
private def defaultTimeout = testKitSettings.SingleExpectDefaultTimeout.dilated | |
private def expectInternal(max: FiniteDuration, message: String, hint: String)(pf: PartialFunction[Any, Unit]): Unit = { | |
expectations += new Expectation { | |
var lastError: Throwable = _ | |
def execute(msg: Any): Boolean = { | |
if (pf.isDefinedAt(msg)) { | |
Try(pf(msg)) match { | |
case Success(_) => true | |
case Failure(t) => | |
lastError = t | |
false | |
} | |
} else false | |
} | |
def duration = max | |
override def toString = { | |
if (lastError eq null) { | |
Option(hint).fold(message) { h => s"$message ($h)" } | |
} else lastError.toString | |
} | |
} | |
} | |
private def verifyExpectations(): Unit = { | |
// use the sum of all expectations as the maximum duration | |
val max = expectations.foldLeft(Duration.Zero) { | |
case (accum, exp) => accum + exp.duration | |
} | |
val unhandled = mutable.Set[Any]() | |
try { | |
if (expectations.nonEmpty) { | |
within(max) { | |
while (expectations.nonEmpty) { | |
fishForMessage() { | |
case Expectation(xp) => | |
expectations -= xp | |
true | |
case msg => | |
unhandled += msg | |
false | |
} | |
} | |
if (expectNoMsgFor > Duration.Zero) { | |
val msg = receiveOne(expectNoMsgFor) | |
if (msg != null) { | |
throw new AssertionError(s"Expected no additional messages but received: $msg") | |
} | |
} | |
} | |
} | |
} catch { | |
case ex: AssertionError if expectations.nonEmpty => | |
val message = s""" | |
|${expectations.size} expectation(s) were unmet: | |
|${expectations.toSeq.zipWithIndex.map { case (exp, i) => s"\t(${i + 1}) $exp" }.mkString("\n")} | |
|${unhandled.size} message(s) were unhandled: | |
|${unhandled.map("\t" + _).mkString("\n")} | |
""".stripMargin | |
throw new AssertionError(message, ex) | |
} | |
} | |
protected abstract override def withFixture(test: NoArgTest): Outcome = { | |
expectations.clear() | |
expectNoMsgFor = Duration.Zero | |
val result = super.withFixture(test) | |
verifyExpectations() | |
result | |
} | |
trait Expectation { | |
def execute(msg: Any): Boolean | |
def duration: FiniteDuration | |
} | |
object Expectation { | |
/** Tests the `msg` against all of the expectations, and extracts the matching expectation if there is one */ | |
def unapply(msg: Any): Option[Expectation] = expectations.find(_.execute(msg)) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment