Last active
December 19, 2024 17:47
-
-
Save zach-klippenstein/63e17047ade83d78d24130232bf33fa9 to your computer and use it in GitHub Desktop.
Prototype implementation of TaskEffect (https://issuetracker.google.com/issues/361645776)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@file:Suppress( | |
"CANNOT_OVERRIDE_INVISIBLE_MEMBER", | |
"INVISIBLE_MEMBER", | |
"INVISIBLE_REFERENCE", | |
) | |
import androidx.collection.ObjectIntMap | |
import androidx.compose.runtime.DerivedState | |
import androidx.compose.runtime.DerivedStateObserver | |
import androidx.compose.runtime.SnapshotMutationPolicy | |
import androidx.compose.runtime.derivedStateObservers | |
import androidx.compose.runtime.derivedStateOf | |
import androidx.compose.runtime.snapshots.StateObject | |
import androidx.compose.runtime.structuralEqualityPolicy | |
/** | |
* Helpers for working with [derivedStateOf] state objects. All these APIs are internal to the Compose | |
* runtime, so we can't access them normally. We cheat by disabling internal checks on this file, but | |
* that's really dangerous since it allows us to access tons of stuff, so these helpers are in their | |
* own file to isolate from the higher-level code. | |
*/ | |
internal object DerivedStateHelper { | |
fun StateObject.asMaybeDerivedState(): DerivedStateProxy<*>? = | |
if (this is DerivedState<*>) DerivedStateProxy(this) | |
else null | |
fun StateObject.asDerivedState(): DerivedStateProxy<*> = | |
DerivedStateProxy(this as DerivedState<*>) | |
inline fun <R> observeDerivedStateRecalculations( | |
crossinline onStart: () -> Unit, | |
crossinline onDone: () -> Unit, | |
block: () -> R | |
) { | |
val observers = derivedStateObservers() | |
try { | |
observers.add(object : DerivedStateObserver { | |
override fun start(derivedState: DerivedState<*>) { | |
onStart() | |
} | |
override fun done(derivedState: DerivedState<*>) { | |
onDone() | |
} | |
}) | |
block() | |
} finally { | |
observers.removeAt(observers.lastIndex) | |
} | |
} | |
/** | |
* No-overhead wrapper of a [DerivedState] that allows other code in this module to access properties | |
* of a [DerivedState] without enabling unsafe access. | |
*/ | |
@JvmInline | |
value class DerivedStateProxy<T>(private val derivedState: DerivedState<T>) { | |
val policy: SnapshotMutationPolicy<T>? get() = derivedState.policy | |
val currentValue: T get() = derivedState.currentRecord.currentValue | |
val dependencies: ObjectIntMap<StateObject> get() = derivedState.currentRecord.dependencies | |
/** | |
* Compares [value] to this state's current value using its [SnapshotMutationPolicy]. | |
*/ | |
fun isCurrentValueEquivalentTo(value: Any?): Boolean { | |
val policy = policy ?: structuralEqualityPolicy() | |
@Suppress("UNCHECKED_CAST") | |
return policy.equivalent(value as T, currentValue) | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import android.util.Log | |
import androidx.compose.runtime.Composable | |
import androidx.compose.runtime.DisposableEffect | |
import androidx.compose.runtime.ExperimentalComposeApi | |
import androidx.compose.runtime.LaunchedEffect | |
import androidx.compose.runtime.snapshotFlow | |
import androidx.compose.runtime.snapshots.MutableSnapshot | |
import androidx.compose.runtime.snapshots.Snapshot | |
import com.squareup.ui.internal.utils.SnapshotObservingElement.CurrentSnapshot | |
import kotlinx.coroutines.ThreadContextElement | |
import kotlinx.coroutines.cancelAndJoin | |
import kotlinx.coroutines.coroutineScope | |
import kotlinx.coroutines.launch | |
import kotlinx.coroutines.withContext | |
import kotlin.coroutines.CoroutineContext | |
/** | |
* An effect that is similar to a [LaunchedEffect]—it launches a coroutine in the current composition's | |
* coroutine context—but will automatically cancel the coroutine (if not complete) and restart it when | |
* any snapshot state objects that have been read in the coroutine are written. | |
* | |
* It's a bit like a [LaunchedEffect] combined with a [snapshotFlow], but even better because both | |
* state reads and suspending function calls can happen mixed in with each other and don't need to be | |
* separated into separate observation lambdas. | |
* | |
* ## Coroutine behavior | |
* | |
* As soon as [block] reads a snapshot state object, that state object will be monitored for changes | |
* outside your coroutine. Monitoring will continue as long as the [TaskEffect] is present in the | |
* composition, even after the coroutine completes. If a monitored state is changed while the coroutine | |
* is suspended, the coroutine will first be cancelled and allowed to finish cancellation (including | |
* any suspend calls in `finally` blocks). Then, once it's fully cancelled, the effect will be | |
* restarted. If the coroutine had already completed there's nothing to cancel and it will just be | |
* restarted. Each time the effect is restarted, the set of monitored state objects is cleared, so if | |
* you read different state objects on the next invocation, those will be monitored going forward. This | |
* is the same behavior as other restartable functions in Compose (composition, [snapshotFlow], etc.). | |
* | |
* ## Snapshot behavior | |
* | |
* Unlike other effects ([LaunchedEffect], [DisposableEffect]), but similar to composition, | |
* [TaskEffect]s run inside a snapshot. This means that during a single invocation of the effect | |
* function, any changes the effect function makes to state objects won't be seen outside the effect | |
* until the effect function returns. It also means that any changes to state objects made from outside | |
* the effect will not be visible in the effect function until it restarts. | |
* | |
* If the effect function is restarted before it finishes (e.g. if a state object it read changes while | |
* it's suspended) then any changes made to state objects up to that point will be discarded and never | |
* made visible outside the effect. | |
* | |
* ## Further reading. | |
* | |
* For a more general discussion of “restartable functions” in Compose, see | |
* [this article](https://blog.zachklipp.com/restartable-functions-from-first-principles/). | |
* | |
* There is [an issue](https://issuetracker.google.com/issues/361645776) on the Google tracker to add | |
* this API, or something similar, to Compose, but it's not clear when it will be implemented. | |
* | |
* ## Recipes | |
* | |
* If you're using the value of a state object to launch a service call using [snapshotFlow], you can | |
* replace it with a [TaskEffect] as follows: | |
* ```kotlin | |
* // Old: | |
* @Composable fun MyComposable(viewModel: …) { | |
* val state by remember { mutableStateOf(…) } | |
* LaunchedEffect(viewModel) { | |
* snapshotFlow { state }.collectLatest { state -> | |
* viewModel.doSomethingSuspending(state) | |
* } | |
* } | |
* } | |
* | |
* // New: | |
* @Composable fun MyComposable(viewModel: …) { | |
* val state by remember { mutableStateOf(…) } | |
* TaskEffect(viewModel) { | |
* viewModel.doSomethingSuspending(state) | |
* } | |
* } | |
* ``` | |
* | |
* The benefit of this conversion is even more apparent when you're reading multiple states: | |
* ```kotlin | |
* // Old: | |
* @Composable fun MyComposable(viewModel: …) { | |
* … | |
* LaunchedEffect(viewModel) { | |
* snapshotFlow { | |
* Triple(state1, state2, state3) | |
* }.collectLatest { (state1, state2, state3) -> | |
* viewModel.doSomethingSuspending(state1, state2, state3) | |
* } | |
* } | |
* } | |
* | |
* // New: | |
* @Composable fun MyComposable(viewModel: …) { | |
* … | |
* TaskEffect(viewModel) { | |
* viewModel.doSomethingSuspending(state1, state2, state3) | |
* } | |
* } | |
* ``` | |
* | |
* If you're reading a state object in composition just to capture in a [LaunchedEffect], then just | |
* convert your effect to a [TaskEffect] and move the read into it: | |
* ```kotlin | |
* // Bad: | |
* @Composable fun MyComposable() { | |
* val state by remember { mutableStateOf(…) } | |
* val currentState = state | |
* LaunchedEffect(currentState) { | |
* viewModel.doSomethingSuspending(currentState) | |
* } | |
* } | |
* | |
* // Good: | |
* @Composable fun MyComposable() { | |
* val state by remember { mutableStateOf(…) } | |
* TaskEffect { | |
* viewModel.doSomethingSuspending(state) | |
* } | |
* } | |
* ``` | |
* | |
* If you're using a state object as the key to a [LaunchedEffect] and reading the same state object in | |
* the effect then your effect is probably not doing what you think, and you can fix it by using | |
* [TaskEffect] without the state object keys (for more on this anti-pattern, see | |
* [Compose Chronicles #3](https://www.notion.so/square-seller/Chronicle-Edition-3-10e70293beed809385e3de61602001eb?pvs=4#15270293beed80f1aa49d03cd1e34e24)): | |
* ```kotlin | |
* // Bad: | |
* @Composable fun MyComposable() { | |
* val state by remember { mutableStateOf(…) } | |
* LaunchedEffect(state) { | |
* viewModel.doSomethingSuspending(state) | |
* } | |
* } | |
* | |
* // Good: | |
* @Composable fun MyComposable() { | |
* val state by remember { mutableStateOf(…) } | |
* TaskEffect { | |
* viewModel.doSomethingSuspending(state) | |
* } | |
* } | |
* ``` | |
*/ | |
@Composable | |
public fun TaskEffect( | |
vararg keys: Any, | |
block: suspend () -> Unit | |
) { | |
// No value in using rememberUpdatedState here, we'd effectively just restart the entire effect in | |
// that case anyway. | |
LaunchedEffect(*keys, block) { | |
restartOnStateChanged(block) | |
} | |
} | |
/** | |
* This is the meat of [TaskEffect]. Theoretically it could be exposed as public API itself, but that | |
* would allow it to be nested, which is not currently supported. It could be, but would require a lot | |
* more design and testing. | |
*/ | |
internal suspend fun restartOnStateChanged(block: suspend () -> Unit): Nothing { | |
coroutineScope { | |
// 1. pass in ReadObserver | |
// 2. record every state object that gets read | |
// 3. listen for changes to those state object | |
// 4. cancel and restart the coroutine if there are any changes | |
// TODO if block() doesn't actually read any state, then we can just let the loop exit and clean | |
// everything up, making this _almost_ as cheap as just using LaunchedEffect directly. | |
val observer = TaskEffectSnapshotObserver() | |
while (true) { | |
val job = observer.runThenAwaitStateChange { readObserver -> | |
val element = SnapshotObservingElement(readObserver) | |
launch { | |
// TODO Tests fail if the element is passed directly to launch instead of withContext – | |
// why? Passing to launch should be enough and is more efficient. | |
withContext(element) { | |
block() | |
} | |
} | |
} | |
// We need to restart the task, so cancel it if it's currently suspended and hasn't already | |
// completed. | |
job.cancelAndJoin() | |
} // Loop and restart the task. | |
// Compiler bug, it's not smart enough to infer a Nothing type from an infinite loop. | |
@Suppress("UNREACHABLE_CODE") | |
throw AssertionError("Should never leave while loop") | |
} | |
} | |
/** | |
* A [ThreadContextElement] that can be added to a coroutine's context to run the whole coroutine | |
* under a snapshot that will observe snapshot state reads and report them to [readObserver]. Each time | |
* the coroutine is resumed, a new mutable snapshot is created with [readObserver] and entered. When | |
* the coroutine suspends, the snapshot is left and applied. | |
*/ | |
// TODO use SnapshotContextElement as the key, so they overwrite each other in the context? | |
// TODO It's now possible to restart yourself by writing to a state after suspending that you read | |
// before suspending. We might be able to avoid this by detecting our own snapshots in the apply | |
// observer. | |
@OptIn(ExperimentalComposeApi::class) | |
private class SnapshotObservingElement( | |
private val readObserver: (Any) -> Unit, | |
) : ThreadContextElement<CurrentSnapshot>, CoroutineContext.Element { | |
override val key get() = Key | |
override fun updateThreadContext(context: CoroutineContext): CurrentSnapshot { | |
val snapshot = Snapshot.takeMutableSnapshot(readObserver) | |
// Don't need to track the old snapshot – we'll apply this snapshot when restoring, which puts the | |
// global snapshot back. | |
val oldSnapshot = snapshot.unsafeEnter() | |
return CurrentSnapshot( | |
snapshot = snapshot, | |
oldSnapshot = oldSnapshot | |
) | |
} | |
override fun restoreThreadContext( | |
context: CoroutineContext, | |
oldState: CurrentSnapshot | |
) { | |
val snapshot = oldState.snapshot | |
if (Snapshot.current !== snapshot) { | |
// Uneven restore calls, we're not in the snapshot we thought, do nothing. | |
Log.w("TaskEffect", "detected uneven update/restore calls") | |
return | |
} | |
snapshot.unsafeLeave(oldState.oldSnapshot) | |
// TODO Shouldn't call check() here since it throws, and we can't throw from this method. But | |
// Compose's SnapshotContextElement throws here and the exception shows up in all the expected | |
// places so it seems fine? Also I'm not sure how else to report failure: cancel context? | |
try { | |
snapshot.apply().check() | |
} finally { | |
snapshot.dispose() | |
} | |
} | |
class CurrentSnapshot( | |
val snapshot: MutableSnapshot, | |
val oldSnapshot: Snapshot? | |
) | |
companion object Key : CoroutineContext.Key<SnapshotObservingElement> | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import androidx.collection.MutableScatterSet | |
import androidx.collection.ScatterSet | |
import androidx.collection.mutableScatterMapOf | |
import androidx.collection.mutableScatterSetOf | |
import androidx.compose.runtime.derivedStateOf | |
import androidx.compose.runtime.snapshots.ObserverHandle | |
import androidx.compose.runtime.snapshots.Snapshot | |
import androidx.compose.runtime.snapshots.SnapshotStateObserver | |
import androidx.compose.runtime.snapshots.StateObject | |
import com.squareup.ui.internal.utils.DerivedStateHelper.asDerivedState | |
import com.squareup.ui.internal.utils.DerivedStateHelper.asMaybeDerivedState | |
import kotlinx.coroutines.channels.Channel | |
import kotlinx.coroutines.channels.ReceiveChannel | |
/** | |
* Helper class for observing changes to snapshot change. Similar to [SnapshotStateObserver] but | |
* exposes the read observer which we need for [TaskEffect]. This class has full support for | |
* [derivedStateOf] and won't return from [runThenAwaitStateChange] when a derived state's dependencies | |
* change but the final result of the calculation function is unchanged. | |
*/ | |
internal class TaskEffectSnapshotObserver { | |
private val appliedChanges = Channel<Set<Any>>(100) | |
private val trackedStateObjects = mutableScatterSetOf<Any>() | |
private val derivedStateValues = mutableScatterMapOf<StateObject, Any?>() | |
private val dependenciesToDerivedStates = | |
mutableScatterMapOf<Any, MutableScatterSet<StateObject>>() | |
private val changes = MutableScatterSet<Any>() | |
private var derivedStateNestingLevel = 0 | |
private var observerHandle: ObserverHandle? = null | |
private val readObserver: (Any) -> Unit = { stateObject -> | |
// The read observer can be called from any thread, so changes to the list must be synchronized. | |
lock { | |
recordRead(stateObject as StateObject) | |
} | |
} | |
private val applyObserver: (Set<Any>, Snapshot) -> Unit = { changedSet, _ -> | |
appliedChanges.trySend(changedSet).getOrThrow() | |
} | |
/** | |
* Runs [block] and then suspends until any state objects passed to the `readObserver` function | |
* passed to [block] are changed. | |
* | |
* The `readObserver` function passed to [block] is intended to be passed to any [Snapshot]s created | |
* inside [block]. | |
*/ | |
suspend inline fun <T : Any> runThenAwaitStateChange( | |
block: (readObserver: (Any) -> Unit) -> T | |
): T { | |
lateinit var result: T | |
// Re-register the apply observer every time since registering an observer advances the global | |
// snapshot to avoid catching any changes that occurred before the call. This "observation | |
// barrier" makes our code more correct. There might be other ways to do this if it turns out | |
// that constantly registering/unregistering a global observer is too expensive. | |
observerHandle = Snapshot.registerApplyObserver(applyObserver) | |
try { | |
DerivedStateHelper.observeDerivedStateRecalculations( | |
onStart = { derivedStateNestingLevel++ }, | |
onDone = { derivedStateNestingLevel-- } | |
) { | |
result = block(readObserver) | |
} | |
// Wait for the apply observer to detect a change that we care about. | |
waitForChange() | |
} finally { | |
resetState() | |
} | |
return result | |
} | |
private fun recordRead(stateObject: StateObject) { | |
if (derivedStateNestingLevel > 0) { | |
// This read is happening from a derived state's calculation block. The derived state object | |
// itself will record and cache the dependency, and after the calculation is done we'll read all | |
// its dependencies explicitly below. | |
return | |
} | |
val derivedState = stateObject.asMaybeDerivedState() | |
if (derivedState != null) { | |
// Derived states are special! See | |
// https://blog.zachklipp.com/how-derivedstateof-works-a-deep-d-er-ive/. We need to track the | |
// actual value, so we can compare the new value when a dependency is written. We also need | |
// manually track its dependencies since we might not actually run the calculation function | |
// here. Note that all these APIs are internal to the Compose runtime, so we use a helper to | |
// sneak past the visibility guards and keep that dangerous code isolated. | |
// Cache the derived state's value so we can check if it's changed later. | |
val value = derivedState.currentValue | |
derivedStateValues[stateObject] = value | |
// Manually track each of its dependencies, which were either recorded above or the last time | |
// the calculation function was ran. | |
val dependencies = derivedState.dependencies | |
dependencies.forEachKey { dependency -> | |
val set = dependenciesToDerivedStates.getOrPut(dependency) { MutableScatterSet() } | |
set += stateObject | |
} | |
} | |
// Always track the object itself. If it's a derived state, it may still send a write if something | |
// else causes a new value to be calculated. | |
trackedStateObjects.add(stateObject) | |
} | |
private suspend fun waitForChange() { | |
do { | |
// Don't call the suspending receive() if there's already something ready in the channel. | |
appliedChanges.drainTo(changes) | |
if (changes.isEmpty()) { | |
changes.addAll(appliedChanges.receive()) | |
appliedChanges.drainTo(changes) | |
} | |
val needsRestart = checkNeedsRestart(changes) | |
changes.clear() | |
} while (!needsRestart) | |
} | |
private fun checkNeedsRestart(changes: ScatterSet<Any>): Boolean { | |
lock { | |
changes.forEach { stateObject -> | |
if (stateObject in trackedStateObjects) { | |
// This is a normal state object that we're tracking directly – always invalidate. | |
return true | |
} | |
val derivedStates = dependenciesToDerivedStates[stateObject] | |
if (derivedStates?.isNotEmpty() == true) { | |
// We aren't tracking this object directly, it's just a dependency of a derived state we are | |
// tracking. Only invalidate if the new derived state value is actually different. | |
derivedStates.forEach { derivedState -> | |
val proxy = derivedState.asDerivedState() | |
val oldValue = derivedStateValues[derivedState] | |
if (!proxy.isCurrentValueEquivalentTo(oldValue)) { | |
return true | |
} | |
} | |
} | |
} | |
} | |
return false | |
} | |
private fun resetState() { | |
observerHandle?.dispose() | |
observerHandle = null | |
trackedStateObjects.clear() | |
derivedStateValues.clear() | |
dependenciesToDerivedStates.clear() | |
changes.clear() | |
appliedChanges.drainTo() | |
} | |
private fun <T> ReceiveChannel<Set<T>>.drainTo(set: MutableScatterSet<T>? = null) { | |
do { | |
val result = tryReceive() | |
if (set != null) { | |
result.getOrNull()?.let(set::addAll) | |
} | |
} while (result.isSuccess) | |
} | |
private inline fun lock(block: () -> Unit) { | |
synchronized(trackedStateObjects, block) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment