Skip to content

Instantly share code, notes, and snippets.

@danielt1263
Last active August 30, 2024 06:33
Show Gist options
  • Save danielt1263/68097ac187cd43c23945bc6c15f5cc0b to your computer and use it in GitHub Desktop.
Save danielt1263/68097ac187cd43c23945bc6c15f5cc0b to your computer and use it in GitHub Desktop.
The TokenAcquisitionService automatically retry requests if it receives an unauthorized error. Complete with proof that it works correctly.
//
// TokenAcquisitionService.swift
//
// Created by Daniel Tartaglia on 16 Jan 2019.
// Copyright © 2024 Daniel Tartaglia. MIT License.
//
import Foundation
import RxSwift
public typealias Response = (URLRequest) -> Observable<(response: HTTPURLResponse, data: Data)>
/// Builds and makes network requests using the token provided by the service. Will request a new token and retry if
/// the result is an unauthorized (401) error.
///
/// - Parameters:
/// - response: A function that sends requests to the network and emits responses. Can be for example
/// `URLSession.shared.rx.response`
/// - tokenAcquisitionService: The object responsible for tracking the auth token. All requests should use the same
/// object.
/// - request: A function that can build the request when given a token.
/// - Returns: response of a guaranteed authorized network request.
public func getData<T>(response: @escaping Response,
tokenAcquisitionService: TokenAcquisitionService<T>,
request: @escaping (T) throws -> URLRequest) -> Observable<(response: HTTPURLResponse, data: Data)> {
return Observable
.deferred {
tokenAcquisitionService.token.take(1).compactMap { result in
guard case let .success(token) = result else { return nil }
return token
}
}
.map { try request($0) }
.flatMap { response($0) }
.map { response in
guard response.response.statusCode != 401 else { throw TokenAcquisitionError.unauthorized }
return response
}
.retry(when: { $0.renewToken(with: tokenAcquisitionService) })
}
// MARK: -
extension ObservableConvertibleType where Element == Error {
/// Monitors self for `.unauthorized` error events and passes all other errors on. When an `.unauthorized` error is
/// seen, the `service` will get a new token and emit a signal that it's safe to retry the request.
///
/// - Parameter service: A `TokenAcquisitionService` object that is being used to store the auth token for the request.
/// - Returns: A trigger that will emit when it's safe to retry the request.
public func renewToken<T>(with service: TokenAcquisitionService<T>) -> Observable<Void> {
return service.trackErrors(for: self)
}
}
/// Errors recognized by the `TokenAcquisitionService`.
///
/// - unauthorized: It listens for and activates when it receives an `.unauthorized` error.
/// - refusedToken: It emits a `.refusedToken` error if the `getToken` request fails.
public enum TokenAcquisitionError: Error, Equatable {
case unauthorized
case refusedToken(response: HTTPURLResponse, data: Data)
}
public final class TokenAcquisitionService<T> {
/// responds with the current token immediatly and emits a new token whenver a new one is aquired. You can, for
/// example, subscribe to it in order to save the token as it's updated. If token acquisition fails, this will emit a
/// `.next(.failure)` event.
public var token: Observable<Result<T, Error>> {
return _token.asObservable()
}
public typealias GetToken = (T) -> Observable<(response: HTTPURLResponse, data: Data)>
/// Creates a `TokenAcquisitionService` object that will store the most recent authorization token acquired and will
/// acquire new ones as needed.
///
/// - Parameters:
/// - initialToken: The token the service should start with. Provide a token from storage or an empty/nil object
/// represting a missing token, if one has not been aquired yet.
/// - getToken: A function responsable for aquiring new tokens when needed.
/// - extractToken: A function that can extract a token from the data returned by `getToken`.
public init(initialToken: T, getToken: @escaping GetToken, extractToken: @escaping (Data) throws -> T) {
relay
.flatMapFirst { token in
getToken(token)
.map { (urlResponse) -> Result<T, Error> in
guard urlResponse.response.statusCode / 100 == 2 else {
return .failure(TokenAcquisitionError.refusedToken(response: urlResponse.response, data: urlResponse.data))
}
return Result(catching: { try extractToken(urlResponse.data) })
}
.catch { Observable.just(Result.failure($0)) }
}
.startWith(.success(initialToken))
.subscribe(_token)
.disposed(by: disposeBag)
}
/// Allows the token to be set imperativly if necessary.
/// - Parameter token: The new token the service should use. It will immediatly be emitted to any subscribers to the
/// service.
public func setToken(_ token: T) {
lock.lock()
_token.onNext(.success(token))
lock.unlock()
}
/// Monitors the source for `.unauthorized` error events and passes all other errors on. When an `.unauthorized` error
/// is seen, `self` will get a new token and emit a signal that it's safe to retry the request.
///
/// - Parameter source: An `Observable` (or like type) that emits errors.
/// - Returns: A trigger that will emit when it's safe to retry the request.
func trackErrors<O: ObservableConvertibleType>(for source: O) -> Observable<Void> where O.Element == Error {
let lock = self.lock
let relay = self.relay
let error = source
.asObservable()
.map { error in
guard (error as? TokenAcquisitionError) == .unauthorized else { throw error }
}
.flatMap { [unowned self] in self.token.take(1) }
.do(onNext: {
guard case let .success(token) = $0 else { return }
lock.lock()
relay.onNext(token)
lock.unlock()
})
.filter { _ in false }
.map { _ in }
return Observable.merge(token.skip(1).map { _ in }, error)
}
private let _token = ReplaySubject<Result<T, Error>>.create(bufferSize: 1)
private let relay = PublishSubject<T>()
private let lock = NSRecursiveLock()
private let disposeBag = DisposeBag()
}
//
// TokenAcquisitionServiceTests.swift
//
// Created by Daniel Tartaglia on 16 Jan 2019.
// Copyright © 2024 Daniel Tartaglia. MIT License.
//
import RxSwift
import RxTest
import XCTest
class TokenAcquisitionServiceTests: XCTestCase {
var scheduler: TestScheduler!
var tokenResult: TestableObserver<Result<String, Error>>!
var triggerResult: TestableObserver<Void>!
var bag: DisposeBag!
override func setUp() {
super.setUp()
scheduler = TestScheduler(initialClock: 0)
tokenResult = scheduler.createObserver(Result<String, Error>.self)
triggerResult = scheduler.createObserver(Void.self)
bag = DisposeBag()
}
func testInitial() {
// given
func getToken(_ old: String) -> Observable<(response: HTTPURLResponse, data: Data)> {
XCTFail()
return .empty()
}
func extractToken(_ data: Data) -> String {
XCTFail()
return ""
}
let service = TokenAcquisitionService(initialToken: "first", getToken: getToken, extractToken: extractToken)
// when
service.token
.bind(to: tokenResult)
.disposed(by: bag)
scheduler.start()
// then
XCTAssertEqual(tokenResult.events.map(extractEvent), [.next(0, "first")])
}
func testUpdate() {
// given
func getToken(_ old: String) -> Observable<(response: HTTPURLResponse, data: Data)> {
XCTFail()
return .empty()
}
let service = TokenAcquisitionService<String>(initialToken: "first", getToken: getToken, extractToken: extractToken)
// when
scheduler.scheduleAt(10) {
service.setToken("second")
}
service.token
.bind(to: tokenResult)
.disposed(by: bag)
scheduler.start()
// then
XCTAssertEqual(tokenResult.events.map { $0.map { $0.map { try $0.get() } } }, [.next(0, "first"), .next(10, "second")])
}
func testUnauthorized() {
// given
let trigger = scheduler.createColdObservable([.next(10, TokenAcquisitionError.unauthorized as Error)])
func getToken(_ old: String) -> Observable<(response: HTTPURLResponse, data: Data)> {
XCTAssertEqual(old, "first")
let response = HTTPURLResponse(url: URL(fileURLWithPath: ""), statusCode: 200, httpVersion: nil, headerFields: nil)!
let data = "second".data(using: .utf8)!
return Observable.just((response: response, data: data))
}
let service = TokenAcquisitionService<String>(initialToken: "first", getToken: getToken, extractToken: extractToken)
// when
bag.insert(
service.token.bind(to: tokenResult),
trigger.renewToken(with: service).bind(to: triggerResult)
)
scheduler.start()
// then
XCTAssertEqual(tokenResult.events.map(extractEvent), [.next(0, "first"), .next(10, "second")])
XCTAssertEqual(triggerResult.events.map { $0.time }, [10])
}
func testBadTokenRequest() {
// given
let trigger = scheduler.createColdObservable([.next(10, TokenAcquisitionError.unauthorized as Error)])
let response = HTTPURLResponse(url: URL(fileURLWithPath: ""), statusCode: 500, httpVersion: nil, headerFields: nil)!
let data = "second".data(using: .utf8)!
func getToken(_ old: String) -> Observable<(response: HTTPURLResponse, data: Data)> {
XCTAssertEqual(old, "first")
return Observable.just((response: response, data: data))
}
let service = TokenAcquisitionService<String>(initialToken: "first", getToken: getToken, extractToken: extractToken)
// when
bag.insert(
service.token.bind(to: tokenResult),
trigger.renewToken(with: service).bind(to: triggerResult)
)
scheduler.start()
// then
XCTAssertEqual(
tokenResult.events.map { $0.map { $0.map { $0.mapError { $0 as! TokenAcquisitionError } } } },
[.next(0, .success("first")), .next(10, .failure(TokenAcquisitionError.refusedToken(response: response, data: data)))]
)
XCTAssertEqual(triggerResult.events.map { $0.time }, [10])
}
func testOtherErrorsFallThrough() {
// given
let trigger = scheduler.createColdObservable([.next(10, RxError.unknown as Error)])
func getToken(_ old: String) -> Observable<(response: HTTPURLResponse, data: Data)> {
XCTAssertEqual(old, "first")
let response = HTTPURLResponse(url: URL(fileURLWithPath: ""), statusCode: 200, httpVersion: nil, headerFields: nil)!
let data = "second".data(using: .utf8)!
return Observable.just((response: response, data: data))
}
let service = TokenAcquisitionService<String>(initialToken: "first", getToken: getToken, extractToken: extractToken)
// when
bag.insert(
service.token.bind(to: tokenResult),
trigger.renewToken(with: service).bind(to: triggerResult)
)
scheduler.start()
// then
XCTAssertEqual(tokenResult.events.map(extractEvent), [.next(0, "first")])
XCTAssertEqual(triggerResult.events.map { $0.time }, [10])
}
func testMultipleUnauthsOnlyCauseOneTokenRequest() {
// given
let trigger1 = scheduler.createColdObservable([.next(10, TokenAcquisitionError.unauthorized as Error)])
let trigger2 = scheduler.createColdObservable([.next(30, TokenAcquisitionError.unauthorized as Error)])
let triggerResult2 = scheduler.createObserver(Void.self)
var requestCount = 0
func getToken(_ old: String) -> Observable<(response: HTTPURLResponse, data: Data)> {
XCTAssertEqual(old, "first")
requestCount += 1
let response = HTTPURLResponse(url: URL(fileURLWithPath: ""), statusCode: 200, httpVersion: nil, headerFields: nil)!
let data = "second".data(using: .utf8)!
return Observable.just((response: response, data: data)).delay(.seconds(20), scheduler: scheduler)
}
let service = TokenAcquisitionService<String>(initialToken: "first", getToken: getToken, extractToken: extractToken)
// when
bag.insert(
service.token.bind(to: tokenResult),
trigger1.renewToken(with: service).bind(to: triggerResult),
trigger2.renewToken(with: service).bind(to: triggerResult2)
)
scheduler.start()
// then
XCTAssertEqual(tokenResult.events.map(extractEvent), [.next(0, "first"), .next(30, "second")])
XCTAssertEqual(triggerResult.events.map { $0.time }, [30])
XCTAssertEqual(triggerResult2.events.map { $0.time }, [30])
XCTAssertEqual(requestCount, 1)
}
}
func extractToken(_ data: Data) -> String {
return String(data: data, encoding: .utf8) ?? ""
}
extension Recorded {
func map<T>(_ transform: (Value) throws -> T) rethrows -> Recorded<T> {
Recorded<T>(time: time, value: try transform(value))
}
}
func extractEvent(_ event: Recorded<Event<Result<String, any Error>>>) -> Recorded<Event<String>> {
event.map { $0.map { try $0.get() } }
}
@supervtb
Copy link

@danielt1263 hi , i use moya , my network request method:

init(endpointClosure: @escaping MoyaProvider<YunguApi>.EndpointClosure = MoyaProvider<YunguApi>.defaultEndpointMapping,
      requestClosure: @escaping MoyaProvider<YunguApi>.RequestClosure = MoyaProvider<YunguApi>.defaultRequestMapping,
      stubClosure: @escaping MoyaProvider<YunguApi>.StubClosure = MoyaProvider<YunguApi>.neverStub,
      session: Session = MoyaProvider<YunguApi>.defaultAlamofireSession(),
      plugins: [PluginType] = [],
      trackInflights: Bool = false,
      online: Observable<Bool> = connectedToInternet()) {
   
   self.online = online
   self.provider = MoyaProvider(endpointClosure: endpointClosure, requestClosure: requestClosure, stubClosure: stubClosure, session: session, plugins: plugins, trackInflights: trackInflights) 
 }
 
 func request(_ token: YunguApi, isSecondTryAfterAuth: Bool = false) -> Observable<Moya.Response> {
   let actualRequest = provider.rx.request(token)
   return online
     .ignore(value: false)  // Wait until we're online
     .take(1)        // Take 1 to make sure we only invoke the API once.
     .flatMap { _ in // Turn the online state into a network request
       return actualRequest
         
   }
 }

I tried using it but it doesn't work How can I transform it? thank you

@danielt1263 Hello, do you have any suggestion how we can use it in RxMoya? :)

@danielt1263
Copy link
Author

@supervtb I don't know Moya, but you will need something that looks kind of like:

func myRequest(_ token: YunguApi) -> Observable<(response: HTTPURLResponse, data: Data)> {
	YourClass<YunguApi>().request(token)
		.map { moyaResponse in
			return (moyaResponse.httpURLResponse, moyaResponse.data)
		}
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment