Skip to content

Instantly share code, notes, and snippets.

@castortech
Last active January 10, 2025 12:03
Show Gist options
  • Save castortech/2cc20df8da329a3585135753ea27ab6b to your computer and use it in GitHub Desktop.
Save castortech/2cc20df8da329a3585135753ea27ab6b to your computer and use it in GitHub Desktop.
MDC ForkJoinPool
public enum ArgType {
MDC_KEY, //Map<String, String>
OLD_MAP, //Map<String, String>
;
}
import java.lang.Thread.UncaughtExceptionHandler;
import java.security.AccessControlContext;
import java.security.AccessController;
import java.security.Permission;
import java.security.Permissions;
import java.security.PrivilegedAction;
import java.security.ProtectionDomain;
import java.util.EnumMap;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.ForkJoinWorkerThread;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.MDC;
import com.castortech.util.session.WrappedCall;
public class MdcForkJoinPool extends ForkJoinPool implements WrappedCall {
private static final int POOL_SIZE = 5;
/**
* Creates a new MdcForkJoinPool.
*
* @param parallelism the parallelism level. For default value, use {@link java.lang.Runtime#availableProcessors}.
* @param factory the factory for creating new threads. For default value, use
* {@link #defaultForkJoinWorkerThreadFactory}.
* @param handler the handler for internal worker threads that terminate due to unrecoverable errors encountered
* while executing tasks. For default value, use {@code null}.
* @param asyncMode if true, establishes local first-in-first-out scheduling mode for forked tasks that are never
* joined. This mode may be more appropriate than default locally stack-based mode in applications
* in which worker threads only process event-style asynchronous tasks. For default value, use
* {@code false}.
* @throws IllegalArgumentException if parallelism less than or equal to zero, or greater than implementation limit
* @throws NullPointerException if the factory is null
* @throws SecurityException if a security manager exists and the caller is not permitted to modify threads
* because it does not hold
* {@link java.lang.RuntimePermission}{@code ("modifyThread")}
*/
public MdcForkJoinPool(int parallelism, ForkJoinWorkerThreadFactory factory,
UncaughtExceptionHandler handler, boolean asyncMode) {
super(parallelism, factory, handler, asyncMode);
}
public MdcForkJoinPool(boolean asyncMode) {
super(POOL_SIZE, new MDCForkJoinWorkerThreadFactory(), null, asyncMode);
}
public MdcForkJoinPool() {
super(POOL_SIZE, new MDCForkJoinWorkerThreadFactory(), null, false);
}
@Override
public void execute(ForkJoinTask<?> task) {
super.execute(wrap(task, getContextForTask()));
}
@Override
public void execute(Runnable task) {
super.execute(wrap(task, getContextForTask()));
}
@Override
public <T> ForkJoinTask<T> submit(ForkJoinTask<T> task) {
return super.submit(wrap(task, getContextForTask()));
}
@Override
public <T> ForkJoinTask<T> submit(Callable<T> task) {
return super.submit(wrap(task, getContextForTask()));
}
@Override
public <T> ForkJoinTask<T> submit(Runnable task, T result) {
return super.submit(wrap(task, getContextForTask()), result);
}
@Override
public ForkJoinTask<?> submit(Runnable task) {
return super.submit(wrap(task, getContextForTask()));
}
private static Map<ArgType, Object> getContextForTask() {
Map<ArgType, Object> args = new EnumMap<>(ArgType.class);
args.put(ArgType.MDC_KEY, MDC.getCopyOfContextMap());
return args;
}
public Runnable wrap(final Runnable runnable, final Map<ArgType, Object> args) {
return () -> {
args.put(ArgType.OLD_MAP, preCall(args));
try {
runnable.run();
}
finally {
postCall(args);
}
};
}
public <T> Callable<T> wrap(Callable<T> task, final Map<ArgType, Object> args) {
return () -> {
args.put(ArgType.OLD_MAP, preCall(args));
try {
return task.call();
}
finally {
postCall(args);
}
};
}
private <T> ForkJoinTask<T> wrap(ForkJoinTask<T> task, final Map<ArgType, Object> args) {
return new ForkJoinTask<T>() {
private static final long serialVersionUID = 1L;
/**
* If non-null, overrides the value returned by the underlying task.
*/
private final AtomicReference<T> override = new AtomicReference<>();
@Override
public T getRawResult() {
T result = override.get();
if (result != null)
return result;
return task.getRawResult();
}
@Override
protected void setRawResult(T value) {
override.set(value);
}
@Override
protected boolean exec() {
// According to ForkJoinTask.fork() "it is a usage error to fork a task more than once unless it has
// completed and been reinitialized". We therefore assume that this method does not have to be
// thread-safe.
args.put(ArgType.OLD_MAP, preCall(args));
try {
task.invoke();
return true;
}
finally {
postCall(args);
}
}
};
}
private static final class MDCForkJoinWorkerThreadFactory implements ForkJoinWorkerThreadFactory {
@SuppressWarnings("nls")
@Override
public final ForkJoinWorkerThread newThread(ForkJoinPool pool) {
final AccessControlContext acc = contextWithPermissions(
new RuntimePermission("getClassLoader"),
new RuntimePermission("setContextClassLoader"));
return AccessController.doPrivileged((PrivilegedAction<ForkJoinWorkerThread>)
() -> new MdcForkJoinWorkerThread(pool, getContextForTask()), acc);
}
}
static AccessControlContext contextWithPermissions(Permission... perms) {
Permissions permissions = new Permissions();
for (Permission perm : perms)
permissions.add(perm);
return new AccessControlContext(new ProtectionDomain[] { new ProtectionDomain(null, permissions) });
}
}
import java.util.Map;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinWorkerThread;
public class MdcForkJoinWorkerThread extends ForkJoinWorkerThread implements WrappedCall {
private final Map<ArgType, Object> context;
public MdcForkJoinWorkerThread(ForkJoinPool pool, Map<ArgType, Object> context) {
super(pool);
this.context = context;
}
@Override
protected void onStart() {
super.onStart();
context.put(ArgType.OLD_MAP, preCall(context));
}
@Override
protected void onTermination(Throwable exception) {
super.onTermination(exception);
postCall(context);
}
}
import java.util.Map;
import org.slf4j.MDC;
public interface WrappedCall {
default Map<String, String> preCall(Map<ArgType, Object> args) {
Map<String, String> mdcMap = uncheckedCast(args.get(ArgType.MDC_KEY));
Map<String, String> oldMap = MDC.getCopyOfContextMap();
if (mdcMap == null) {
MDC.clear();
}
else {
MDC.setContextMap(mdcMap);
}
return oldMap;
}
default void postCall(Map<ArgType, Object> args) {
Map<String, String> oldMap = uncheckedCast(args.get(ArgType.OLD_MAP));
if (oldMap == null) {
MDC.clear();
}
else {
MDC.setContextMap(oldMap);
}
}
@SuppressWarnings("unchecked")
static <T, X> X uncheckedCast(T o) {
return (X)o;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment