Skip to content

Instantly share code, notes, and snippets.

@hector6872
Last active February 22, 2023 09:49
Show Gist options
  • Save hector6872/fe5956ec71576213b4408f47f5cd05f8 to your computer and use it in GitHub Desktop.
Save hector6872/fe5956ec71576213b4408f47f5cd05f8 to your computer and use it in GitHub Desktop.
Redux implementation for Kotlin :)
private val LocalStore: ProvidableCompositionLocal<Store<*>> = compositionLocalOf { error("Store not provided") }
@Composable
fun <STATE : State> StoreProvider(store: Store<STATE>, content: @Composable Store<STATE>.() -> Unit) {
CompositionLocalProvider(LocalStore provides store) {
store.content()
}
}
@Composable
@Suppress("UNCHECKED_CAST")
fun <STATE : State> store(): Store<STATE> = LocalStore.current as Store<STATE>
@Composable
inline fun <reified STATE : State, VALUE> subscribe(
crossinline selector: @DisallowComposableCalls STATE.() -> VALUE
): MutableState<VALUE> = store<STATE>().subscribe(selector)
@Composable
inline fun <STATE : State, VALUE> Store<STATE>.subscribe(
crossinline selector: @DisallowComposableCalls (STATE.() -> VALUE)
): MutableState<VALUE> {
val result: MutableState<VALUE> = remember { mutableStateOf(state.selector()) }
DisposableEffect(result) {
val unsubscribe: Unsubscribe = subscribe { state, _ -> result.value = state.selector() }
onDispose(unsubscribe)
}
return result
}
@Composable
inline fun <reified STATE : State> subscribe(): MutableState<STATE> = store<STATE>().subscribe()
@Composable
inline fun <STATE : State> Store<STATE>.subscribe(): MutableState<STATE> {
val result: MutableState<STATE> = remember { mutableStateOf(state) }
DisposableEffect(result) {
val unsubscribe: Unsubscribe = subscribe { state, _ -> result.value = state }
onDispose(unsubscribe)
}
return result
}
val fakeReducer: FakeReducer = FakeReducer()
val fakeMiddleware: FakeMiddleware = FakeMiddleware()
data class TestState(val counter: Int? = null) : State
val Store<TestState>.counter: Int? get() = this.state.counter
fun counterBinder(old: TestState, new: TestState): Boolean = old.counter != new.counter
fun createFakeStore() = createStore(
initialState = TestState(),
reducers = combineReducers(fakeReducer::reduce),
middlewares = combineMiddlewares(fakeMiddleware::run)
)
sealed class TestActions : Action {
object IncrementByOne : TestActions()
class IncrementByOneAsync(val coroutineScopeForTestingUse: CoroutineScope) : TestActions()
}
class FakeReducer {
fun reduce(state: TestState, action: Action): TestState = when (action) {
IncrementByOne -> state.copy(counter = (state.counter ?: 0) + 1)
else -> state
}
}
class FakeMiddleware {
fun run(state: TestState, action: Action, dispatcher: Dispatcher, next: NextMiddleware<TestState>): Action {
when (action) {
is IncrementByOneAsync -> action.coroutineScopeForTestingUse.launch {
dispatcher(IncrementByOne)
}
}
return next(state, action, dispatcher)
}
}
@OptIn(ExperimentalCoroutinesApi::class)
class SameThreadEnforcedStoreTest {
@Test(expected = IllegalStateException::class)
fun `run an Action from a thread other than thread from which the SameThreadEnforcedStore was created will throw an IllegalStateException`() {
val sameThreadEnforcedStore = createSameThreadEnforcedStore(createFakeStore())
val storeThread = currentThreadName()
runTest {
withContext(Dispatchers.Default) {
assertThat(currentThreadName()).isNotEqualTo(storeThread)
sameThreadEnforcedStore.dispatch(IncrementByOne)
}
}
}
@Test
fun `run an Action from a thread other than thread from which the Store was created will not throw an exception`() {
val store = createFakeStore()
val storeThread = currentThreadName()
runTest {
withContext(Dispatchers.Default) {
assertThat(currentThreadName()).isNotEqualTo(storeThread)
store.dispatch(IncrementByOne)
}
}
}
}
private fun currentThreadName(): String = Thread.currentThread().name
interface Action
interface State
typealias Binder<STATE> = (STATE, STATE) -> Boolean
private fun <STATE> defaultBinder(old: STATE, new: STATE): Boolean = true
typealias Dispatcher = (Action) -> Unit
typealias Reducer<STATE> = (STATE, Action) -> STATE
fun <STATE> combineReducers(vararg reducers: Reducer<STATE>): List<Reducer<STATE>> = listOf(*reducers)
typealias Middleware<STATE> = (STATE, Action, Dispatcher, NextMiddleware<STATE>) -> Action
typealias NextMiddleware<STATE> = (STATE, Action, Dispatcher) -> Action
fun <STATE> combineMiddlewares(vararg middlewares: Middleware<STATE>): List<Middleware<STATE>> = listOf(*middlewares)
internal typealias Subscription<STATE> = (state: STATE, binder: Binder<STATE>) -> Unit
typealias Unsubscribe = () -> Unit
private class SubscriptionWithBinder<STATE>(val binder: Binder<STATE>, val subscription: Subscription<STATE>?)
interface Store<STATE : State> {
val state: STATE
fun subscribe(subscription: Subscription<STATE>): Unsubscribe
fun subscribe(binder: Binder<STATE>, subscription: Subscription<STATE>): Unsubscribe
fun subscribe(binders: List<Binder<STATE>>, subscription: Subscription<STATE>): Unsubscribe
fun dispatch(action: Action)
fun dispatch(vararg actions: Action)
fun rehydrate(newState: STATE)
}
/**
* Creates a NON-THREADSAFE store. NOT RECOMMENDED.
* It is highly recommended to use [createThreadSafeStore] to ensure thread safety
*/
@Suppress("ObjectPropertyName")
fun <STATE : State> createStore(
initialState: STATE,
reducers: List<Reducer<STATE>> = listOf(),
middlewares: List<Middleware<STATE>> = listOf()
): Store<STATE> = object : Store<STATE> {
private val _state: AtomicReference<STATE> = AtomicReference(initialState)
override val state: STATE get() = _state.get()
private val actions = LinkedBlockingQueue<Action>()
private val subscriptions = CopyOnWriteArrayList<SubscriptionWithBinder<STATE>>()
override fun subscribe(subscription: Subscription<STATE>): Unsubscribe = subscribe(::defaultBinder, subscription)
override fun subscribe(binder: Binder<STATE>, subscription: Subscription<STATE>): Unsubscribe {
val subscriptionWithBinder = SubscriptionWithBinder(binder, subscription)
subscriptions.add(subscriptionWithBinder)
subscription(state, binder)
return { subscriptions.remove(subscriptionWithBinder) }
}
override fun subscribe(binders: List<Binder<STATE>>, subscription: Subscription<STATE>): Unsubscribe {
fun consumer(subscriptions: List<Unsubscribe>): Unsubscribe = { subscriptions.forEach { it.invoke() } }
return consumer(binders.map { subscribe(it, subscription) })
}
@Suppress("RemoveRedundantSpreadOperator")
override fun dispatch(action: Action) = dispatch(*arrayOf(action))
override fun dispatch(vararg actions: Action) {
listOf(*actions).forEach { action -> this.actions.offer(action) }
@Suppress("ForEachParameterNotUsed")
this.actions.iterator().forEach { handle(this.actions.poll()) }
}
override fun rehydrate(newState: STATE) = notify(newState)
private fun handle(action: Action?) {
action ?: return
notify(reduce(state, dispatchToMiddleware(action)))
}
private fun reduce(current: STATE, action: Action): STATE = reducers.fold(current) { state, reducer -> reducer(state, action) }
private fun dispatchToMiddleware(action: Action): Action = next(0)(state, action, ::dispatch)
private fun next(index: Int): NextMiddleware<STATE> = when (index) {
middlewares.size -> { _, action, _ -> action }
else -> { state, action, dispatch -> middlewares[index](state, action, dispatch, next(index + 1)) }
}
private fun notify(newState: STATE) {
subscriptions.reversed().filter { it.binder(state, newState) }.forEach { it.subscription?.invoke(newState, it.binder) }
_state.set(newState)
}
}
class StoreTest {
private val store: Store<TestState> = createFakeStore()
@Test
fun `get initial status upon subscription`() {
store.subscribe { state, _ -> assertThat(state.counter).isNull() }
}
@Test
fun `subscribe with no binding`() = store.run {
subscribe { state, _ -> state.counter?.let { assertThat(state.counter).isEqualTo(1) } }
dispatch(IncrementByOne)
}
@Test
fun `subscribe via binding`() = store.run {
subscribe(::counterBinder) { state, binder ->
assertThat(binder).isEqualTo(::counterBinder)
state.counter?.let { assertThat(state.counter).isEqualTo(1) }
}
dispatch(IncrementByOne)
}
@Test
fun `subscribe via a list of bindings`() = store.run {
subscribe(listOf(::counterBinder)) { state, binder ->
assertThat(binder).isEqualTo(::counterBinder)
state.counter?.let { assertThat(state.counter).isEqualTo(1) }
}
dispatch(IncrementByOne)
}
@Test
fun `rehydrate state`() {
store.run {
subscribe { state, _ -> state.counter?.let { assertThat(state.counter).isEqualTo(100) } }
rehydrate(TestState(counter = 100))
assertThat(counter).isEqualTo(100)
}
}
}
/**
* Threadsafe decorator that synchronizes access to each function
* This may have performance impact on JVM/Native.
*/
fun <STATE : State> createThreadSafeStore(store: Store<STATE>): ThreadSafeStore<STATE> = ThreadSafeStore(store)
fun <STATE : State> createThreadSafeStore(
initialState: STATE,
reducers: List<Reducer<STATE>> = listOf(),
middlewares: List<Middleware<STATE>> = listOf()
): ThreadSafeStore<STATE> = ThreadSafeStore(createStore(initialState, reducers, middlewares))
class ThreadSafeStore<STATE : State>(private val store: Store<STATE>) : Store<STATE> by store {
@get:Synchronized
override val state
get() = synchronized(this) { store.state }
@Synchronized
override fun subscribe(subscription: Subscription<STATE>): Unsubscribe = synchronized(this) { store.subscribe(subscription) }
@Synchronized
override fun subscribe(binder: Binder<STATE>, subscription: Subscription<STATE>): Unsubscribe = synchronized(this) { store.subscribe(binder, subscription) }
@Synchronized
override fun subscribe(binders: List<Binder<STATE>>, subscription: Subscription<STATE>): Unsubscribe = synchronized(this) { store.subscribe(binders, subscription) }
@Synchronized
override fun dispatch(action: Action) = synchronized(this) { store.dispatch(action) }
@Synchronized
override fun dispatch(vararg actions: Action) = synchronized(this) { store.dispatch(*actions) }
@Synchronized
override fun rehydrate(newState: STATE) = synchronized(this) { store.rehydrate(newState) }
}
/**
* Decorator for store whose functions can only be accessed from the same thread in which the Store was created
* Functions called from a thread other than thread from which the Store was created will throw an IllegalStateException
*/
fun <STATE : State> createSameThreadEnforcedStore(store: Store<STATE>): SameThreadEnforcedStore<STATE> = SameThreadEnforcedStore(store)
fun <STATE : State> createSameThreadEnforcedStore(
initialState: STATE,
reducers: List<Reducer<STATE>> = listOf(),
middlewares: List<Middleware<STATE>> = listOf()
): SameThreadEnforcedStore<STATE> = SameThreadEnforcedStore(createStore(initialState, reducers, middlewares))
class SameThreadEnforcedStore<STATE : State>(private val store: Store<STATE>) : Store<STATE> by store {
private val storeThreadName = currentThreadName()
override fun subscribe(subscription: Subscription<STATE>): Unsubscribe {
checkIsSameThread()
return store.subscribe(subscription)
}
override fun subscribe(binder: Binder<STATE>, subscription: Subscription<STATE>): Unsubscribe {
checkIsSameThread()
return store.subscribe(binder, subscription)
}
override fun subscribe(binders: List<Binder<STATE>>, subscription: Subscription<STATE>): Unsubscribe {
checkIsSameThread()
return store.subscribe(binders, subscription)
}
override fun dispatch(action: Action) {
checkIsSameThread()
store.dispatch(action)
}
override fun dispatch(vararg actions: Action) {
checkIsSameThread()
store.dispatch(*actions)
}
override fun rehydrate(newState: STATE) {
checkIsSameThread()
store.rehydrate(newState)
}
private fun currentThreadName(): String = Thread.currentThread().name.stripCoroutineName()
private fun isSameThread() = storeThreadName.equals(currentThreadName(), ignoreCase = true)
private fun checkIsSameThread() = check(isSameThread()) {
"""You may not call the store from a thread other than the thread on which it was created.
This store was created on: '$storeThreadName' and current thread is '${currentThreadName()}'
""".trimMargin()
}
/**
* Thread name may have '@coroutine#n' appended to it
* https://kotlinlang.org/docs/coroutine-context-and-dispatchers.html#debugging-using-logging
*/
private fun String.stripCoroutineName(): String {
val lastIndex = this.lastIndexOf('@')
return (if (lastIndex < 0) this else this.substring(0, lastIndex)).trim()
}
}
/**
* test results (performance impact example):
* thread safe: completed 100000 actions in 122 ms
* non-thread safe: completed 100000 actions in 60 ms
*/
@OptIn(ExperimentalCoroutinesApi::class)
class ThreadSafeStoreTest {
@Test
fun `massive action dispatching using a non-thread safe Store should return an unexpected state`() {
val store = createFakeStore()
runTest {
withContext(Dispatchers.Default) {
massiveRun("non-thread safe", SYNC_NUM_COROUTINES, SYNC_NUM_REPEATS) { store.dispatch(IncrementByOne) }
assertThat(store.counter).isNotEqualTo(SYNC_NUM_COROUTINES * SYNC_NUM_REPEATS)
}
}
}
@Test
fun `massive action dispatching using a thread safe Store should return an expected state`() {
val threadSafeStore = createThreadSafeStore(createFakeStore())
runTest {
withContext(Dispatchers.Default) {
massiveRun("thread safe", SYNC_NUM_COROUTINES, SYNC_NUM_REPEATS) { threadSafeStore.dispatch(IncrementByOne) }
assertThat(threadSafeStore.counter).isEqualTo(SYNC_NUM_COROUTINES * SYNC_NUM_REPEATS)
}
}
}
}
private suspend fun massiveRun(tag: String, numCoroutines: Int, numRepeats: Int, block: suspend () -> Unit) {
val time = measureTimeMillis {
coroutineScope {
repeat(numCoroutines) {
launch {
repeat(numRepeats) { block() }
}
}
}
}
println("$tag: completed ${numCoroutines * numRepeats} actions in $time ms")
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment