Last active
January 2, 2022 22:56
-
-
Save clintval/f37798af8c4572aa6999c38dfc124567 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package io.cvbio.collection | |
import io.cvbio.io.Io | |
import com.fulcrumgenomics.commons.collection.SelfClosingIterator | |
import java.util.concurrent.atomic.AtomicReference | |
import java.util.concurrent._ | |
import scala.concurrent.duration.{Duration, DurationInt} | |
import scala.concurrent.{Await, ExecutionContext, Future} | |
/** Helpers for parallel work over iterators. */ | |
object ParIterator { | |
/** The default maximum size for capacity queues used in parallel iteration. */ | |
val DefaultQueueCapacity: Int = 128000 | |
/** A helper function for mapping over objects in an iterator in parallel while caching only <capacity> result objects | |
* at a time. It is recommended to only pass thread-safe functions to the <fn> parameter. Internally, this function | |
* may use a fixed thread pool and a blocking queue to buffer results of size <capacity> but only if more than one | |
* thread is requested. If more than 1 thread is requested, then we assume the main thread calling this method will | |
* not be under load and will be kept waiting. Any exception caused by the source <iterator> or by the applied | |
* function <fn> will be raised even if running in a multi-threaded context. | |
*/ | |
def map[A, B]( | |
iterator: Iterator[A], | |
fn: A => B, | |
threads: Int = Io.AvailableProcessors, | |
capacity: Option[Int] = Some(DefaultQueueCapacity) | |
): Iterator[B] = { | |
if (threads == 1) { iterator.map(fn) } else { | |
val pool = Executors.newFixedThreadPool(threads) | |
val results = iterator.parMap(fn, capacity = capacity)(ExecutionContext.fromExecutorService(pool)) | |
new SelfClosingIterator[B](results, pool.shutdown) | |
} | |
} | |
/** Implicitly add parallel operations onto Scala's base iterator class. */ | |
implicit class ParIteratorImpl[A](private val iterator: Iterator[A]) { | |
/** Parallelize work over an iterator using a given execution context buffering <capacity> results at a time. An | |
* additional single-thread execution context will be created to manage the side-effect of submitting all work | |
* to the primary <executor> which means there is no condition under which this iterator will deadlock infinitely | |
* unless you ask for infinite timeouts while awaiting computations (default is 1 hour). If <capacity> is set to | |
* <None> then a dynamically expanding linked blocking queue is used but if <capacity> is set to a fixed size then | |
* an array blocking queue is used. Any exception caused by the source <iterator> or by the applied function <fn> | |
* will be raised even if running in a multi-threaded context. | |
* | |
* @param fn the method to map over the elements in the iterator. | |
* @param capacity the number of results to buffer at a time in the underlying blocking queue. | |
* @param timeOut await each result this amount of time before cancelling the computation and raising an exception. | |
*/ | |
def parMap[B](fn: A => B, capacity: Option[Int] = Some(DefaultQueueCapacity), timeOut: Duration = 1.hour)( | |
implicit executor: ExecutionContext | |
): Iterator[B] = { | |
val throwable = new AtomicReference[Throwable](null) // A place for any exceptions raised in the source iterator. | |
val finished = new CountDownLatch(1) // Set this to zero when we have finished sending jobs to the thread pool. | |
val ioPool = Executors.newSingleThreadExecutor | |
val ioContext = ExecutionContext.fromExecutorService(ioPool) | |
// Use a dynamically-expanding queue if no capacity was explicitly asked for, otherwise pre-allocate an array. | |
val queue: BlockingQueue[Option[Future[B]]] = capacity match { | |
case Some(size) => new ArrayBlockingQueue(size) | |
case None => new LinkedBlockingQueue() | |
} | |
// Use the IO execution context to fill the queue with results and terminate the queue with a final `None` to | |
// indicate that the input iterator is fully exhausted and all Futures have been scheduled. We wrap this call in a | |
// try-catch block in the exceptional case that exceptions are raised not in the input function, but in the source | |
// iteration itself (`iterator.foreach(???)`)! Any exception will be saved, then the iterator will short-circuit. | |
// Once the iterator short-circuits, an iterator exhaustion hook (defined below) will be called which includes a | |
// method to raise the exception properly so it is not silenced. If we did not handle exceptions this way, then | |
// the iterator could be truncated and data lost. | |
Future { | |
try { try iterator.foreach(elem => queue.put(Some(Future(fn(elem))(executor)))) finally queue.put(None) } | |
catch { case thr: Throwable => throwable.compareAndSet(null, thr) } | |
finally { finished.countDown() } | |
} (ioContext) | |
// Build the return iterator which will await results from the queue until the queue is empty. | |
new Iterator[B] { | |
/** Whether or not there is still pending work that is filling the queue with results. */ | |
private var alive: Boolean = true | |
/** The next element in the queue as that element is pulled from this thread. */ | |
private var nextFuture: Option[Future[B]] = None | |
/** If the iterator still has more object to yield. */ | |
override def hasNext: Boolean = { | |
alive && { | |
if (nextFuture.isEmpty) { | |
nextFuture = queue.take() match { | |
case None => alive = false; None | |
case some => some | |
} | |
} | |
// If there are no more Futures in the queue, then await the signal which indicates submission to the queue | |
// has finished and any exceptions that were raised are saved to `throwable`. Once the queue is no longer | |
// needed shutdown the queue to prevent a memory leak. Finally, Raise any exceptions that occurred during | |
// source iteration so the exceptions are not silently dropped. It is critically important to call these | |
// methods in this order because a race condition may occur when the source iterator raises an exception | |
// and short-circuits, but we have not yet had a chance to save the exception message before finishing the | |
// final call to `hasNext` (occurring in a separate thread). Awaiting the final countdown latch guarantees | |
// we will raise the exception message if it is present. | |
if (!alive) { | |
finished.await() | |
ioPool.shutdown() | |
Option(throwable.get).foreach(throw _) | |
} | |
alive | |
} | |
} | |
/** Return the next object in the iterator or raise an exception if there are no more objects. */ | |
override def next(): B = { | |
if (!hasNext) { Iterator.empty.next() } else { | |
val value = Await.result(nextFuture.get, atMost = timeOut) | |
nextFuture = None | |
value | |
} | |
} | |
} | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package io.cvbio.collection | |
import io.cvbio.collection.ParIterator.ParIteratorImpl | |
import io.cvbio.io.Io | |
import io.cvbio.testing.UnitSpec | |
import java.util.concurrent.Executors | |
import scala.concurrent.ExecutionContext | |
/** Unit tests for [[ParIterator]]. */ | |
class ParIteratorTest extends UnitSpec { | |
/** The number of threads to use in all thread pools. */ | |
private val ThreadCount = Io.AvailableProcessors | |
"ParIterator.map" should "return elements in the correct order when parallelized" in { | |
val expected = Range(1, 1000).inclusive | |
val actual = ParIterator.map[Int, Int]( | |
expected.iterator, | |
identity, | |
threads = ThreadCount, | |
capacity = Some(100) | |
).toSeq | |
actual should contain theSameElementsInOrderAs expected | |
} | |
if (ThreadCount > 1) { // To run these tests you must have more than one available processor. | |
it should "raise exceptions that occur within passed function running in threads, but only if multiple threads are used" in { | |
def raise(num: Int): Int = throw new IllegalArgumentException(num.toString) | |
an[IllegalArgumentException] shouldBe thrownBy { | |
ParIterator.map( | |
iterator = Range(1, 10).iterator, | |
fn = raise, | |
threads = ThreadCount | |
).toSeq | |
} | |
} | |
it should "raise exceptions that occur within the input iterator running in threads, but only if multiple threads are used" in { | |
def raise(num: Int): Int = throw new IllegalArgumentException(num.toString) | |
an[IllegalArgumentException] shouldBe thrownBy { | |
ParIterator.map( | |
iterator = Range(1, 10).iterator.map(raise), | |
fn = identity[Int], | |
threads = ThreadCount | |
).toSeq | |
} | |
} | |
} | |
"ParIterator.parMap" should "map over elements using a fixed size thread pool and a near-unlimited buffer" in { | |
val pool = Executors.newFixedThreadPool(ThreadCount) | |
val context = ExecutionContext.fromExecutorService(pool) | |
def addTen(int: Int): Int = int + 10 | |
val integers = Range(1, 10) | |
val actual = integers.iterator.parMap(addTen, capacity = None)(context).toSeq | |
pool.shutdown() | |
actual should contain theSameElementsInOrderAs integers.map(addTen) | |
} | |
it should "not deadlock if a fixed thread pool with one thread is requested" in { | |
val pool = Executors.newFixedThreadPool(1) | |
val context = ExecutionContext.fromExecutorService(pool) | |
def addTen(int: Int): Int = int + 10 | |
val integers = Range(1, 10) | |
val actual = integers.iterator.parMap(addTen, capacity = None)(context).toSeq | |
pool.shutdown() | |
actual should contain theSameElementsInOrderAs integers.map(addTen) | |
} | |
it should "map over elements using a fixed size thread pool and a buffer of a size smaller than the collection" in { | |
val pool = Executors.newFixedThreadPool(ThreadCount) | |
val context = ExecutionContext.fromExecutorService(pool) | |
def addTen(int: Int): Int = int + 10 | |
val integers = Range(1, 10) | |
val actual = integers.iterator.parMap(addTen, capacity = Some(1))(context).toSeq | |
pool.shutdown() | |
actual should contain theSameElementsInOrderAs integers.map(addTen) | |
} | |
it should "map over elements using the user-defined execution context and a right-sized buffer" in { | |
val pool = Executors.newFixedThreadPool(ThreadCount) | |
val context = ExecutionContext.fromExecutorService(pool) | |
def addTen(int: Int): Int = int + 10 | |
val integers = Range(1, 10) | |
val actual = integers.iterator.parMap(addTen, capacity = Some(integers.length))(context).toSeq | |
pool.shutdown() | |
actual should contain theSameElementsInOrderAs integers.map(addTen) | |
} | |
if (ThreadCount > 1) { // To run these tests you must have more than one available processor. | |
it should "raise exceptions that occur within passed function running in threads, but only if multiple threads are used" in { | |
val pool = Executors.newFixedThreadPool(ThreadCount) | |
val context = ExecutionContext.fromExecutorService(pool) | |
def raise(num: Int): Int = throw new IllegalArgumentException(num.toString) | |
an[IllegalArgumentException] shouldBe thrownBy { Range(1, 10).iterator.parMap(raise)(context).toSeq } | |
pool.shutdown() | |
} | |
it should "raise exceptions that occur within the input iterator running in threads, but only if multiple threads are used" in { | |
val pool = Executors.newFixedThreadPool(ThreadCount) | |
val context = ExecutionContext.fromExecutorService(pool) | |
def raise(num: Int): Int = throw new IllegalArgumentException(num.toString) | |
an[IllegalArgumentException] shouldBe thrownBy { Range(1, 10).iterator.map(raise).parMap(identity[Int])(context).toSeq } | |
pool.shutdown() | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment