Skip to content

Instantly share code, notes, and snippets.

@zach-klippenstein
Last active December 19, 2024 17:47
Show Gist options
  • Save zach-klippenstein/63e17047ade83d78d24130232bf33fa9 to your computer and use it in GitHub Desktop.
Save zach-klippenstein/63e17047ade83d78d24130232bf33fa9 to your computer and use it in GitHub Desktop.
Prototype implementation of TaskEffect (https://issuetracker.google.com/issues/361645776)
@file:Suppress(
"CANNOT_OVERRIDE_INVISIBLE_MEMBER",
"INVISIBLE_MEMBER",
"INVISIBLE_REFERENCE",
)
import androidx.collection.ObjectIntMap
import androidx.compose.runtime.DerivedState
import androidx.compose.runtime.DerivedStateObserver
import androidx.compose.runtime.SnapshotMutationPolicy
import androidx.compose.runtime.derivedStateObservers
import androidx.compose.runtime.derivedStateOf
import androidx.compose.runtime.snapshots.StateObject
import androidx.compose.runtime.structuralEqualityPolicy
/**
* Helpers for working with [derivedStateOf] state objects. All these APIs are internal to the Compose
* runtime, so we can't access them normally. We cheat by disabling internal checks on this file, but
* that's really dangerous since it allows us to access tons of stuff, so these helpers are in their
* own file to isolate from the higher-level code.
*/
internal object DerivedStateHelper {
fun StateObject.asMaybeDerivedState(): DerivedStateProxy<*>? =
if (this is DerivedState<*>) DerivedStateProxy(this)
else null
fun StateObject.asDerivedState(): DerivedStateProxy<*> =
DerivedStateProxy(this as DerivedState<*>)
inline fun <R> observeDerivedStateRecalculations(
crossinline onStart: () -> Unit,
crossinline onDone: () -> Unit,
block: () -> R
) {
val observers = derivedStateObservers()
try {
observers.add(object : DerivedStateObserver {
override fun start(derivedState: DerivedState<*>) {
onStart()
}
override fun done(derivedState: DerivedState<*>) {
onDone()
}
})
block()
} finally {
observers.removeAt(observers.lastIndex)
}
}
/**
* No-overhead wrapper of a [DerivedState] that allows other code in this module to access properties
* of a [DerivedState] without enabling unsafe access.
*/
@JvmInline
value class DerivedStateProxy<T>(private val derivedState: DerivedState<T>) {
val policy: SnapshotMutationPolicy<T>? get() = derivedState.policy
val currentValue: T get() = derivedState.currentRecord.currentValue
val dependencies: ObjectIntMap<StateObject> get() = derivedState.currentRecord.dependencies
/**
* Compares [value] to this state's current value using its [SnapshotMutationPolicy].
*/
fun isCurrentValueEquivalentTo(value: Any?): Boolean {
val policy = policy ?: structuralEqualityPolicy()
@Suppress("UNCHECKED_CAST")
return policy.equivalent(value as T, currentValue)
}
}
}
import android.util.Log
import androidx.compose.runtime.Composable
import androidx.compose.runtime.DisposableEffect
import androidx.compose.runtime.ExperimentalComposeApi
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.snapshotFlow
import androidx.compose.runtime.snapshots.MutableSnapshot
import androidx.compose.runtime.snapshots.Snapshot
import com.squareup.ui.internal.utils.SnapshotObservingElement.CurrentSnapshot
import kotlinx.coroutines.ThreadContextElement
import kotlinx.coroutines.cancelAndJoin
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import kotlin.coroutines.CoroutineContext
/**
* An effect that is similar to a [LaunchedEffect]—it launches a coroutine in the current composition's
* coroutine context—but will automatically cancel the coroutine (if not complete) and restart it when
* any snapshot state objects that have been read in the coroutine are written.
*
* It's a bit like a [LaunchedEffect] combined with a [snapshotFlow], but even better because both
* state reads and suspending function calls can happen mixed in with each other and don't need to be
* separated into separate observation lambdas.
*
* ## Coroutine behavior
*
* As soon as [block] reads a snapshot state object, that state object will be monitored for changes
* outside your coroutine. Monitoring will continue as long as the [TaskEffect] is present in the
* composition, even after the coroutine completes. If a monitored state is changed while the coroutine
* is suspended, the coroutine will first be cancelled and allowed to finish cancellation (including
* any suspend calls in `finally` blocks). Then, once it's fully cancelled, the effect will be
* restarted. If the coroutine had already completed there's nothing to cancel and it will just be
* restarted. Each time the effect is restarted, the set of monitored state objects is cleared, so if
* you read different state objects on the next invocation, those will be monitored going forward. This
* is the same behavior as other restartable functions in Compose (composition, [snapshotFlow], etc.).
*
* ## Snapshot behavior
*
* Unlike other effects ([LaunchedEffect], [DisposableEffect]), but similar to composition,
* [TaskEffect]s run inside a snapshot. This means that during a single invocation of the effect
* function, any changes the effect function makes to state objects won't be seen outside the effect
* until the effect function returns. It also means that any changes to state objects made from outside
* the effect will not be visible in the effect function until it restarts.
*
* If the effect function is restarted before it finishes (e.g. if a state object it read changes while
* it's suspended) then any changes made to state objects up to that point will be discarded and never
* made visible outside the effect.
*
* ## Further reading.
*
* For a more general discussion of “restartable functions” in Compose, see
* [this article](https://blog.zachklipp.com/restartable-functions-from-first-principles/).
*
* There is [an issue](https://issuetracker.google.com/issues/361645776) on the Google tracker to add
* this API, or something similar, to Compose, but it's not clear when it will be implemented.
*
* ## Recipes
*
* If you're using the value of a state object to launch a service call using [snapshotFlow], you can
* replace it with a [TaskEffect] as follows:
* ```kotlin
* // Old:
* @Composable fun MyComposable(viewModel: …) {
* val state by remember { mutableStateOf(…) }
* LaunchedEffect(viewModel) {
* snapshotFlow { state }.collectLatest { state ->
* viewModel.doSomethingSuspending(state)
* }
* }
* }
*
* // New:
* @Composable fun MyComposable(viewModel: …) {
* val state by remember { mutableStateOf(…) }
* TaskEffect(viewModel) {
* viewModel.doSomethingSuspending(state)
* }
* }
* ```
*
* The benefit of this conversion is even more apparent when you're reading multiple states:
* ```kotlin
* // Old:
* @Composable fun MyComposable(viewModel: …) {
* …
* LaunchedEffect(viewModel) {
* snapshotFlow {
* Triple(state1, state2, state3)
* }.collectLatest { (state1, state2, state3) ->
* viewModel.doSomethingSuspending(state1, state2, state3)
* }
* }
* }
*
* // New:
* @Composable fun MyComposable(viewModel: …) {
* …
* TaskEffect(viewModel) {
* viewModel.doSomethingSuspending(state1, state2, state3)
* }
* }
* ```
*
* If you're reading a state object in composition just to capture in a [LaunchedEffect], then just
* convert your effect to a [TaskEffect] and move the read into it:
* ```kotlin
* // Bad:
* @Composable fun MyComposable() {
* val state by remember { mutableStateOf(…) }
* val currentState = state
* LaunchedEffect(currentState) {
* viewModel.doSomethingSuspending(currentState)
* }
* }
*
* // Good:
* @Composable fun MyComposable() {
* val state by remember { mutableStateOf(…) }
* TaskEffect {
* viewModel.doSomethingSuspending(state)
* }
* }
* ```
*
* If you're using a state object as the key to a [LaunchedEffect] and reading the same state object in
* the effect then your effect is probably not doing what you think, and you can fix it by using
* [TaskEffect] without the state object keys (for more on this anti-pattern, see
* [Compose Chronicles #3](https://www.notion.so/square-seller/Chronicle-Edition-3-10e70293beed809385e3de61602001eb?pvs=4#15270293beed80f1aa49d03cd1e34e24)):
* ```kotlin
* // Bad:
* @Composable fun MyComposable() {
* val state by remember { mutableStateOf(…) }
* LaunchedEffect(state) {
* viewModel.doSomethingSuspending(state)
* }
* }
*
* // Good:
* @Composable fun MyComposable() {
* val state by remember { mutableStateOf(…) }
* TaskEffect {
* viewModel.doSomethingSuspending(state)
* }
* }
* ```
*/
@Composable
public fun TaskEffect(
vararg keys: Any,
block: suspend () -> Unit
) {
// No value in using rememberUpdatedState here, we'd effectively just restart the entire effect in
// that case anyway.
LaunchedEffect(*keys, block) {
restartOnStateChanged(block)
}
}
/**
* This is the meat of [TaskEffect]. Theoretically it could be exposed as public API itself, but that
* would allow it to be nested, which is not currently supported. It could be, but would require a lot
* more design and testing.
*/
internal suspend fun restartOnStateChanged(block: suspend () -> Unit): Nothing {
coroutineScope {
// 1. pass in ReadObserver
// 2. record every state object that gets read
// 3. listen for changes to those state object
// 4. cancel and restart the coroutine if there are any changes
// TODO if block() doesn't actually read any state, then we can just let the loop exit and clean
// everything up, making this _almost_ as cheap as just using LaunchedEffect directly.
val observer = TaskEffectSnapshotObserver()
while (true) {
val job = observer.runThenAwaitStateChange { readObserver ->
val element = SnapshotObservingElement(readObserver)
launch {
// TODO Tests fail if the element is passed directly to launch instead of withContext –
// why? Passing to launch should be enough and is more efficient.
withContext(element) {
block()
}
}
}
// We need to restart the task, so cancel it if it's currently suspended and hasn't already
// completed.
job.cancelAndJoin()
} // Loop and restart the task.
// Compiler bug, it's not smart enough to infer a Nothing type from an infinite loop.
@Suppress("UNREACHABLE_CODE")
throw AssertionError("Should never leave while loop")
}
}
/**
* A [ThreadContextElement] that can be added to a coroutine's context to run the whole coroutine
* under a snapshot that will observe snapshot state reads and report them to [readObserver]. Each time
* the coroutine is resumed, a new mutable snapshot is created with [readObserver] and entered. When
* the coroutine suspends, the snapshot is left and applied.
*/
// TODO use SnapshotContextElement as the key, so they overwrite each other in the context?
// TODO It's now possible to restart yourself by writing to a state after suspending that you read
// before suspending. We might be able to avoid this by detecting our own snapshots in the apply
// observer.
@OptIn(ExperimentalComposeApi::class)
private class SnapshotObservingElement(
private val readObserver: (Any) -> Unit,
) : ThreadContextElement<CurrentSnapshot>, CoroutineContext.Element {
override val key get() = Key
override fun updateThreadContext(context: CoroutineContext): CurrentSnapshot {
val snapshot = Snapshot.takeMutableSnapshot(readObserver)
// Don't need to track the old snapshot – we'll apply this snapshot when restoring, which puts the
// global snapshot back.
val oldSnapshot = snapshot.unsafeEnter()
return CurrentSnapshot(
snapshot = snapshot,
oldSnapshot = oldSnapshot
)
}
override fun restoreThreadContext(
context: CoroutineContext,
oldState: CurrentSnapshot
) {
val snapshot = oldState.snapshot
if (Snapshot.current !== snapshot) {
// Uneven restore calls, we're not in the snapshot we thought, do nothing.
Log.w("TaskEffect", "detected uneven update/restore calls")
return
}
snapshot.unsafeLeave(oldState.oldSnapshot)
// TODO Shouldn't call check() here since it throws, and we can't throw from this method. But
// Compose's SnapshotContextElement throws here and the exception shows up in all the expected
// places so it seems fine? Also I'm not sure how else to report failure: cancel context?
try {
snapshot.apply().check()
} finally {
snapshot.dispose()
}
}
class CurrentSnapshot(
val snapshot: MutableSnapshot,
val oldSnapshot: Snapshot?
)
companion object Key : CoroutineContext.Key<SnapshotObservingElement>
}
import androidx.collection.MutableScatterSet
import androidx.collection.ScatterSet
import androidx.collection.mutableScatterMapOf
import androidx.collection.mutableScatterSetOf
import androidx.compose.runtime.derivedStateOf
import androidx.compose.runtime.snapshots.ObserverHandle
import androidx.compose.runtime.snapshots.Snapshot
import androidx.compose.runtime.snapshots.SnapshotStateObserver
import androidx.compose.runtime.snapshots.StateObject
import com.squareup.ui.internal.utils.DerivedStateHelper.asDerivedState
import com.squareup.ui.internal.utils.DerivedStateHelper.asMaybeDerivedState
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.ReceiveChannel
/**
* Helper class for observing changes to snapshot change. Similar to [SnapshotStateObserver] but
* exposes the read observer which we need for [TaskEffect]. This class has full support for
* [derivedStateOf] and won't return from [runThenAwaitStateChange] when a derived state's dependencies
* change but the final result of the calculation function is unchanged.
*/
internal class TaskEffectSnapshotObserver {
private val appliedChanges = Channel<Set<Any>>(100)
private val trackedStateObjects = mutableScatterSetOf<Any>()
private val derivedStateValues = mutableScatterMapOf<StateObject, Any?>()
private val dependenciesToDerivedStates =
mutableScatterMapOf<Any, MutableScatterSet<StateObject>>()
private val changes = MutableScatterSet<Any>()
private var derivedStateNestingLevel = 0
private var observerHandle: ObserverHandle? = null
private val readObserver: (Any) -> Unit = { stateObject ->
// The read observer can be called from any thread, so changes to the list must be synchronized.
lock {
recordRead(stateObject as StateObject)
}
}
private val applyObserver: (Set<Any>, Snapshot) -> Unit = { changedSet, _ ->
appliedChanges.trySend(changedSet).getOrThrow()
}
/**
* Runs [block] and then suspends until any state objects passed to the `readObserver` function
* passed to [block] are changed.
*
* The `readObserver` function passed to [block] is intended to be passed to any [Snapshot]s created
* inside [block].
*/
suspend inline fun <T : Any> runThenAwaitStateChange(
block: (readObserver: (Any) -> Unit) -> T
): T {
lateinit var result: T
// Re-register the apply observer every time since registering an observer advances the global
// snapshot to avoid catching any changes that occurred before the call. This "observation
// barrier" makes our code more correct. There might be other ways to do this if it turns out
// that constantly registering/unregistering a global observer is too expensive.
observerHandle = Snapshot.registerApplyObserver(applyObserver)
try {
DerivedStateHelper.observeDerivedStateRecalculations(
onStart = { derivedStateNestingLevel++ },
onDone = { derivedStateNestingLevel-- }
) {
result = block(readObserver)
}
// Wait for the apply observer to detect a change that we care about.
waitForChange()
} finally {
resetState()
}
return result
}
private fun recordRead(stateObject: StateObject) {
if (derivedStateNestingLevel > 0) {
// This read is happening from a derived state's calculation block. The derived state object
// itself will record and cache the dependency, and after the calculation is done we'll read all
// its dependencies explicitly below.
return
}
val derivedState = stateObject.asMaybeDerivedState()
if (derivedState != null) {
// Derived states are special! See
// https://blog.zachklipp.com/how-derivedstateof-works-a-deep-d-er-ive/. We need to track the
// actual value, so we can compare the new value when a dependency is written. We also need
// manually track its dependencies since we might not actually run the calculation function
// here. Note that all these APIs are internal to the Compose runtime, so we use a helper to
// sneak past the visibility guards and keep that dangerous code isolated.
// Cache the derived state's value so we can check if it's changed later.
val value = derivedState.currentValue
derivedStateValues[stateObject] = value
// Manually track each of its dependencies, which were either recorded above or the last time
// the calculation function was ran.
val dependencies = derivedState.dependencies
dependencies.forEachKey { dependency ->
val set = dependenciesToDerivedStates.getOrPut(dependency) { MutableScatterSet() }
set += stateObject
}
}
// Always track the object itself. If it's a derived state, it may still send a write if something
// else causes a new value to be calculated.
trackedStateObjects.add(stateObject)
}
private suspend fun waitForChange() {
do {
// Don't call the suspending receive() if there's already something ready in the channel.
appliedChanges.drainTo(changes)
if (changes.isEmpty()) {
changes.addAll(appliedChanges.receive())
appliedChanges.drainTo(changes)
}
val needsRestart = checkNeedsRestart(changes)
changes.clear()
} while (!needsRestart)
}
private fun checkNeedsRestart(changes: ScatterSet<Any>): Boolean {
lock {
changes.forEach { stateObject ->
if (stateObject in trackedStateObjects) {
// This is a normal state object that we're tracking directly – always invalidate.
return true
}
val derivedStates = dependenciesToDerivedStates[stateObject]
if (derivedStates?.isNotEmpty() == true) {
// We aren't tracking this object directly, it's just a dependency of a derived state we are
// tracking. Only invalidate if the new derived state value is actually different.
derivedStates.forEach { derivedState ->
val proxy = derivedState.asDerivedState()
val oldValue = derivedStateValues[derivedState]
if (!proxy.isCurrentValueEquivalentTo(oldValue)) {
return true
}
}
}
}
}
return false
}
private fun resetState() {
observerHandle?.dispose()
observerHandle = null
trackedStateObjects.clear()
derivedStateValues.clear()
dependenciesToDerivedStates.clear()
changes.clear()
appliedChanges.drainTo()
}
private fun <T> ReceiveChannel<Set<T>>.drainTo(set: MutableScatterSet<T>? = null) {
do {
val result = tryReceive()
if (set != null) {
result.getOrNull()?.let(set::addAll)
}
} while (result.isSuccess)
}
private inline fun lock(block: () -> Unit) {
synchronized(trackedStateObjects, block)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment