Skip to content

Instantly share code, notes, and snippets.

@tifletcher
Last active May 2, 2017 04:43
Show Gist options
  • Save tifletcher/9771c28e37370c0a5cb784de6f0040df to your computer and use it in GitHub Desktop.
Save tifletcher/9771c28e37370c0a5cb784de6f0040df to your computer and use it in GitHub Desktop.
package com.dntty.util
object DependencyResolver {
/**
* Recursively resolve an element's dependencies into a flattened set of elements
*
* @param rootElement element whose dependencies will be resolved
* @param allElements set of elements from which dependencies will be extracted
* @param getId function which extracts an id of type K from an element of type T
* @param getIncludedIds function which extracts an optional sequence of ids that the element depends on
* @tparam T element type
* @tparam K element id type
* @return flattened set of elements consisting of the root element and all its dependencies
* @throws DependencyNotFound if outerSet does not contain all required dependencies
* @throws CyclicDependencyDetected if a cycle is detected while traversing the dependency graph
*/
def apply[T, K](rootElement: T, allElements: Set[T], getId: (T) => K, getIncludedIds: (T) => Traversable[K]): Set[T] =
new Worker[T, K](getIncludedIds, getId, allElements).dependencies(rootElement)
/**
* Construct a DependencyResolver.Worker to resolve multiple elements from the same set, or to construct a pre-cached
* map of the entire set's dependencies. Worker is almost certainly NOT THREAD SAFE, so don't share references to it.
* Prefer instead sharing its dependency map when multiple threads need access.
* @param getIncludedIds function which extracts an optional sequence of ids that the element depends on
* @param getId function which extracts an id of type K from an element of type T
* @param allElements set of elements from which dependencies will be extracted
* @tparam T element type
* @tparam K element id type
*/
class Worker[T, K](getIncludedIds: (T) => Traversable[K], getId: (T) => K, allElements: Set[T]) {
private val cachingCycleDetector = new CachingCycleDetector[T, K](allElements)
/**
* Fetch fetch a single element's dependencies
* @param rootElement the element who's dependencies will be calculated
* @return the minimum subset of elements required to resolve rootElement and its dependencies
* @throws DependencyNotFound if outerSet does not contain all required dependencies
* @throws CyclicDependencyDetected if a cycle is detected while traversing the dependency graph
*/
def dependencies(rootElement: T): Set[T] = {
cachingCycleDetector.getCachedResultOrRunWithCycleDetection(
cacheKey = getId(rootElement),
cacheVal = {
val includedIds: Traversable[K] = getIncludedIds(rootElement)
val includedElements = includedIds.flatMap(getIncludedElements).toSet
includedElements + rootElement
}
)
}
/**
* Return a map from each element's id to a set composed of the element and all its dependencies.
* Note that when this value is calculated each element's dependencies will be calculated, and as such the
* returned dependency map is most likely cycle free and complete.
* @throws DependencyNotFound if outerSet does not contain all required dependencies
* @throws CyclicDependencyDetected if a cycle is detected while traversing the dependency graph
*/
lazy val dependencyMap: Map[K, Set[T]] = allElements
.map { el =>
getId(el) -> dependencies(el)
}
.toMap
private def getIncludedElements(includedId: K): Set[T] = {
val newRootOpt = allElements.find(getId(_) == includedId)
newRootOpt match {
case Some(newRoot) => dependencies(newRoot)
case None => throw DependencyNotFound(includedId, allElements)
}
}
}
private class CachingCycleDetector[T, K](allElements: Set[T]) {
import scala.collection.mutable
private val resolutionMap = mutable.Map.empty[K, Set[T]]
private val inProgressIds = mutable.Set.empty[K]
def getCachedResultOrRunWithCycleDetection(cacheKey: K, cacheVal: => Set[T]): Set[T] =
resolutionMap.getOrElse(cacheKey, {
claimIdForCycleDetection(cacheKey)
val result = cacheVal
cacheResultAndReleaseId(cacheKey, result)
})
private def claimIdForCycleDetection(elementId: K) = {
if (inProgressIds.contains(elementId)) {
throw CyclicDependencyDetected(elementId, allElements)
} else {
inProgressIds.add(elementId)
}
}
private def cacheResultAndReleaseId(elementId: K, resolutionSet: Set[T]): Set[T] = {
inProgressIds.remove(elementId)
resolutionMap.update(elementId, resolutionSet)
resolutionSet
}
private def hasCachedResult(elementId: K): Boolean = resolutionMap.contains(elementId)
private def getCachedResult(elementId: K): Set[T] = resolutionMap(elementId)
}
case class DependencyNotFound[K, T](dependencyId: K, allElements: Set[T]) extends Exception(
s"Dependency with id $dependencyId not found in set $allElements"
)
case class CyclicDependencyDetected[K, T](dependencyId: K, allElements: Set[T]) extends Exception(
s"Cycle detected at id $dependencyId in set $allElements"
)
}
package com.dntty.util
import com.dntty.util.DependencyResolver.{ CyclicDependencyDetected, DependencyNotFound }
import org.scalatest.FunSpec
import scala.util.Random
class DependencyResolverSpec extends FunSpec {
def testResolver: (Test, Set[Test]) => Set[Test] = DependencyResolver(_, _, id, includes)
def id: Test => Int = _.id
def includes: Test => Seq[Int] = _.includes
def randomId: Int = Random.nextInt
case class Test(id: Int, includes: Seq[Int])
object Test {
def apply(): Test = Test(randomId, Seq.empty)
def apply(includes: Int*): Test = Test(randomId, includes)
}
val extras: Set[Test] = {
val a = Test()
val b = Test(a.id)
val c = Test(a.id, b.id)
Set(a, b, c)
}
describe("DependencyResolver") {
it("should resolve an element with no dependencies") {
val el1 = Test()
val pool = Set(el1) ++ extras
val res = testResolver(el1, pool)
assert(res.size == 1)
assert(res.contains(el1))
}
it("should resolve a simple dependency") {
val included = Test()
val root = Test(included.id)
val pool = Set(included, root) ++ extras
val res = testResolver(root, pool)
assert(res.size == 2)
assert(res.contains(root))
assert(res.contains(included))
}
it("should fully resolve dependencies of dependencies") {
val base = Test()
val middle = Test(includes = base.id)
val root = Test(includes = middle.id)
val pool = Set(base, middle, root) ++ extras
val res = testResolver(root, pool)
assert(res.size == 3)
assert(res.contains(base))
assert(res.contains(middle))
assert(res.contains(root))
}
it("should only resolve a dependency once") {
import org.mockito.Mockito._
val baseId = randomId
val base = mock(classOf[Test])
when(base.id).thenReturn(baseId)
when(base.includes).thenReturn(Seq.empty)
val middle1 = Test(includes = base.id)
val middle2 = Test(includes = base.id)
val root = Test(includes = middle1.id, middle2.id)
val pool = Set(base, middle1, middle2, root) ++ extras
val res = testResolver(root, pool)
assert(res.size == 4)
assert(res.contains(base))
assert(res.contains(middle1))
assert(res.contains(middle2))
assert(res.contains(root))
verify(base, times(1)).includes
}
it("should throw when a dependency cannot be found") {
val el = Test(includes = randomId)
assertThrows[DependencyNotFound[_, _]] {
testResolver(el, Set(el))
}
}
it("should throw when a simple cycle exists") {
val el1 = Test(1, Seq(2))
val el2 = Test(2, Seq(3))
val el3 = Test(3, Seq(1))
assertThrows[CyclicDependencyDetected[_, _]] {
testResolver(el1, Set(el1, el2, el3))
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment