Last active
January 10, 2025 12:03
-
-
Save castortech/2cc20df8da329a3585135753ea27ab6b to your computer and use it in GitHub Desktop.
MDC ForkJoinPool
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| public enum ArgType { | |
| MDC_KEY, //Map<String, String> | |
| OLD_MAP, //Map<String, String> | |
| ; | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import java.lang.Thread.UncaughtExceptionHandler; | |
| import java.security.AccessControlContext; | |
| import java.security.AccessController; | |
| import java.security.Permission; | |
| import java.security.Permissions; | |
| import java.security.PrivilegedAction; | |
| import java.security.ProtectionDomain; | |
| import java.util.EnumMap; | |
| import java.util.Map; | |
| import java.util.concurrent.Callable; | |
| import java.util.concurrent.ForkJoinPool; | |
| import java.util.concurrent.ForkJoinTask; | |
| import java.util.concurrent.ForkJoinWorkerThread; | |
| import java.util.concurrent.atomic.AtomicReference; | |
| import org.slf4j.MDC; | |
| import com.castortech.util.session.WrappedCall; | |
| public class MdcForkJoinPool extends ForkJoinPool implements WrappedCall { | |
| private static final int POOL_SIZE = 5; | |
| /** | |
| * Creates a new MdcForkJoinPool. | |
| * | |
| * @param parallelism the parallelism level. For default value, use {@link java.lang.Runtime#availableProcessors}. | |
| * @param factory the factory for creating new threads. For default value, use | |
| * {@link #defaultForkJoinWorkerThreadFactory}. | |
| * @param handler the handler for internal worker threads that terminate due to unrecoverable errors encountered | |
| * while executing tasks. For default value, use {@code null}. | |
| * @param asyncMode if true, establishes local first-in-first-out scheduling mode for forked tasks that are never | |
| * joined. This mode may be more appropriate than default locally stack-based mode in applications | |
| * in which worker threads only process event-style asynchronous tasks. For default value, use | |
| * {@code false}. | |
| * @throws IllegalArgumentException if parallelism less than or equal to zero, or greater than implementation limit | |
| * @throws NullPointerException if the factory is null | |
| * @throws SecurityException if a security manager exists and the caller is not permitted to modify threads | |
| * because it does not hold | |
| * {@link java.lang.RuntimePermission}{@code ("modifyThread")} | |
| */ | |
| public MdcForkJoinPool(int parallelism, ForkJoinWorkerThreadFactory factory, | |
| UncaughtExceptionHandler handler, boolean asyncMode) { | |
| super(parallelism, factory, handler, asyncMode); | |
| } | |
| public MdcForkJoinPool(boolean asyncMode) { | |
| super(POOL_SIZE, new MDCForkJoinWorkerThreadFactory(), null, asyncMode); | |
| } | |
| public MdcForkJoinPool() { | |
| super(POOL_SIZE, new MDCForkJoinWorkerThreadFactory(), null, false); | |
| } | |
| @Override | |
| public void execute(ForkJoinTask<?> task) { | |
| super.execute(wrap(task, getContextForTask())); | |
| } | |
| @Override | |
| public void execute(Runnable task) { | |
| super.execute(wrap(task, getContextForTask())); | |
| } | |
| @Override | |
| public <T> ForkJoinTask<T> submit(ForkJoinTask<T> task) { | |
| return super.submit(wrap(task, getContextForTask())); | |
| } | |
| @Override | |
| public <T> ForkJoinTask<T> submit(Callable<T> task) { | |
| return super.submit(wrap(task, getContextForTask())); | |
| } | |
| @Override | |
| public <T> ForkJoinTask<T> submit(Runnable task, T result) { | |
| return super.submit(wrap(task, getContextForTask()), result); | |
| } | |
| @Override | |
| public ForkJoinTask<?> submit(Runnable task) { | |
| return super.submit(wrap(task, getContextForTask())); | |
| } | |
| private static Map<ArgType, Object> getContextForTask() { | |
| Map<ArgType, Object> args = new EnumMap<>(ArgType.class); | |
| args.put(ArgType.MDC_KEY, MDC.getCopyOfContextMap()); | |
| return args; | |
| } | |
| public Runnable wrap(final Runnable runnable, final Map<ArgType, Object> args) { | |
| return () -> { | |
| args.put(ArgType.OLD_MAP, preCall(args)); | |
| try { | |
| runnable.run(); | |
| } | |
| finally { | |
| postCall(args); | |
| } | |
| }; | |
| } | |
| public <T> Callable<T> wrap(Callable<T> task, final Map<ArgType, Object> args) { | |
| return () -> { | |
| args.put(ArgType.OLD_MAP, preCall(args)); | |
| try { | |
| return task.call(); | |
| } | |
| finally { | |
| postCall(args); | |
| } | |
| }; | |
| } | |
| private <T> ForkJoinTask<T> wrap(ForkJoinTask<T> task, final Map<ArgType, Object> args) { | |
| return new ForkJoinTask<T>() { | |
| private static final long serialVersionUID = 1L; | |
| /** | |
| * If non-null, overrides the value returned by the underlying task. | |
| */ | |
| private final AtomicReference<T> override = new AtomicReference<>(); | |
| @Override | |
| public T getRawResult() { | |
| T result = override.get(); | |
| if (result != null) | |
| return result; | |
| return task.getRawResult(); | |
| } | |
| @Override | |
| protected void setRawResult(T value) { | |
| override.set(value); | |
| } | |
| @Override | |
| protected boolean exec() { | |
| // According to ForkJoinTask.fork() "it is a usage error to fork a task more than once unless it has | |
| // completed and been reinitialized". We therefore assume that this method does not have to be | |
| // thread-safe. | |
| args.put(ArgType.OLD_MAP, preCall(args)); | |
| try { | |
| task.invoke(); | |
| return true; | |
| } | |
| finally { | |
| postCall(args); | |
| } | |
| } | |
| }; | |
| } | |
| private static final class MDCForkJoinWorkerThreadFactory implements ForkJoinWorkerThreadFactory { | |
| @SuppressWarnings("nls") | |
| @Override | |
| public final ForkJoinWorkerThread newThread(ForkJoinPool pool) { | |
| final AccessControlContext acc = contextWithPermissions( | |
| new RuntimePermission("getClassLoader"), | |
| new RuntimePermission("setContextClassLoader")); | |
| return AccessController.doPrivileged((PrivilegedAction<ForkJoinWorkerThread>) | |
| () -> new MdcForkJoinWorkerThread(pool, getContextForTask()), acc); | |
| } | |
| } | |
| static AccessControlContext contextWithPermissions(Permission... perms) { | |
| Permissions permissions = new Permissions(); | |
| for (Permission perm : perms) | |
| permissions.add(perm); | |
| return new AccessControlContext(new ProtectionDomain[] { new ProtectionDomain(null, permissions) }); | |
| } | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import java.util.Map; | |
| import java.util.concurrent.ForkJoinPool; | |
| import java.util.concurrent.ForkJoinWorkerThread; | |
| public class MdcForkJoinWorkerThread extends ForkJoinWorkerThread implements WrappedCall { | |
| private final Map<ArgType, Object> context; | |
| public MdcForkJoinWorkerThread(ForkJoinPool pool, Map<ArgType, Object> context) { | |
| super(pool); | |
| this.context = context; | |
| } | |
| @Override | |
| protected void onStart() { | |
| super.onStart(); | |
| context.put(ArgType.OLD_MAP, preCall(context)); | |
| } | |
| @Override | |
| protected void onTermination(Throwable exception) { | |
| super.onTermination(exception); | |
| postCall(context); | |
| } | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import java.util.Map; | |
| import org.slf4j.MDC; | |
| public interface WrappedCall { | |
| default Map<String, String> preCall(Map<ArgType, Object> args) { | |
| Map<String, String> mdcMap = uncheckedCast(args.get(ArgType.MDC_KEY)); | |
| Map<String, String> oldMap = MDC.getCopyOfContextMap(); | |
| if (mdcMap == null) { | |
| MDC.clear(); | |
| } | |
| else { | |
| MDC.setContextMap(mdcMap); | |
| } | |
| return oldMap; | |
| } | |
| default void postCall(Map<ArgType, Object> args) { | |
| Map<String, String> oldMap = uncheckedCast(args.get(ArgType.OLD_MAP)); | |
| if (oldMap == null) { | |
| MDC.clear(); | |
| } | |
| else { | |
| MDC.setContextMap(oldMap); | |
| } | |
| } | |
| @SuppressWarnings("unchecked") | |
| static <T, X> X uncheckedCast(T o) { | |
| return (X)o; | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment