Skip to content

Instantly share code, notes, and snippets.

@benjchristensen
Created May 11, 2013 14:55
Show Gist options
  • Save benjchristensen/5560178 to your computer and use it in GitHub Desktop.
Save benjchristensen/5560178 to your computer and use it in GitHub Desktop.
Uses a queue instead of `synchronized` keywords to serialize events. This is slower than using `synchronized` for handling contended access (as it now has overhead of queue etc) but it may make sense when a sequence has long-running events where we'd rather pay the queueing cost rather than blocking threads.
/**
* Copyright 2013 Netflix, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package rx.util;
import static org.junit.Assert.*;
import static org.mockito.Matchers.*;
import static org.mockito.Mockito.*;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import rx.Notification;
import rx.Observable;
import rx.Observer;
import rx.Subscription;
import rx.util.functions.Func1;
/**
* Serialize onNext/onCompleted/onError execution and ensure no events are propagated after unsubscribe/onCompleted/onError
*
* @param <T>
*/
public final class SerializedObserver<T> implements Observer<T> {
private final Observer<T> actualObserver;
private final AtomicObservableSubscription s;
private volatile boolean isFinished = false;
private final ConcurrentLinkedQueue<Notification<T>> queue = new ConcurrentLinkedQueue<Notification<T>>();
private final AtomicInteger counter = new AtomicInteger(0);
public SerializedObserver(Observer<T> observer, AtomicObservableSubscription s) {
this.actualObserver = observer;
this.s = s;
}
@Override
public void onCompleted() {
if (s.isUnsubscribed() || isFinished) {
return;
}
isFinished = true;
enqueue(new Notification<T>());
}
@Override
public void onError(final Exception e) {
if (s.isUnsubscribed() || isFinished) {
return;
}
isFinished = true;
enqueue(new Notification<T>(e));
}
@Override
public void onNext(final T args) {
if (s.isUnsubscribed() || isFinished) {
return;
}
enqueue(new Notification<T>(args));
}
/**
* Allow one thread to do actual work, all others will put the event in a queue.
* <p>
* We rely on the {@link ConcurrentLinkedQueue} to provide memory visibility across threads.
*
* @param notification
*/
private void enqueue(Notification<T> notification) {
// this must happen before 'counter' is used to provide synchronization between threads
queue.offer(notification);
// we now use counter to atomically determine if we need to start processing or not
// it will be 0 if it's the first notification or the scheduler has finished processing work
// and we need to start doing it again
if (counter.getAndIncrement() == 0) {
processQueue();
}
}
private void processQueue() {
// keep processing as long as 'counter' is not 0
do {
Notification<T> not = queue.poll();
switch (not.getKind()) {
case OnNext:
actualObserver.onNext(not.getValue());
break;
case OnError:
actualObserver.onError(not.getException());
break;
case OnCompleted:
actualObserver.onCompleted();
break;
default:
throw new IllegalStateException("Unknown kind of notification " + not);
}
} while (counter.decrementAndGet() > 0);
}
public static class UnitTest {
@Mock
Observer<String> aObserver;
@Before
public void before() {
MockitoAnnotations.initMocks(this);
}
@Test
public void testSingleThreadedBasic() {
Subscription s = mock(Subscription.class);
TestSingleThreadedObservable onSubscribe = new TestSingleThreadedObservable(s, "one", "two", "three");
Observable<String> w = Observable.create(onSubscribe);
AtomicObservableSubscription as = new AtomicObservableSubscription(s);
SerializedObserver<String> aw = new SerializedObserver<String>(aObserver, as);
w.subscribe(aw);
onSubscribe.waitToFinish();
verify(aObserver, times(1)).onNext("one");
verify(aObserver, times(1)).onNext("two");
verify(aObserver, times(1)).onNext("three");
verify(aObserver, never()).onError(any(Exception.class));
verify(aObserver, times(1)).onCompleted();
// non-deterministic because unsubscribe happens after 'waitToFinish' releases
// so commenting out for now as this is not a critical thing to test here
// verify(s, times(1)).unsubscribe();
}
@Test
public void testMultiThreadedBasic() {
Subscription s = mock(Subscription.class);
TestMultiThreadedObservable onSubscribe = new TestMultiThreadedObservable(s, "one", "two", "three");
Observable<String> w = Observable.create(onSubscribe);
AtomicObservableSubscription as = new AtomicObservableSubscription(s);
BusyObserver busyObserver = new BusyObserver();
SerializedObserver<String> aw = new SerializedObserver<String>(busyObserver, as);
w.subscribe(aw);
onSubscribe.waitToFinish();
assertEquals(3, busyObserver.onNextCount.get());
assertFalse(busyObserver.onError);
assertTrue(busyObserver.onCompleted);
// non-deterministic because unsubscribe happens after 'waitToFinish' releases
// so commenting out for now as this is not a critical thing to test here
// verify(s, times(1)).unsubscribe();
// we can have concurrency ...
assertTrue(onSubscribe.maxConcurrentThreads.get() > 1);
// ... but the onNext execution should be single threaded
assertEquals(1, busyObserver.maxConcurrentThreads.get());
}
@Test
public void testMultiThreadedWithNPE() {
Subscription s = mock(Subscription.class);
TestMultiThreadedObservable onSubscribe = new TestMultiThreadedObservable(s, "one", "two", "three", null);
Observable<String> w = Observable.create(onSubscribe);
AtomicObservableSubscription as = new AtomicObservableSubscription(s);
BusyObserver busyObserver = new BusyObserver();
SerializedObserver<String> aw = new SerializedObserver<String>(busyObserver, as);
w.subscribe(aw);
onSubscribe.waitToFinish();
System.out.println("maxConcurrentThreads: " + onSubscribe.maxConcurrentThreads.get());
// we can't know how many onNext calls will occur since they each run on a separate thread
// that depends on thread scheduling so 0, 1, 2 and 3 are all valid options
// assertEquals(3, busyObserver.onNextCount.get());
assertTrue(busyObserver.onNextCount.get() < 4);
assertTrue(busyObserver.onError);
// no onCompleted because onError was invoked
assertFalse(busyObserver.onCompleted);
// non-deterministic because unsubscribe happens after 'waitToFinish' releases
// so commenting out for now as this is not a critical thing to test here
//verify(s, times(1)).unsubscribe();
// we can have concurrency ...
assertTrue(onSubscribe.maxConcurrentThreads.get() > 1);
// ... but the onNext execution should be single threaded
assertEquals(1, busyObserver.maxConcurrentThreads.get());
}
@Test
public void testMultiThreadedWithNPEinMiddle() {
Subscription s = mock(Subscription.class);
TestMultiThreadedObservable onSubscribe = new TestMultiThreadedObservable(s, "one", "two", "three", null, "four", "five", "six", "seven", "eight", "nine");
Observable<String> w = Observable.create(onSubscribe);
AtomicObservableSubscription as = new AtomicObservableSubscription(s);
BusyObserver busyObserver = new BusyObserver();
SerializedObserver<String> aw = new SerializedObserver<String>(busyObserver, as);
w.subscribe(aw);
onSubscribe.waitToFinish();
System.out.println("maxConcurrentThreads: " + onSubscribe.maxConcurrentThreads.get());
// this should not be the full number of items since the error should stop it before it completes all 9
System.out.println("onNext count: " + busyObserver.onNextCount.get());
assertTrue(busyObserver.onNextCount.get() < 9);
assertTrue(busyObserver.onError);
// no onCompleted because onError was invoked
assertFalse(busyObserver.onCompleted);
// non-deterministic because unsubscribe happens after 'waitToFinish' releases
// so commenting out for now as this is not a critical thing to test here
// verify(s, times(1)).unsubscribe();
// we can have concurrency ...
assertTrue(onSubscribe.maxConcurrentThreads.get() > 1);
// ... but the onNext execution should be single threaded
assertEquals(1, busyObserver.maxConcurrentThreads.get());
}
/**
* A non-realistic use case that tries to expose thread-safety issues by throwing lots of out-of-order
* events on many threads.
*
* @param w
* @param tw
*/
@Test
public void runConcurrencyTest() {
ExecutorService tp = Executors.newFixedThreadPool(20);
try {
TestConcurrencyObserver tw = new TestConcurrencyObserver();
AtomicObservableSubscription s = new AtomicObservableSubscription();
SerializedObserver<String> w = new SerializedObserver<String>(tw, s);
Future<?> f1 = tp.submit(new OnNextThread(w, 12000));
Future<?> f2 = tp.submit(new OnNextThread(w, 5000));
Future<?> f3 = tp.submit(new OnNextThread(w, 75000));
Future<?> f4 = tp.submit(new OnNextThread(w, 13500));
Future<?> f5 = tp.submit(new OnNextThread(w, 22000));
Future<?> f6 = tp.submit(new OnNextThread(w, 15000));
Future<?> f7 = tp.submit(new OnNextThread(w, 7500));
Future<?> f8 = tp.submit(new OnNextThread(w, 23500));
Future<?> f10 = tp.submit(new CompletionThread(w, TestConcurrencyObserverEvent.onCompleted, f1, f2, f3, f4));
try {
Thread.sleep(1);
} catch (InterruptedException e) {
// ignore
}
Future<?> f11 = tp.submit(new CompletionThread(w, TestConcurrencyObserverEvent.onCompleted, f4, f6, f7));
Future<?> f12 = tp.submit(new CompletionThread(w, TestConcurrencyObserverEvent.onCompleted, f4, f6, f7));
Future<?> f13 = tp.submit(new CompletionThread(w, TestConcurrencyObserverEvent.onCompleted, f4, f6, f7));
Future<?> f14 = tp.submit(new CompletionThread(w, TestConcurrencyObserverEvent.onCompleted, f4, f6, f7));
// // the next 4 onError events should wait on same as f10
Future<?> f15 = tp.submit(new CompletionThread(w, TestConcurrencyObserverEvent.onError, f1, f2, f3, f4));
Future<?> f16 = tp.submit(new CompletionThread(w, TestConcurrencyObserverEvent.onError, f1, f2, f3, f4));
Future<?> f17 = tp.submit(new CompletionThread(w, TestConcurrencyObserverEvent.onError, f1, f2, f3, f4));
Future<?> f18 = tp.submit(new CompletionThread(w, TestConcurrencyObserverEvent.onError, f1, f2, f3, f4));
waitOnThreads(f1, f2, f3, f4, f5, f6, f7, f8, f10, f11, f12, f13, f14, f15, f16, f17, f18);
@SuppressWarnings("unused")
int numNextEvents = tw.assertEvents(null); // no check of type since we don't want to test barging results here, just interleaving behavior
// System.out.println("Number of events executed: " + numNextEvents);
} catch (Exception e) {
fail("Concurrency test failed: " + e.getMessage());
e.printStackTrace();
} finally {
tp.shutdown();
try {
tp.awaitTermination(5000, TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
private static void waitOnThreads(Future<?>... futures) {
for (Future<?> f : futures) {
try {
f.get(10, TimeUnit.SECONDS);
} catch (Exception e) {
System.err.println("Failed while waiting on future.");
e.printStackTrace();
}
}
}
/**
* A thread that will pass data to onNext
*/
public static class OnNextThread implements Runnable {
private final Observer<String> Observer;
private final int numStringsToSend;
OnNextThread(Observer<String> Observer, int numStringsToSend) {
this.Observer = Observer;
this.numStringsToSend = numStringsToSend;
}
@Override
public void run() {
for (int i = 0; i < numStringsToSend; i++) {
Observer.onNext("aString");
}
}
}
/**
* A thread that will call onError or onNext
*/
public static class CompletionThread implements Runnable {
private final Observer<String> Observer;
private final TestConcurrencyObserverEvent event;
private final Future<?>[] waitOnThese;
CompletionThread(Observer<String> Observer, TestConcurrencyObserverEvent event, Future<?>... waitOnThese) {
this.Observer = Observer;
this.event = event;
this.waitOnThese = waitOnThese;
}
@Override
public void run() {
/* if we have 'waitOnThese' futures, we'll wait on them before proceeding */
if (waitOnThese != null) {
for (Future<?> f : waitOnThese) {
try {
f.get();
} catch (Exception e) {
System.err.println("Error while waiting on future in CompletionThread");
}
}
}
/* send the event */
if (event == TestConcurrencyObserverEvent.onError) {
Observer.onError(new RuntimeException("mocked exception"));
} else if (event == TestConcurrencyObserverEvent.onCompleted) {
Observer.onCompleted();
} else {
throw new IllegalArgumentException("Expecting either onError or onCompleted");
}
}
}
private static enum TestConcurrencyObserverEvent {
onCompleted, onError, onNext
}
private static class TestConcurrencyObserver implements Observer<String> {
/** used to store the order and number of events received */
private final LinkedBlockingQueue<TestConcurrencyObserverEvent> events = new LinkedBlockingQueue<TestConcurrencyObserverEvent>();
private final int waitTime;
@SuppressWarnings("unused")
public TestConcurrencyObserver(int waitTimeInNext) {
this.waitTime = waitTimeInNext;
}
public TestConcurrencyObserver() {
this.waitTime = 0;
}
@Override
public void onCompleted() {
events.add(TestConcurrencyObserverEvent.onCompleted);
}
@Override
public void onError(Exception e) {
events.add(TestConcurrencyObserverEvent.onError);
}
@Override
public void onNext(String args) {
events.add(TestConcurrencyObserverEvent.onNext);
// do some artificial work to make the thread scheduling/timing vary
int s = 0;
for (int i = 0; i < 20; i++) {
s += s * i;
}
if (waitTime > 0) {
try {
Thread.sleep(waitTime);
} catch (InterruptedException e) {
// ignore
}
}
}
/**
* Assert the order of events is correct and return the number of onNext executions.
*
* @param expectedEndingEvent
* @return int count of onNext calls
* @throws IllegalStateException
* If order of events was invalid.
*/
public int assertEvents(TestConcurrencyObserverEvent expectedEndingEvent) throws IllegalStateException {
int nextCount = 0;
boolean finished = false;
for (TestConcurrencyObserverEvent e : events) {
if (e == TestConcurrencyObserverEvent.onNext) {
if (finished) {
// already finished, we shouldn't get this again
throw new IllegalStateException("Received onNext but we're already finished.");
}
nextCount++;
} else if (e == TestConcurrencyObserverEvent.onError) {
if (finished) {
// already finished, we shouldn't get this again
throw new IllegalStateException("Received onError but we're already finished.");
}
if (expectedEndingEvent != null && TestConcurrencyObserverEvent.onError != expectedEndingEvent) {
throw new IllegalStateException("Received onError ending event but expected " + expectedEndingEvent);
}
finished = true;
} else if (e == TestConcurrencyObserverEvent.onCompleted) {
if (finished) {
// already finished, we shouldn't get this again
throw new IllegalStateException("Received onCompleted but we're already finished.");
}
if (expectedEndingEvent != null && TestConcurrencyObserverEvent.onCompleted != expectedEndingEvent) {
throw new IllegalStateException("Received onCompleted ending event but expected " + expectedEndingEvent);
}
finished = true;
}
}
return nextCount;
}
}
/**
* This spawns a single thread for the subscribe execution
*
*/
private static class TestSingleThreadedObservable implements Func1<Observer<String>, Subscription> {
final Subscription s;
final String[] values;
private Thread t = null;
public TestSingleThreadedObservable(final Subscription s, final String... values) {
this.s = s;
this.values = values;
}
public Subscription call(final Observer<String> observer) {
System.out.println("TestSingleThreadedObservable subscribed to ...");
t = new Thread(new Runnable() {
@Override
public void run() {
try {
System.out.println("running TestSingleThreadedObservable thread");
for (String s : values) {
System.out.println("TestSingleThreadedObservable onNext: " + s);
observer.onNext(s);
}
observer.onCompleted();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
});
System.out.println("starting TestSingleThreadedObservable thread");
t.start();
System.out.println("done starting TestSingleThreadedObservable thread");
return s;
}
public void waitToFinish() {
try {
t.join();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
}
/**
* This spawns a thread for the subscription, then a separate thread for each onNext call.
*
*/
private static class TestMultiThreadedObservable implements Func1<Observer<String>, Subscription> {
final Subscription s;
final String[] values;
Thread t = null;
AtomicInteger threadsRunning = new AtomicInteger();
AtomicInteger maxConcurrentThreads = new AtomicInteger();
ExecutorService threadPool;
public TestMultiThreadedObservable(Subscription s, String... values) {
this.s = s;
this.values = values;
this.threadPool = Executors.newCachedThreadPool();
}
@Override
public Subscription call(final Observer<String> observer) {
System.out.println("TestMultiThreadedObservable subscribed to ...");
t = new Thread(new Runnable() {
@Override
public void run() {
try {
System.out.println("running TestMultiThreadedObservable thread");
for (final String s : values) {
threadPool.execute(new Runnable() {
@Override
public void run() {
threadsRunning.incrementAndGet();
try {
// perform onNext call
System.out.println("TestMultiThreadedObservable onNext: " + s);
if (s == null) {
// force an error
throw new NullPointerException();
}
observer.onNext(s);
// capture 'maxThreads'
int concurrentThreads = threadsRunning.get();
int maxThreads = maxConcurrentThreads.get();
if (concurrentThreads > maxThreads) {
maxConcurrentThreads.compareAndSet(maxThreads, concurrentThreads);
}
} catch (Exception e) {
observer.onError(e);
} finally {
threadsRunning.decrementAndGet();
}
}
});
}
// we are done spawning threads
threadPool.shutdown();
} catch (Exception e) {
throw new RuntimeException(e);
}
// wait until all threads are done, then mark it as COMPLETED
try {
// wait for all the threads to finish
threadPool.awaitTermination(2, TimeUnit.SECONDS);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
observer.onCompleted();
}
});
System.out.println("starting TestMultiThreadedObservable thread");
t.start();
System.out.println("done starting TestMultiThreadedObservable thread");
return s;
}
public void waitToFinish() {
try {
t.join();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
}
private static class BusyObserver implements Observer<String> {
volatile boolean onCompleted = false;
volatile boolean onError = false;
AtomicInteger onNextCount = new AtomicInteger();
AtomicInteger threadsRunning = new AtomicInteger();
AtomicInteger maxConcurrentThreads = new AtomicInteger();
@Override
public void onCompleted() {
System.out.println(">>> BusyObserver received onCompleted");
onCompleted = true;
}
@Override
public void onError(Exception e) {
System.out.println(">>> BusyObserver received onError: " + e.getMessage());
onError = true;
}
@Override
public void onNext(String args) {
threadsRunning.incrementAndGet();
try {
onNextCount.incrementAndGet();
System.out.println(">>> BusyObserver received onNext: " + args);
try {
// simulate doing something computational
Thread.sleep(200);
} catch (InterruptedException e) {
e.printStackTrace();
}
} finally {
// capture 'maxThreads'
int concurrentThreads = threadsRunning.get();
int maxThreads = maxConcurrentThreads.get();
if (concurrentThreads > maxThreads) {
maxConcurrentThreads.compareAndSet(maxThreads, concurrentThreads);
}
threadsRunning.decrementAndGet();
}
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment