Last active
February 22, 2023 09:49
-
-
Save hector6872/fe5956ec71576213b4408f47f5cd05f8 to your computer and use it in GitHub Desktop.
Redux implementation for Kotlin :)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
private val LocalStore: ProvidableCompositionLocal<Store<*>> = compositionLocalOf { error("Store not provided") } | |
@Composable | |
fun <STATE : State> StoreProvider(store: Store<STATE>, content: @Composable Store<STATE>.() -> Unit) { | |
CompositionLocalProvider(LocalStore provides store) { | |
store.content() | |
} | |
} | |
@Composable | |
@Suppress("UNCHECKED_CAST") | |
fun <STATE : State> store(): Store<STATE> = LocalStore.current as Store<STATE> | |
@Composable | |
inline fun <reified STATE : State, VALUE> subscribe( | |
crossinline selector: @DisallowComposableCalls STATE.() -> VALUE | |
): MutableState<VALUE> = store<STATE>().subscribe(selector) | |
@Composable | |
inline fun <STATE : State, VALUE> Store<STATE>.subscribe( | |
crossinline selector: @DisallowComposableCalls (STATE.() -> VALUE) | |
): MutableState<VALUE> { | |
val result: MutableState<VALUE> = remember { mutableStateOf(state.selector()) } | |
DisposableEffect(result) { | |
val unsubscribe: Unsubscribe = subscribe { state, _ -> result.value = state.selector() } | |
onDispose(unsubscribe) | |
} | |
return result | |
} | |
@Composable | |
inline fun <reified STATE : State> subscribe(): MutableState<STATE> = store<STATE>().subscribe() | |
@Composable | |
inline fun <STATE : State> Store<STATE>.subscribe(): MutableState<STATE> { | |
val result: MutableState<STATE> = remember { mutableStateOf(state) } | |
DisposableEffect(result) { | |
val unsubscribe: Unsubscribe = subscribe { state, _ -> result.value = state } | |
onDispose(unsubscribe) | |
} | |
return result | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
val fakeReducer: FakeReducer = FakeReducer() | |
val fakeMiddleware: FakeMiddleware = FakeMiddleware() | |
data class TestState(val counter: Int? = null) : State | |
val Store<TestState>.counter: Int? get() = this.state.counter | |
fun counterBinder(old: TestState, new: TestState): Boolean = old.counter != new.counter | |
fun createFakeStore() = createStore( | |
initialState = TestState(), | |
reducers = combineReducers(fakeReducer::reduce), | |
middlewares = combineMiddlewares(fakeMiddleware::run) | |
) | |
sealed class TestActions : Action { | |
object IncrementByOne : TestActions() | |
class IncrementByOneAsync(val coroutineScopeForTestingUse: CoroutineScope) : TestActions() | |
} | |
class FakeReducer { | |
fun reduce(state: TestState, action: Action): TestState = when (action) { | |
IncrementByOne -> state.copy(counter = (state.counter ?: 0) + 1) | |
else -> state | |
} | |
} | |
class FakeMiddleware { | |
fun run(state: TestState, action: Action, dispatcher: Dispatcher, next: NextMiddleware<TestState>): Action { | |
when (action) { | |
is IncrementByOneAsync -> action.coroutineScopeForTestingUse.launch { | |
dispatcher(IncrementByOne) | |
} | |
} | |
return next(state, action, dispatcher) | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@OptIn(ExperimentalCoroutinesApi::class) | |
class SameThreadEnforcedStoreTest { | |
@Test(expected = IllegalStateException::class) | |
fun `run an Action from a thread other than thread from which the SameThreadEnforcedStore was created will throw an IllegalStateException`() { | |
val sameThreadEnforcedStore = createSameThreadEnforcedStore(createFakeStore()) | |
val storeThread = currentThreadName() | |
runTest { | |
withContext(Dispatchers.Default) { | |
assertThat(currentThreadName()).isNotEqualTo(storeThread) | |
sameThreadEnforcedStore.dispatch(IncrementByOne) | |
} | |
} | |
} | |
@Test | |
fun `run an Action from a thread other than thread from which the Store was created will not throw an exception`() { | |
val store = createFakeStore() | |
val storeThread = currentThreadName() | |
runTest { | |
withContext(Dispatchers.Default) { | |
assertThat(currentThreadName()).isNotEqualTo(storeThread) | |
store.dispatch(IncrementByOne) | |
} | |
} | |
} | |
} | |
private fun currentThreadName(): String = Thread.currentThread().name |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
interface Action | |
interface State | |
typealias Binder<STATE> = (STATE, STATE) -> Boolean | |
private fun <STATE> defaultBinder(old: STATE, new: STATE): Boolean = true | |
typealias Dispatcher = (Action) -> Unit | |
typealias Reducer<STATE> = (STATE, Action) -> STATE | |
fun <STATE> combineReducers(vararg reducers: Reducer<STATE>): List<Reducer<STATE>> = listOf(*reducers) | |
typealias Middleware<STATE> = (STATE, Action, Dispatcher, NextMiddleware<STATE>) -> Action | |
typealias NextMiddleware<STATE> = (STATE, Action, Dispatcher) -> Action | |
fun <STATE> combineMiddlewares(vararg middlewares: Middleware<STATE>): List<Middleware<STATE>> = listOf(*middlewares) | |
internal typealias Subscription<STATE> = (state: STATE, binder: Binder<STATE>) -> Unit | |
typealias Unsubscribe = () -> Unit | |
private class SubscriptionWithBinder<STATE>(val binder: Binder<STATE>, val subscription: Subscription<STATE>?) | |
interface Store<STATE : State> { | |
val state: STATE | |
fun subscribe(subscription: Subscription<STATE>): Unsubscribe | |
fun subscribe(binder: Binder<STATE>, subscription: Subscription<STATE>): Unsubscribe | |
fun subscribe(binders: List<Binder<STATE>>, subscription: Subscription<STATE>): Unsubscribe | |
fun dispatch(action: Action) | |
fun dispatch(vararg actions: Action) | |
fun rehydrate(newState: STATE) | |
} | |
/** | |
* Creates a NON-THREADSAFE store. NOT RECOMMENDED. | |
* It is highly recommended to use [createThreadSafeStore] to ensure thread safety | |
*/ | |
@Suppress("ObjectPropertyName") | |
fun <STATE : State> createStore( | |
initialState: STATE, | |
reducers: List<Reducer<STATE>> = listOf(), | |
middlewares: List<Middleware<STATE>> = listOf() | |
): Store<STATE> = object : Store<STATE> { | |
private val _state: AtomicReference<STATE> = AtomicReference(initialState) | |
override val state: STATE get() = _state.get() | |
private val actions = LinkedBlockingQueue<Action>() | |
private val subscriptions = CopyOnWriteArrayList<SubscriptionWithBinder<STATE>>() | |
override fun subscribe(subscription: Subscription<STATE>): Unsubscribe = subscribe(::defaultBinder, subscription) | |
override fun subscribe(binder: Binder<STATE>, subscription: Subscription<STATE>): Unsubscribe { | |
val subscriptionWithBinder = SubscriptionWithBinder(binder, subscription) | |
subscriptions.add(subscriptionWithBinder) | |
subscription(state, binder) | |
return { subscriptions.remove(subscriptionWithBinder) } | |
} | |
override fun subscribe(binders: List<Binder<STATE>>, subscription: Subscription<STATE>): Unsubscribe { | |
fun consumer(subscriptions: List<Unsubscribe>): Unsubscribe = { subscriptions.forEach { it.invoke() } } | |
return consumer(binders.map { subscribe(it, subscription) }) | |
} | |
@Suppress("RemoveRedundantSpreadOperator") | |
override fun dispatch(action: Action) = dispatch(*arrayOf(action)) | |
override fun dispatch(vararg actions: Action) { | |
listOf(*actions).forEach { action -> this.actions.offer(action) } | |
@Suppress("ForEachParameterNotUsed") | |
this.actions.iterator().forEach { handle(this.actions.poll()) } | |
} | |
override fun rehydrate(newState: STATE) = notify(newState) | |
private fun handle(action: Action?) { | |
action ?: return | |
notify(reduce(state, dispatchToMiddleware(action))) | |
} | |
private fun reduce(current: STATE, action: Action): STATE = reducers.fold(current) { state, reducer -> reducer(state, action) } | |
private fun dispatchToMiddleware(action: Action): Action = next(0)(state, action, ::dispatch) | |
private fun next(index: Int): NextMiddleware<STATE> = when (index) { | |
middlewares.size -> { _, action, _ -> action } | |
else -> { state, action, dispatch -> middlewares[index](state, action, dispatch, next(index + 1)) } | |
} | |
private fun notify(newState: STATE) { | |
subscriptions.reversed().filter { it.binder(state, newState) }.forEach { it.subscription?.invoke(newState, it.binder) } | |
_state.set(newState) | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class StoreTest { | |
private val store: Store<TestState> = createFakeStore() | |
@Test | |
fun `get initial status upon subscription`() { | |
store.subscribe { state, _ -> assertThat(state.counter).isNull() } | |
} | |
@Test | |
fun `subscribe with no binding`() = store.run { | |
subscribe { state, _ -> state.counter?.let { assertThat(state.counter).isEqualTo(1) } } | |
dispatch(IncrementByOne) | |
} | |
@Test | |
fun `subscribe via binding`() = store.run { | |
subscribe(::counterBinder) { state, binder -> | |
assertThat(binder).isEqualTo(::counterBinder) | |
state.counter?.let { assertThat(state.counter).isEqualTo(1) } | |
} | |
dispatch(IncrementByOne) | |
} | |
@Test | |
fun `subscribe via a list of bindings`() = store.run { | |
subscribe(listOf(::counterBinder)) { state, binder -> | |
assertThat(binder).isEqualTo(::counterBinder) | |
state.counter?.let { assertThat(state.counter).isEqualTo(1) } | |
} | |
dispatch(IncrementByOne) | |
} | |
@Test | |
fun `rehydrate state`() { | |
store.run { | |
subscribe { state, _ -> state.counter?.let { assertThat(state.counter).isEqualTo(100) } } | |
rehydrate(TestState(counter = 100)) | |
assertThat(counter).isEqualTo(100) | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* Threadsafe decorator that synchronizes access to each function | |
* This may have performance impact on JVM/Native. | |
*/ | |
fun <STATE : State> createThreadSafeStore(store: Store<STATE>): ThreadSafeStore<STATE> = ThreadSafeStore(store) | |
fun <STATE : State> createThreadSafeStore( | |
initialState: STATE, | |
reducers: List<Reducer<STATE>> = listOf(), | |
middlewares: List<Middleware<STATE>> = listOf() | |
): ThreadSafeStore<STATE> = ThreadSafeStore(createStore(initialState, reducers, middlewares)) | |
class ThreadSafeStore<STATE : State>(private val store: Store<STATE>) : Store<STATE> by store { | |
@get:Synchronized | |
override val state | |
get() = synchronized(this) { store.state } | |
@Synchronized | |
override fun subscribe(subscription: Subscription<STATE>): Unsubscribe = synchronized(this) { store.subscribe(subscription) } | |
@Synchronized | |
override fun subscribe(binder: Binder<STATE>, subscription: Subscription<STATE>): Unsubscribe = synchronized(this) { store.subscribe(binder, subscription) } | |
@Synchronized | |
override fun subscribe(binders: List<Binder<STATE>>, subscription: Subscription<STATE>): Unsubscribe = synchronized(this) { store.subscribe(binders, subscription) } | |
@Synchronized | |
override fun dispatch(action: Action) = synchronized(this) { store.dispatch(action) } | |
@Synchronized | |
override fun dispatch(vararg actions: Action) = synchronized(this) { store.dispatch(*actions) } | |
@Synchronized | |
override fun rehydrate(newState: STATE) = synchronized(this) { store.rehydrate(newState) } | |
} | |
/** | |
* Decorator for store whose functions can only be accessed from the same thread in which the Store was created | |
* Functions called from a thread other than thread from which the Store was created will throw an IllegalStateException | |
*/ | |
fun <STATE : State> createSameThreadEnforcedStore(store: Store<STATE>): SameThreadEnforcedStore<STATE> = SameThreadEnforcedStore(store) | |
fun <STATE : State> createSameThreadEnforcedStore( | |
initialState: STATE, | |
reducers: List<Reducer<STATE>> = listOf(), | |
middlewares: List<Middleware<STATE>> = listOf() | |
): SameThreadEnforcedStore<STATE> = SameThreadEnforcedStore(createStore(initialState, reducers, middlewares)) | |
class SameThreadEnforcedStore<STATE : State>(private val store: Store<STATE>) : Store<STATE> by store { | |
private val storeThreadName = currentThreadName() | |
override fun subscribe(subscription: Subscription<STATE>): Unsubscribe { | |
checkIsSameThread() | |
return store.subscribe(subscription) | |
} | |
override fun subscribe(binder: Binder<STATE>, subscription: Subscription<STATE>): Unsubscribe { | |
checkIsSameThread() | |
return store.subscribe(binder, subscription) | |
} | |
override fun subscribe(binders: List<Binder<STATE>>, subscription: Subscription<STATE>): Unsubscribe { | |
checkIsSameThread() | |
return store.subscribe(binders, subscription) | |
} | |
override fun dispatch(action: Action) { | |
checkIsSameThread() | |
store.dispatch(action) | |
} | |
override fun dispatch(vararg actions: Action) { | |
checkIsSameThread() | |
store.dispatch(*actions) | |
} | |
override fun rehydrate(newState: STATE) { | |
checkIsSameThread() | |
store.rehydrate(newState) | |
} | |
private fun currentThreadName(): String = Thread.currentThread().name.stripCoroutineName() | |
private fun isSameThread() = storeThreadName.equals(currentThreadName(), ignoreCase = true) | |
private fun checkIsSameThread() = check(isSameThread()) { | |
"""You may not call the store from a thread other than the thread on which it was created. | |
This store was created on: '$storeThreadName' and current thread is '${currentThreadName()}' | |
""".trimMargin() | |
} | |
/** | |
* Thread name may have '@coroutine#n' appended to it | |
* https://kotlinlang.org/docs/coroutine-context-and-dispatchers.html#debugging-using-logging | |
*/ | |
private fun String.stripCoroutineName(): String { | |
val lastIndex = this.lastIndexOf('@') | |
return (if (lastIndex < 0) this else this.substring(0, lastIndex)).trim() | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* test results (performance impact example): | |
* thread safe: completed 100000 actions in 122 ms | |
* non-thread safe: completed 100000 actions in 60 ms | |
*/ | |
@OptIn(ExperimentalCoroutinesApi::class) | |
class ThreadSafeStoreTest { | |
@Test | |
fun `massive action dispatching using a non-thread safe Store should return an unexpected state`() { | |
val store = createFakeStore() | |
runTest { | |
withContext(Dispatchers.Default) { | |
massiveRun("non-thread safe", SYNC_NUM_COROUTINES, SYNC_NUM_REPEATS) { store.dispatch(IncrementByOne) } | |
assertThat(store.counter).isNotEqualTo(SYNC_NUM_COROUTINES * SYNC_NUM_REPEATS) | |
} | |
} | |
} | |
@Test | |
fun `massive action dispatching using a thread safe Store should return an expected state`() { | |
val threadSafeStore = createThreadSafeStore(createFakeStore()) | |
runTest { | |
withContext(Dispatchers.Default) { | |
massiveRun("thread safe", SYNC_NUM_COROUTINES, SYNC_NUM_REPEATS) { threadSafeStore.dispatch(IncrementByOne) } | |
assertThat(threadSafeStore.counter).isEqualTo(SYNC_NUM_COROUTINES * SYNC_NUM_REPEATS) | |
} | |
} | |
} | |
} | |
private suspend fun massiveRun(tag: String, numCoroutines: Int, numRepeats: Int, block: suspend () -> Unit) { | |
val time = measureTimeMillis { | |
coroutineScope { | |
repeat(numCoroutines) { | |
launch { | |
repeat(numRepeats) { block() } | |
} | |
} | |
} | |
} | |
println("$tag: completed ${numCoroutines * numRepeats} actions in $time ms") | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment