Skip to content

Instantly share code, notes, and snippets.

@shyiko
Created May 17, 2012 20:19
Show Gist options
  • Save shyiko/2721351 to your computer and use it in GitHub Desktop.
Save shyiko/2721351 to your computer and use it in GitHub Desktop.
InvocationLocal<T>
/**
* Provides method-invocation-scope variables.
* <p/>
* Each invocation of {@link #get()} within same method (including nested calls) will yield value previously set by
* {@link #set(Object)}. Once {@link #set(Object)} region ends, value is rolled back to the default one (provided
* during {@link InvocationLocal} instance construction).
* <p/>
* Implementation is thread-safe.
*
* @author <a href="mailto:[email protected]">sshyiko</a>
*/
public class InvocationLocal<T> {
private static final int DEFAULT_STACK_SHIFT = 4;
private ThreadLocal<State> threadLocalHolder = new ThreadLocal<State>();
private final int stackShift;
public InvocationLocal(final T defaultValue) {
this(new DefaultValueProvider<T>() {
@Override
public T get() {
return defaultValue;
}
}, 0);
}
public InvocationLocal(final T defaultValue, final int stackShift) {
this(new DefaultValueProvider<T>() {
@Override
public T get() {
return defaultValue;
}
}, stackShift);
}
public InvocationLocal(final DefaultValueProvider<T> defaultValueProvider) {
this(defaultValueProvider, 0);
}
public InvocationLocal(final DefaultValueProvider<T> defaultValueProvider, final int stackShift) {
threadLocalHolder = new ThreadLocal<State>() {
private final String[] EMPTY_STACK_TRACE = new String[0];
@Override
protected State initialValue() {
return new State(EMPTY_STACK_TRACE, defaultValueProvider.get());
}
};
this.stackShift = Math.max(stackShift + DEFAULT_STACK_SHIFT, DEFAULT_STACK_SHIFT);
}
public T get() {
State state = threadLocalHolder.get();
if (!isNestedInvocation(state.getStackTrace(), getStackTrace())) {
threadLocalHolder.remove();
state = threadLocalHolder.get();
}
return state.getValue();
}
public T getIgnoringScope() {
return threadLocalHolder.get().getValue();
}
public void set(T value) {
threadLocalHolder.set(new State(getStackTrace(), value));
}
public void updateScope() {
threadLocalHolder.get().setStackTrace(getStackTrace());
}
private boolean isNestedInvocation(String[] stackTrace, String[] nestedStackTrace) {
int stackTraceLength = stackTrace.length;
if (stackTraceLength > nestedStackTrace.length) {
return false;
}
for (int i = 0; i < stackTraceLength; i++) {
if (!stackTrace[i].equals(nestedStackTrace[i])) {
return false;
}
}
return true;
}
private String[] getStackTrace() {
StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace();
String[] result = new String[stackTrace.length - stackShift];
int resultIndex = result.length - 1, stackIndex = 0;
for (StackTraceElement stackTraceElement : stackTrace) {
if (stackIndex++ < stackShift) {
continue;
}
result[resultIndex--] = stackTraceElement.toString();
}
return result;
}
private class State {
private String[] stackTrace;
private T value;
private State(String[] stackTrace, T value) {
this.stackTrace = stackTrace;
this.value = value;
}
public String[] getStackTrace() {
return stackTrace;
}
public void setStackTrace(String[] stackTrace) {
this.stackTrace = stackTrace;
}
public T getValue() {
return value;
}
public void setValue(T value) {
this.value = value;
}
}
public interface DefaultValueProvider<T> {
T get();
}
}
@shyiko
Copy link
Author

shyiko commented May 17, 2012

@Test
public void testNestedCall() {
    InvocationLocal<Integer> invocationScopeLocal = new InvocationLocal<Integer>(1);
    runNested(invocationScopeLocal);
    assertEqualsNested(invocationScopeLocal, 1);
}

private void runNested(InvocationLocal<Integer> invocationScopeLocal) {
    invocationScopeLocal.set(2);
    assertEquals(2, (int) invocationScopeLocal.get());
    assertEqualsNested(invocationScopeLocal, 2);
}

private void assertEqualsNested(InvocationLocal<Integer> invocationScopeLocal, int value) {
    assertEquals(value, (int) invocationScopeLocal.get());
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment