Last active
May 2, 2017 04:43
-
-
Save tifletcher/9771c28e37370c0a5cb784de6f0040df to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.dntty.util | |
object DependencyResolver { | |
/** | |
* Recursively resolve an element's dependencies into a flattened set of elements | |
* | |
* @param rootElement element whose dependencies will be resolved | |
* @param allElements set of elements from which dependencies will be extracted | |
* @param getId function which extracts an id of type K from an element of type T | |
* @param getIncludedIds function which extracts an optional sequence of ids that the element depends on | |
* @tparam T element type | |
* @tparam K element id type | |
* @return flattened set of elements consisting of the root element and all its dependencies | |
* @throws DependencyNotFound if outerSet does not contain all required dependencies | |
* @throws CyclicDependencyDetected if a cycle is detected while traversing the dependency graph | |
*/ | |
def apply[T, K](rootElement: T, allElements: Set[T], getId: (T) => K, getIncludedIds: (T) => Traversable[K]): Set[T] = | |
new Worker[T, K](getIncludedIds, getId, allElements).dependencies(rootElement) | |
/** | |
* Construct a DependencyResolver.Worker to resolve multiple elements from the same set, or to construct a pre-cached | |
* map of the entire set's dependencies. Worker is almost certainly NOT THREAD SAFE, so don't share references to it. | |
* Prefer instead sharing its dependency map when multiple threads need access. | |
* @param getIncludedIds function which extracts an optional sequence of ids that the element depends on | |
* @param getId function which extracts an id of type K from an element of type T | |
* @param allElements set of elements from which dependencies will be extracted | |
* @tparam T element type | |
* @tparam K element id type | |
*/ | |
class Worker[T, K](getIncludedIds: (T) => Traversable[K], getId: (T) => K, allElements: Set[T]) { | |
private val cachingCycleDetector = new CachingCycleDetector[T, K](allElements) | |
/** | |
* Fetch fetch a single element's dependencies | |
* @param rootElement the element who's dependencies will be calculated | |
* @return the minimum subset of elements required to resolve rootElement and its dependencies | |
* @throws DependencyNotFound if outerSet does not contain all required dependencies | |
* @throws CyclicDependencyDetected if a cycle is detected while traversing the dependency graph | |
*/ | |
def dependencies(rootElement: T): Set[T] = { | |
cachingCycleDetector.getCachedResultOrRunWithCycleDetection( | |
cacheKey = getId(rootElement), | |
cacheVal = { | |
val includedIds: Traversable[K] = getIncludedIds(rootElement) | |
val includedElements = includedIds.flatMap(getIncludedElements).toSet | |
includedElements + rootElement | |
} | |
) | |
} | |
/** | |
* Return a map from each element's id to a set composed of the element and all its dependencies. | |
* Note that when this value is calculated each element's dependencies will be calculated, and as such the | |
* returned dependency map is most likely cycle free and complete. | |
* @throws DependencyNotFound if outerSet does not contain all required dependencies | |
* @throws CyclicDependencyDetected if a cycle is detected while traversing the dependency graph | |
*/ | |
lazy val dependencyMap: Map[K, Set[T]] = allElements | |
.map { el => | |
getId(el) -> dependencies(el) | |
} | |
.toMap | |
private def getIncludedElements(includedId: K): Set[T] = { | |
val newRootOpt = allElements.find(getId(_) == includedId) | |
newRootOpt match { | |
case Some(newRoot) => dependencies(newRoot) | |
case None => throw DependencyNotFound(includedId, allElements) | |
} | |
} | |
} | |
private class CachingCycleDetector[T, K](allElements: Set[T]) { | |
import scala.collection.mutable | |
private val resolutionMap = mutable.Map.empty[K, Set[T]] | |
private val inProgressIds = mutable.Set.empty[K] | |
def getCachedResultOrRunWithCycleDetection(cacheKey: K, cacheVal: => Set[T]): Set[T] = | |
resolutionMap.getOrElse(cacheKey, { | |
claimIdForCycleDetection(cacheKey) | |
val result = cacheVal | |
cacheResultAndReleaseId(cacheKey, result) | |
}) | |
private def claimIdForCycleDetection(elementId: K) = { | |
if (inProgressIds.contains(elementId)) { | |
throw CyclicDependencyDetected(elementId, allElements) | |
} else { | |
inProgressIds.add(elementId) | |
} | |
} | |
private def cacheResultAndReleaseId(elementId: K, resolutionSet: Set[T]): Set[T] = { | |
inProgressIds.remove(elementId) | |
resolutionMap.update(elementId, resolutionSet) | |
resolutionSet | |
} | |
private def hasCachedResult(elementId: K): Boolean = resolutionMap.contains(elementId) | |
private def getCachedResult(elementId: K): Set[T] = resolutionMap(elementId) | |
} | |
case class DependencyNotFound[K, T](dependencyId: K, allElements: Set[T]) extends Exception( | |
s"Dependency with id $dependencyId not found in set $allElements" | |
) | |
case class CyclicDependencyDetected[K, T](dependencyId: K, allElements: Set[T]) extends Exception( | |
s"Cycle detected at id $dependencyId in set $allElements" | |
) | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.dntty.util | |
import com.dntty.util.DependencyResolver.{ CyclicDependencyDetected, DependencyNotFound } | |
import org.scalatest.FunSpec | |
import scala.util.Random | |
class DependencyResolverSpec extends FunSpec { | |
def testResolver: (Test, Set[Test]) => Set[Test] = DependencyResolver(_, _, id, includes) | |
def id: Test => Int = _.id | |
def includes: Test => Seq[Int] = _.includes | |
def randomId: Int = Random.nextInt | |
case class Test(id: Int, includes: Seq[Int]) | |
object Test { | |
def apply(): Test = Test(randomId, Seq.empty) | |
def apply(includes: Int*): Test = Test(randomId, includes) | |
} | |
val extras: Set[Test] = { | |
val a = Test() | |
val b = Test(a.id) | |
val c = Test(a.id, b.id) | |
Set(a, b, c) | |
} | |
describe("DependencyResolver") { | |
it("should resolve an element with no dependencies") { | |
val el1 = Test() | |
val pool = Set(el1) ++ extras | |
val res = testResolver(el1, pool) | |
assert(res.size == 1) | |
assert(res.contains(el1)) | |
} | |
it("should resolve a simple dependency") { | |
val included = Test() | |
val root = Test(included.id) | |
val pool = Set(included, root) ++ extras | |
val res = testResolver(root, pool) | |
assert(res.size == 2) | |
assert(res.contains(root)) | |
assert(res.contains(included)) | |
} | |
it("should fully resolve dependencies of dependencies") { | |
val base = Test() | |
val middle = Test(includes = base.id) | |
val root = Test(includes = middle.id) | |
val pool = Set(base, middle, root) ++ extras | |
val res = testResolver(root, pool) | |
assert(res.size == 3) | |
assert(res.contains(base)) | |
assert(res.contains(middle)) | |
assert(res.contains(root)) | |
} | |
it("should only resolve a dependency once") { | |
import org.mockito.Mockito._ | |
val baseId = randomId | |
val base = mock(classOf[Test]) | |
when(base.id).thenReturn(baseId) | |
when(base.includes).thenReturn(Seq.empty) | |
val middle1 = Test(includes = base.id) | |
val middle2 = Test(includes = base.id) | |
val root = Test(includes = middle1.id, middle2.id) | |
val pool = Set(base, middle1, middle2, root) ++ extras | |
val res = testResolver(root, pool) | |
assert(res.size == 4) | |
assert(res.contains(base)) | |
assert(res.contains(middle1)) | |
assert(res.contains(middle2)) | |
assert(res.contains(root)) | |
verify(base, times(1)).includes | |
} | |
it("should throw when a dependency cannot be found") { | |
val el = Test(includes = randomId) | |
assertThrows[DependencyNotFound[_, _]] { | |
testResolver(el, Set(el)) | |
} | |
} | |
it("should throw when a simple cycle exists") { | |
val el1 = Test(1, Seq(2)) | |
val el2 = Test(2, Seq(3)) | |
val el3 = Test(3, Seq(1)) | |
assertThrows[CyclicDependencyDetected[_, _]] { | |
testResolver(el1, Set(el1, el2, el3)) | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment