Skip to content

Instantly share code, notes, and snippets.

@Genzer
Created September 24, 2024 23:30
Show Gist options
  • Save Genzer/39fc4af6e7c9599a36af31dd9014e3d2 to your computer and use it in GitHub Desktop.
Save Genzer/39fc4af6e7c9599a36af31dd9014e3d2 to your computer and use it in GitHub Desktop.
package com.grokhard.explorejavaconcurrency;
import static java.util.concurrent.TimeUnit.SECONDS;
import java.util.Objects;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
/**
* <p>
* This class acts as a guard which provides simple facility to ensure if there
* is a need for a cache-write, only one Thread is allowed to do it.
* The rest, reading Threads, will be halted until the cache-write finishes.
* </p>
*
* <p>
* This class aims to ensure a good-enough concurrency level when it comes
* to accessing cached value.
* </p>
*
*
* <p>
* Implementation Note: internally, the class uses {@linkplain ReadWriteLock}
* to maintain concurrency between read and write.
* </p>
*
* <p>
* Typical use:
* </p>
*
* <pre>{@code
* final ConcurrentMap<String, LocalDateTime> cache = new ConcurrentHashMap<>();
*
* final CacheReadWriteGuard<LocalDateTime> cacheGuard = CacheReadWriteGuard
* .of(LocalDateTime.class)
* .readBy(() -> cache.get("some-cache-value"))
* // expiredIf() could be omitted in this case because by default
* // the builder will provide the same implementation.
* .expireIf(() -> cache.get("some-cache-value") == null)
* .writeOnExpired(() -> cache.put("some-cache-value", LocalDateTime.now()))
* .build();
* }</pre>
*
*<p>References</p>:
*<ul>
* <li>https://stackoverflow.com/questions/25245371/lock-or-wait-cache-load</li>
*</ul>
*
* @param <T> the type of value being cached
*
* @see ReentrantReadWriteLock
*/
public class CacheReadWriteGuard<T> {
@FunctionalInterface
public static interface ExpiredCacheCondition {
boolean isExpired();
}
@FunctionalInterface
public static interface CacheRead<T> {
T read();
}
@FunctionalInterface
public static interface CacheWrite extends Runnable {
}
public static class Builder<T> {
private ExpiredCacheCondition cacheCondition;
private CacheRead<T> readAction;
private CacheWrite writeAction;
public Builder() {
}
/**
* Provides a way that the guard can tell if the cache is expired.
* In case if no {@code ExpiredCacheCondition} is supplied, this
* builder by default will use CacheRead.read() == null as the condition.
*/
public Builder<T> expiredIf(ExpiredCacheCondition cacheCondition) {
this.cacheCondition = cacheCondition;
return this;
}
public Builder<T> readBy(CacheRead<T> cacheRead) {
this.readAction = cacheRead;
return this;
}
public Builder<T> writeOnExpired(CacheWrite cacheWrite) {
this.writeAction = cacheWrite;
return this;
}
public CacheReadWriteGuard<T> build() {
Objects.requireNonNull(readAction, "CacheRead must not be null");
Objects.requireNonNull(writeAction, "CacheWrite must not be null");
if (cacheCondition == null) {
cacheCondition = () -> readAction.read() == null;
}
return new CacheReadWriteGuard<>(cacheCondition, readAction, writeAction);
}
}
private final ExpiredCacheCondition cacheCondition;
private final CacheRead<T> readAction;
private final CacheWrite writeAction;
private final ReadWriteLock readWriteLock = new ReentrantReadWriteLock();
private final Lock readLock = readWriteLock.readLock();
private final Lock writeLock = readWriteLock.writeLock();
private CacheReadWriteGuard(
ExpiredCacheCondition condition,
CacheRead<T> readAction,
CacheWrite writeAction) {
this.cacheCondition = condition;
this.readAction = readAction;
this.writeAction = writeAction;
}
/**
* Returns a new Builder. The {@code Class<T>} is solely
* for ensuring the compiler can infer the type of the
* cached value.
*/
public static <T> Builder<T> of(Class<T> type) {
return new Builder<>();
}
public T get() {
T cachedValue = null;
loadCache();
try {
readLock.lock();
cachedValue = readAction.read();
} finally {
readLock.unlock();
}
return cachedValue;
}
private void loadCache() {
/*
* Keeping all the other threads which don't hold the
* write lock busy in the while loop until the one
* which holds finish the writeAction.
*
* If the cacheCondition returns false, the loop exits
* immediately, thus making no attempt acquiring
* the writeLock.
*/
while (cacheCondition.isExpired()) {
try {
/*
* Attempts to get the write lock within 5 seconds.
*
* If the current Thread fails to get the lock, it will
* be kept busy by the outer while loop.
*
* And yes, the comment block wrapping the number 5 is
* for visual effect making the number 5 stand out.
*/
if (writeLock.tryLock(/**/ 5 /**/, SECONDS)) {
try {
/*
* Re-check the state of the cache in case if some other
* Thread has acquired the writeLock and
* finish loading the cache _before_ this Thread's turn.
*/
if(cacheCondition.isExpired()) {
writeAction.run();
}
} finally {
writeLock.unlock();
}
}
} catch (InterruptedException interrupted) {
// According to the article written by Brian Goetz (you should
// know who he is ;), [Dealing with InterruptedException]
// (https://www.ibm.com/developerworks/library/j-jtp05236/):
//
// "If you catch InterruptedException but cannot rethrow it,
// you should preserve evidence that the interruption occurred
// so that code higher up on the call stack can learn of the
// interruption and respond to it if it wants to.
//
// This task is accomplished by calling interrupt() to
// "reinterrupt" the current thread, as shown in Listing 3.
//
// At the very least, whenever you catch InterruptedException
// and don't rethrow it, reinterrupt the current thread before
// returning."
// Here the Thread is blocked by trying to access the
// writeLock (within 5 seconds). When it (ever) gets interrupted,
// that means it fails to acquire the lock. We don't want the
// Thread to either proceed to read the cache nor throwing an
// exception. So we simply set the interrupt flag again and let
// the Thread continue the loop.
Thread.currentThread().interrupt();
}
}
}
}
package com.grokhard.explorejavaconcurrency;
import java.time.LocalDateTime;
import java.time.LocalTime;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.Collections;
/**
* This test is for regression testing the class CacheReadWriteGuard,
* it should not be used by any other purposes.
*
* This class is kept package-private (no modifier) so that classes
* outside the package cannot access this class.
*
*/
class CacheReadWriteGuardTest {
private final static int totalTestRuntimeInMinutes = 5;
// This value is used to simulate an expensive process to load the cache.
// 10 seconds is considered *long* for loading the cache.
private final static int timeInitializingCacheInSeconds = 10;
// This value is used to pause the cache-invalidating thread.
// It is recommended to put it way bigger than timeInitializingCacheInSeconds.
// Otherwise, you will get into the situation described in the invalidate
// thread setup code down below.
private final static int timeWaitForNextCacheInvalidation = 20;
// Just a simple ConcurrentHashMap to represent typical cache implementation
private static final ConcurrentMap<String, LocalDateTime> cache = new ConcurrentHashMap<>();
private static final CacheReadWriteGuard<LocalDateTime> cacheGuard = CacheReadWriteGuard
.of(LocalDateTime.class)
.readBy(() -> cache.get("blablabla"))
.expiredIf(() -> cache.get("blablabla") == null)
.writeOnExpired(() -> {
System.out.println("At " + LocalDateTime.now() + Thread.currentThread() + " -> initialize cache");
cache.put("blablabla", LocalDateTime.now());
try {
// Simulate a long running process
Thread.sleep(TimeUnit.SECONDS.toMillis(timeInitializingCacheInSeconds));
} catch (InterruptedException interrupted) {
throw new RuntimeException(interrupted);
}
})
.build();
/*
* Entrypoint of the test. Run this method using your favorite
* IDE or command line.
*/
public static void main(String... args) {
/*
* The test scenario here is actually simple:
*
* - There are 20 threads trying to read the cache. Any Thread
* will try to initialize the cache if it is empty.
*
* - There is one single thread invalidating the cache
* randomly (the randomness is not that robust though)
*
* The test is expected to terminate gracefully
* and does it without any exception.
*/
CountDownLatch startSignal = new CountDownLatch(1);
CountDownLatch endSignal = new CountDownLatch(21);
ExecutorService es = Executors.newCachedThreadPool();
final LocalTime time = LocalTime.now().plusMinutes(totalTestRuntimeInMinutes);
List<Runnable> tasks = new ArrayList<>();
for (int i = 0; i < 20; i++) {
tasks.add(() -> {
final CountDownLatch start = startSignal;
final CountDownLatch end = endSignal;
try {
start.await();
} catch (Exception e) {
throw new RuntimeException(e);
}
while (LocalTime.now().isBefore(time)) {
LocalDateTime cached = cacheGuard.get();
System.out.println("At " + LocalDateTime.now() + Thread.currentThread() + " -> read " + cached);
}
end.countDown();
});
}
Runnable invalidate = () -> {
while (LocalTime.now().isBefore(time)) {
final CountDownLatch start = startSignal;
final CountDownLatch end = endSignal;
try {
start.await();
} catch (Exception e) {
throw new RuntimeException(e);
}
// It is recommended to use ThreadLocalRandom instead of
// Random or Math.random() because they have synchronization
// internally.
Random random = ThreadLocalRandom.current();
int next = random.nextInt();
boolean shouldClearCache = (next % 13 == 0);
if (shouldClearCache) {
System.out.println(
String.format("At %s %s random %s",
LocalDateTime.now().toString(),
Thread.currentThread().toString(),
next)
);
cache.clear();
System.out.println("At " + LocalDateTime.now() + Thread.currentThread() + " -> clearCache");
try {
// Rest for 20 seconds because I don't think
// cache should be invalidated that often
// Invalidating the cache too many times leads to
// high Thread contention though.
// You may comment out this try {} block or reduce the
// value of
// timeWaitForNextCacheInvalidation to see how it
// happens. There will be no Thread
// actually can reach to the cache because the
// cache keeps being cleared and the other Thread
// keeps re-initializing it infinitely (because of the while loop)
Thread.sleep(TimeUnit.SECONDS.toMillis(timeWaitForNextCacheInvalidation));
} catch (Exception e) {
throw new RuntimeException(e);
}
}
end.countDown();
}
};
tasks.add(invalidate);
Collections.shuffle(tasks);
tasks.forEach(es::submit);
// Starts all the tasks at the same time
startSignal.countDown();
try {
// Each task counts down on endSignal.
// Once every task finishes, the program
// terminates gracefully.
endSignal.await();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
es.shutdown();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment