Last active
April 12, 2021 21:16
-
-
Save isopropylcyanide/b345af36179ee522d4b2152a502e174c to your computer and use it in GitHub Desktop.
Concurrently execute work and aggregate results using a completion service. Tests are added. In one file for brevity.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.random.aman; | |
import com.google.common.collect.Lists; | |
import lombok.AllArgsConstructor; | |
import lombok.extern.slf4j.Slf4j; | |
import java.util.ArrayList; | |
import java.util.List; | |
import java.util.Random; | |
import java.util.concurrent.CompletionService; | |
import java.util.concurrent.ExecutionException; | |
import java.util.concurrent.ExecutorCompletionService; | |
import java.util.concurrent.ExecutorService; | |
import java.util.concurrent.Executors; | |
import java.util.concurrent.ThreadFactory; | |
import java.util.concurrent.TimeUnit; | |
import java.util.concurrent.atomic.AtomicInteger; | |
import java.util.function.Function; | |
import java.util.stream.Collector; | |
import java.util.stream.Collectors; | |
@Slf4j | |
@AllArgsConstructor | |
public class WorkExecutor { | |
private final MathService mathService; | |
private final ConcurrentWorkExecutor concurrentWorkExecutor; | |
public List<Integer> processExpensiveFunction(List<Integer> inputs, int firstK) { | |
try { | |
List<Integer> outputs = new ArrayList<>(); | |
ExecutorService executor = Executors.newCachedThreadPool(); | |
ExecutorCompletionService<Integer> executorService = new ExecutorCompletionService<>(executor); | |
for (Integer input : inputs) { | |
executorService.submit(() -> { | |
log.info("Submitting task f({}) on {}", input, Thread.currentThread().getName()); | |
return mathService.process(input); | |
}); | |
} | |
executor.shutdown(); | |
for (int outputCount = 1; outputCount <= firstK; outputCount++) { | |
Integer output = executorService.take().get(); | |
log.info("Finished {}/{}: {} on {}", outputCount, firstK, output, Thread.currentThread().getName()); | |
outputs.add(output); | |
} | |
return outputs; | |
} catch (Exception ex) { | |
log.error("Error processing expensive function"); | |
return null; | |
} | |
} | |
public List<Integer> processExpensiveFunctionElegant(List<Integer> inputs, int firstK) { | |
try { | |
return concurrentWorkExecutor.splitJoin(firstK, | |
inputs, | |
mathService::process, | |
Collectors.toList()); | |
} catch (Exception ex) { | |
log.error("Error processing expensive function", ex); | |
return null; | |
} | |
} | |
interface MathService { | |
/** | |
* Process an expensive function f(x) on a given integer and returns the result. | |
*/ | |
Integer process(Integer X); | |
} | |
static class RemoteMathService implements MathService { | |
private final Random delay = new Random(); | |
@Override | |
public Integer process(Integer X) { | |
try { | |
TimeUnit.MILLISECONDS.sleep(delayIteration()); | |
} catch (InterruptedException ignored) { | |
} | |
return X; | |
} | |
private long delayIteration() { | |
return 100L * delay.nextInt(10); | |
} | |
} | |
interface ConcurrentWorkExecutor { | |
/** | |
* Splits a work iterable and submits each task into an executor where the task input is obtained by | |
* applying a mapper function. The intermediate non null results are joined until a {@code size} number | |
* of results are obtained post which they are collected using a custom user defined collector. | |
* | |
* @param size the number of results to wait for | |
* @param iterable the work represented as an iterable | |
* @param mapper the transformation of individual item U within the iterable to a result T | |
* @param collector the collection of T results obtained into a custom collector of type V | |
* @return an object of type V which is obtained after collection of intermediate T type results | |
* @implNote - If all intermediate results are required before collecting the final result, {@code size} | |
* must be equal to the number of items represented by the iterable | |
*/ | |
<T, U, V> V splitJoin(int size, | |
Iterable<U> iterable, | |
Function<U, T> mapper, | |
Collector<T, ?, V> collector) throws ExecutionException, InterruptedException; | |
} | |
/** | |
* A concurrent work executor that blocks for the final result after individual execution results | |
* are obtained. The results are fed into the queue represented by completion service as they are | |
* getting completed. Note that this behavior can be changed by using a single threaded executor | |
* <p> | |
* If execution of any individual work results in an exception, an exception is raised | |
*/ | |
static class OutOfOrderConcurrentWorkExecutor implements ConcurrentWorkExecutor { | |
@Override | |
public <T, U, V> V splitJoin(int size, Iterable<U> iterable, Function<U, T> mapper, Collector<T, ?, V> collector) throws ExecutionException, InterruptedException { | |
ThreadFactory threadFactory = Executors.defaultThreadFactory(); | |
ExecutorService executor = Executors.newCachedThreadPool(threadFactory); //create a cached thread pool ideal for short lived tasks | |
AtomicInteger submitTaskCount = new AtomicInteger(0); //to keep track of submitted tasks | |
CompletionService<T> service = new ExecutorCompletionService<>(executor); | |
try { | |
iterable.forEach(u -> { //note that iterable.size() must not be really large as cached thread pool is unbounded | |
//we typically restrict this size bound when we accept the request itself. | |
service.submit(() -> { | |
log.info("Submitting task: splitJoin f({}) on {}", u, Thread.currentThread().getName()); | |
return mapper.apply(u); | |
}); //user specified mapper is invoked here to create a task and submitted | |
submitTaskCount.incrementAndGet(); | |
}); | |
} finally { | |
executor.shutdown(); // stop accepting any more tasks except for the ones that are submitted | |
} | |
List<T> results = new ArrayList<>(); | |
for (int i = 0; i < Math.min(submitTaskCount.get(), size); i++) { //if we didn't do min(size, submitted) & size > submitted, it would lead to an | |
//infinite loop as take() blocks on an empty queue | |
T t; | |
try { | |
t = service.take().get(); | |
log.info("Finished {}/{}: {} on {}", i, submitTaskCount.get(), t, Thread.currentThread().getName()); | |
} catch (ExecutionException ex) { | |
log.error("Received error during computation for result in Thread [{}] {}", Thread.currentThread().getId(), ex.getMessage()); | |
throw ex; //letting the application handle it | |
} | |
if (t != null) { | |
results.add(t); //collecting the out of order result | |
} | |
} | |
return results.stream().collect(collector); //applying user specified collector to the collection of intermediate results | |
} | |
} | |
public static void main(String[] args) { | |
MathService mathService = new RemoteMathService(); | |
ConcurrentWorkExecutor concurrentWorkExecutor = new OutOfOrderConcurrentWorkExecutor(); | |
WorkExecutor workExecutor = new WorkExecutor(mathService, concurrentWorkExecutor); | |
List<Integer> input = Lists.newArrayList(10, 20, 7, 13, 34, -1, 14); | |
// List<Integer> output = workExecutor.processExpensiveFunction(input, 3); | |
List<Integer> output_ = workExecutor.processExpensiveFunctionElegant(input, 3); | |
// log.info("Result: {}", output); | |
log.info("Result: {}", output_); | |
} | |
class OutOfOrderConcurrentWorkExecutorTest { | |
private OutOfOrderConcurrentWorkExecutor executor; | |
@BeforeEach | |
public void setUp() { | |
this.executor = new OutOfOrderConcurrentWorkExecutor(); | |
} | |
@Test | |
public void testSplitJoinWaitsForAllResultsWhenSizeIsEqualToTheWorkSet() throws Exception { | |
List<Integer> userIds = IntStream.range(1, 6).boxed().collect(Collectors.toList()); | |
Map<Integer, User> userIdUserMap = executor.splitJoin( | |
userIds.size(), userIds, | |
(id) -> new User(RandomStringUtils.randomAlphanumeric(2, 5), id), | |
Collectors.toMap(User::getId, Function.identity()), | |
(u, ex) -> { | |
throw new RuntimeException(ex); | |
}); | |
assertEquals(userIds.size(), userIdUserMap.size()); | |
} | |
@Test | |
public void testSplitJoinWaitsForOnlyTheRequiredResultsWhenSizeIsLessThanWorkSet() throws Exception { | |
List<Integer> userIds = IntStream.range(0, 6).boxed().collect(Collectors.toList()); | |
Map<Integer, User> userIdUserMap = executor.splitJoin( | |
userIds.size() / 2, userIds, | |
(id) -> new User(RandomStringUtils.randomAlphanumeric(2, 5), id), | |
Collectors.toMap(User::getId, Function.identity()), | |
(u, ex) -> { | |
throw new RuntimeException(ex); | |
}); | |
assertEquals(3, userIdUserMap.size()); | |
} | |
@Test | |
public void testSplitJoinWaitsForAllResultsWhenSizeIsMoreThanTheWorkSet() throws Exception { | |
List<Integer> userIds = IntStream.range(0, 4).boxed().collect(Collectors.toList()); | |
List<User> processedUsers = executor.splitJoin( | |
userIds.size() + 1, userIds, | |
(id) -> new User(RandomStringUtils.randomAlphanumeric(2, 5), id), | |
Collectors.toList(), | |
(u, ex) -> { | |
throw new RuntimeException(ex); | |
}); | |
assertEquals(4, processedUsers.size()); | |
} | |
@Test | |
public void testSplitJoinWaitsForAllResultsWhenIntermediateResultThrowsException() throws Exception { | |
List<Integer> userIds = IntStream.range(0, 4).boxed().collect(Collectors.toList()); | |
Function<Integer, User> mapper = (id) -> { | |
if (id == 2) { | |
throw new IllegalArgumentException("User 2 is blacklisted"); | |
} | |
return new User(RandomStringUtils.randomAlphanumeric(2, 5), id); | |
}; | |
try { | |
executor.splitJoin(userIds.size(), userIds, mapper, Collectors.toList(), | |
(u, ex) -> { | |
throw new RuntimeException(ex); | |
}); | |
} catch (RuntimeException ex) { | |
assertTrue(ex.getMessage().contains("User 2 is blacklisted")); | |
assertTrue(Throwables.getRootCause(ex) instanceof IllegalArgumentException); | |
} | |
} | |
@Test | |
public void testSplitJoinReturnsSuccessfullyWhenOneResultThrowsExceptionButItIsNotRequired() throws Exception { | |
List<Integer> userIds = IntStream.range(0, 4).boxed().collect(Collectors.toList()); | |
Function<Integer, User> mapper = (id) -> { | |
if (id == 2) { | |
throw new IllegalArgumentException("User 2 is blacklisted"); | |
} | |
System.out.println("Processing " + id); | |
return new User(RandomStringUtils.randomAlphanumeric(2, 5), id); | |
}; | |
try { | |
executor.splitJoin(userIds.size() - 1, userIds, mapper, Collectors.toList(), | |
(u, ex) -> { | |
throw new RuntimeException(ex); | |
}); | |
} catch (RuntimeException ex) { | |
assertTrue(ex.getMessage().contains("User 2 is blacklisted")); | |
assertTrue(Throwables.getRootCause(ex) instanceof IllegalArgumentException); | |
} | |
} | |
@AllArgsConstructor | |
@Getter | |
private static class User { | |
private final String name; | |
private final Integer id; | |
@Override | |
public String toString() { | |
return "name='" + name + '\'' + ", id=" + id + '}'; | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment