Skip to content

Instantly share code, notes, and snippets.

@dfed
Last active December 10, 2022 23:34
Show Gist options
  • Save dfed/f365487d6ed369690ab2126725695f08 to your computer and use it in GitHub Desktop.
Save dfed/f365487d6ed369690ab2126725695f08 to your computer and use it in GitHub Desktop.
// MIT License
//
// Copyright (c) 2022 Dan Federman
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
// https://github.com/dfed/swift-async-queue
import AsyncQueue
public actor Observable<T: Sendable>: Sendable {
// MARK: Lifecycle
public init(priority: TaskPriority? = nil, initialValue: T? = nil) {
queue = ActorQueue(priority: priority)
if let initialValue {
sendValue(initialValue)
}
}
deinit {
for continuation in continuations.values {
continuation.finish()
}
}
// MARK: Public
/// Observes values in the receiver.
/// - Parameters:
/// - observer: A closure that receives values in the receiver in order, starting with the current value (if there is a current value)
/// - onTermination: A closure called when the observed stream ends.
/// - Returns: A task that contains the observation stream. Cancel this stream to cancel observation.
@discardableResult
nonisolated
public func observe(
_ observer: @escaping @Sendable (T) async -> Void,
onTermination: (@Sendable () async -> Void)? = nil
) -> Task<Void, Never> {
// Capture the stream in its current state before we enter the async queue to ensure we don't drop values.
Task { [observedStream = stream] in
for await value in observedStream {
guard !Task.isCancelled else {
// Our observation has been cancelled but the continuation's onTermination closure hasn't executed just yet.
// Drop this value and wait for the observed stream to be terminated.
continue
}
await observer(value)
}
await onTermination?()
}
}
nonisolated
public func sendValue(_ value: T) {
queue.async {
await self.setCurrentValue(to: value)
}
}
// MARK: Private
private let queue: ActorQueue
nonisolated
private var stream: AsyncStream<T> {
AsyncStream(T.self) { continuation in
queue.async {
await self.beginReceiving(on: continuation)
}
}
}
private func setCurrentValue(to value: T) {
currentValue = value
}
private func beginReceiving(
on continuation: AsyncStream<T>.Continuation)
{
let identifier = UUID()
continuations[identifier] = continuation
continuation.onTermination = { @Sendable [weak self] _ in
guard let self else { return }
self.queue.async {
await self.removeContinuation(with: identifier)
}
}
if let currentValue = currentValue {
continuation.yield(currentValue)
}
}
private func removeContinuation(with identifier: UUID) {
continuations[identifier] = nil
}
private var continuations = [UUID: AsyncStream<T>.Continuation]()
private var currentValue: T? {
didSet {
guard let currentValue else { return }
for continuation in continuations.values {
continuation.yield(currentValue)
}
}
}
}
final class ObservableTests: XCTestCase {
private var systemUnderTest: Observable! = Observable(initialValue: 0)
// MARK: XCTestCase
override func setUp() async throws {
try await super.setUp()
systemUnderTest = Observable(initialValue: 0)
}
// MARK: Behavior Tests
func test_observe_includesAllValuesInOrder() async {
let expectation = self.expectation(description: #function)
let counter = Counter()
systemUnderTest.observe { value in
await counter.incrementAndExpectCount(equals: value + 1)
if value == 1_000 {
expectation.fulfill()
}
}
for iteration in 1...1_000 {
systemUnderTest.sendValue(iteration)
}
await waitForExpectations(timeout: 1.0)
}
func test_observe_stopsPropagationOfValuesOnCancel() async {
let counter = Counter()
let expectation = self.expectation(description: #function)
expectation.expectedFulfillmentCount = 1
systemUnderTest.sendValue(1)
let observer = systemUnderTest.observe { value in
await counter.incrementAndExpectCount(equals: value)
expectation.fulfill()
}
await waitForExpectations(timeout: 1.0)
observer.cancel()
systemUnderTest.sendValue(2)
}
func test_observe_callsCompletionOnCancel() async {
let expectation = self.expectation(description: #function)
systemUnderTest.sendValue(1)
let observer = systemUnderTest.observe({ _ in }, onTermination: {
expectation.fulfill()
})
observer.cancel()
await waitForExpectations(timeout: 1.0)
systemUnderTest.sendValue(2)
}
func test_observe_callsCompletionOnStreamEnd() async {
let expectation = self.expectation(description: #function)
systemUnderTest.sendValue(1)
systemUnderTest.observe({ _ in }, onTermination: {
expectation.fulfill()
})
// Deallocate the prior observable to trigger completion.
systemUnderTest = Observable()
await waitForExpectations(timeout: 1.0)
}
func test_observe_doesNotReceiveAllPastValues() async {
let counter = Counter()
for iteration in 1...100 {
let expectation = self.expectation(description: "\(#function):\(iteration)")
systemUnderTest.sendValue(iteration)
let observer = systemUnderTest.observe { value in
await counter.incrementAndExpectCount(equals: value)
expectation.fulfill()
}
await waitForExpectations(timeout: 1.0)
observer.cancel()
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment