Created
October 30, 2014 23:22
-
-
Save elandau/38a28ffab5ad6566f166 to your computer and use it in GitHub Desktop.
Rx based state machine
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.netflix.experiments.rx; | |
import java.util.HashMap; | |
import java.util.Map; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import rx.Observable; | |
import rx.Observable.OnSubscribe; | |
import rx.Subscriber; | |
import rx.functions.Action1; | |
import rx.functions.Action2; | |
import rx.subjects.PublishSubject; | |
public class StateMachine<T, E> implements Action1<E> { | |
private static final Logger LOG = LoggerFactory.getLogger(StateMachine.class); | |
public static class State<T, E> { | |
private String name; | |
private Action2<T, State<T, E>> enter; | |
private Action2<T, State<T, E>> exit; | |
private Map<E, State<T, E>> transitions = new HashMap<E, State<T, E>>(); | |
public State(String name) { | |
this.name = name; | |
} | |
public State<T, E> onEnter(Action2<T, State<T, E>> func) { | |
this.enter = func; | |
return this; | |
} | |
public State<T, E> onExit(Action2<T, State<T, E>> func) { | |
this.exit = func; | |
return this; | |
} | |
public void enter(T context) { | |
enter.call(context, this); | |
} | |
public void exit(T context) { | |
exit.call(context, this); | |
} | |
public State<T, E> transition(E event, State<T, E> state) { | |
transitions.put(event, state); | |
return this; | |
} | |
public State<T, E> next(E event) { | |
return transitions.get(event); | |
} | |
public String toString() { | |
return name; | |
} | |
} | |
private volatile State<T, E> state; | |
private final T context; | |
private final PublishSubject<E> events = PublishSubject.create(); | |
protected StateMachine(T context, State<T, E> initial) { | |
this.state = initial; | |
this.context = context; | |
} | |
public Observable<Void> connect() { | |
return Observable.create(new OnSubscribe<Void>() { | |
@Override | |
public void call(Subscriber<? super Void> sub) { | |
state.enter(context); | |
sub.add(events.collect(context, new Action2<T, E>() { | |
@Override | |
public void call(T context, E event) { | |
final State<T, E> next = state.next(event); | |
if (next != null) { | |
state.exit(context); | |
state = next; | |
next.enter(context); | |
} | |
else { | |
LOG.info("Invalid event : " + event); | |
} | |
} | |
}) | |
.subscribe()); | |
} | |
}); | |
} | |
@Override | |
public void call(E event) { | |
events.onNext(event); | |
} | |
public State<T, E> getState() { | |
return state; | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.netflix.experiments.rx; | |
import org.junit.BeforeClass; | |
import org.junit.Test; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import rx.functions.Action2; | |
import com.netflix.experiments.rx.StateMachine.State; | |
public class StateMachineTest { | |
private static final Logger LOG = LoggerFactory.getLogger(StateMachineTest.class); | |
public static enum Event { | |
IDLE, | |
CONNECT, | |
CONNECTED, | |
FAILED, | |
UNQUARANTINE, | |
REMOVE | |
} | |
public static Action2<SomeContext, State<SomeContext, Event>> log(final String text) { | |
return new Action2<SomeContext, State<SomeContext, Event>>() { | |
@Override | |
public void call(SomeContext t1, State<SomeContext, Event> state) { | |
LOG.info("" + t1 + ":" + state + ":" + text); | |
} | |
}; | |
} | |
public static class SomeContext { | |
@Override | |
public String toString() { | |
return "Foo []"; | |
} | |
} | |
public static State<SomeContext, Event> IDLE = new State<SomeContext, Event>("IDLE"); | |
public static State<SomeContext, Event> CONNECTING = new State<SomeContext, Event>("CONNECTING"); | |
public static State<SomeContext, Event> CONNECTED = new State<SomeContext, Event>("CONNECTED"); | |
public static State<SomeContext, Event> QUARANTINED = new State<SomeContext, Event>("QUARANTINED"); | |
public static State<SomeContext, Event> REMOVED = new State<SomeContext, Event>("REMOVED"); | |
@BeforeClass | |
public static void beforeClass() { | |
IDLE | |
.onEnter(log("enter")) | |
.onExit(log("exit")) | |
.transition(Event.CONNECT, CONNECTING) | |
.transition(Event.REMOVE, REMOVED); | |
CONNECTING | |
.onEnter(log("enter")) | |
.onExit(log("exit")) | |
.transition(Event.CONNECTED, CONNECTED) | |
.transition(Event.FAILED, QUARANTINED) | |
.transition(Event.REMOVE, REMOVED); | |
CONNECTED | |
.onEnter(log("enter")) | |
.onExit(log("exit")) | |
.transition(Event.IDLE, IDLE) | |
.transition(Event.FAILED, QUARANTINED) | |
.transition(Event.REMOVE, REMOVED); | |
QUARANTINED | |
.onEnter(log("enter")) | |
.onExit(log("exit")) | |
.transition(Event.IDLE, IDLE) | |
.transition(Event.REMOVE, REMOVED); | |
REMOVED | |
.onEnter(log("enter")) | |
.onExit(log("exit")) | |
.transition(Event.CONNECT, CONNECTING); | |
} | |
@Test | |
public void test() { | |
StateMachine<SomeContext, Event> sm = new StateMachine<SomeContext, Event>(new SomeContext(), IDLE); | |
sm.connect().subscribe(); | |
sm.call(Event.CONNECT); | |
sm.call(Event.CONNECTED); | |
sm.call(Event.FAILED); | |
sm.call(Event.REMOVE); | |
} | |
} |
thanks a lot, very impressive.
Converted to RXJava2 and Kotlin:
public class State<T, E>(val name: String) {
private var enter: BiConsumer<T, State<T, E>>? = null
private var exit: BiConsumer<T, State<T, E>>? = null
private var transitions = mutableMapOf<E, State<T, E>>()
public fun onEnter(func: BiConsumer<T, State<T, E>>): State<T, E> {
this.enter = func
return this
}
public fun onExit(func: BiConsumer<T, State<T, E>>): State<T, E> {
this.exit = func
return this
}
public fun enter(context: T) {
enter?.accept(context, this)
}
public fun exit(context: T) {
exit?.accept(context, this)
}
public fun transition(event: E, state: State<T, E>): State<T, E> {
transitions[event] = state
return this
}
public fun next(event: E): State<T, E>? {
return transitions[event]
}
override fun toString(): String {
return name
}
}
public class RxStateMachine<T, E>(val context: T, private val initialState: State<T, E>) :
Consumer<E> {
private val events = PublishSubject.create<E>()
var state: State<T, E> = initialState
fun connect(): Observable<Unit> {
return Observable.create {
state.enter(context)
it.setDisposable(events.collect({
context
}, { context: T, event: E ->
val next = state.next(event)
next?.let {
state.exit(context)
state = next
next.enter(context)
} ?: run {
Log.e("STATE", "Invalid Event: $event")
}
}).subscribe())
}
}
override fun accept(t: E) {
events.onNext(t)
}
}
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Converted to RXJava2 and Java8 (with log4j):