Last active
July 20, 2021 11:08
-
-
Save christianscott/9694282406f1cc05808ea386aaab6d03 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { Preconditions } from 'base/preconditions'; | |
import { flushPromises } from 'base/testing/flush_promises'; | |
import { runTasksWithRateLimit, Task } from '../run_tasks_with_rate_limit'; | |
type Resolve<T> = (value?: T | PromiseLike<T>) => void; | |
type Reject = (reason: any) => void; | |
type Deferred<T = void> = Promise<T> & { resolve: Resolve<T>; reject: Reject }; | |
const deferred = <T>(): Deferred<T> => { | |
let resolve: Resolve<T> | undefined, reject: Reject | undefined; | |
const promise = new Promise((res, rej) => { | |
resolve = res; | |
reject = rej; | |
}) as Deferred<T>; | |
promise.resolve = Preconditions.checkExists(resolve); | |
promise.reject = Preconditions.checkExists(reject); | |
return promise; | |
}; | |
describe('runTasksWithRateLimit', () => { | |
jest.useFakeTimers(); | |
let p1: Deferred<string>; | |
let p2: Deferred<string>; | |
let p3: Deferred<string>; | |
let task1: jest.Mocked<Task<string>>; | |
let task2: jest.Mocked<Task<string>>; | |
let task3: jest.Mocked<Task<string>>; | |
let tasks: Task<string>[]; | |
beforeEach(() => { | |
p1 = deferred(); | |
p2 = deferred(); | |
p3 = deferred(); | |
task1 = jest.fn(() => p1); | |
task2 = jest.fn(() => p2); | |
task3 = jest.fn(() => p3); | |
tasks = [task1, task2, task3]; | |
}); | |
it('resolves to an array with each of the promise results', async () => { | |
const allTasksCompletedPromise = runTasksWithRateLimit(tasks); | |
p1.resolve('result 1'); | |
p2.resolve('result 2'); | |
p3.resolve('result 3'); | |
await expect(allTasksCompletedPromise).resolves.toEqual(['result 1', 'result 2', 'result 3']); | |
}); | |
it('respects max concurrent tasks limit', async () => { | |
runTasksWithRateLimit(tasks, 1); | |
expect(task1).toHaveBeenCalled(); | |
expect(task2).not.toHaveBeenCalled(); | |
p1.resolve('result 1'); | |
await flushPromises(); | |
expect(task2).toHaveBeenCalled(); | |
}); | |
it('respects min period between tasks', async () => { | |
runTasksWithRateLimit(tasks, 1, 1000); | |
jest.runOnlyPendingTimers(); | |
await flushPromises(); | |
expect(task1).toHaveBeenCalled(); | |
expect(task2).not.toHaveBeenCalled(); | |
p1.resolve('result 1'); | |
// let p1 finish & enqueue the next task | |
await flushPromises(); | |
// wait for the timeout to schedule p2 | |
jest.runOnlyPendingTimers(); | |
// flush the `await sleep` promise | |
await flushPromises(); | |
expect(task2).toHaveBeenCalled(); | |
}); | |
it('rejects as soon as any promise rejects', async () => { | |
const allTasksCompletedPromise = runTasksWithRateLimit(tasks, 1); | |
p1.resolve('result 1'); | |
await flushPromises(); | |
expect(task2).toHaveBeenCalled(); | |
p2.reject('error 2'); | |
await expect(allTasksCompletedPromise).rejects.toEqual('error 2'); | |
expect(task3).not.toHaveBeenCalled(); | |
}); | |
}); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { Preconditions } from 'base/preconditions'; | |
function sleep(ms: number): Promise<void> { | |
return new Promise(resolve => setTimeout(() => resolve(), ms)); | |
} | |
export type Task<T> = () => Promise<T>; | |
export function runTasksWithRateLimit<T>( | |
tasks: Task<T>[], | |
concurrentLimit: number = tasks.length, | |
minPeriodBetween: number = 0, | |
): Promise<T[]> { | |
const initialTasks = tasks.slice(0, concurrentLimit); | |
const remainingTasks = tasks.slice(concurrentLimit, Infinity); | |
const results: T[] = []; | |
return new Promise((resolve, reject) => { | |
const onTaskComplete = (result: T) => { | |
results.push(result); | |
if (remainingTasks.length) { | |
const nextTask = Preconditions.checkExists(remainingTasks.shift()); | |
enqueue(nextTask); | |
} else { | |
resolve(results); | |
} | |
}; | |
const onTaskError = (error: unknown) => { | |
reject(error); | |
}; | |
const enqueue = async (task: Task<T>) => { | |
if (minPeriodBetween) { | |
await sleep(minPeriodBetween); | |
} | |
task() | |
.then(onTaskComplete) | |
.catch(onTaskError); | |
}; | |
initialTasks.forEach(task => { | |
enqueue(task); | |
}); | |
}); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment