Created
May 5, 2025 17:34
-
-
Save f1shy-dev/540eeca44270e1ae595e21d4545e3a7b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { DelayedPromise } from "@/lib/utils/delayed-promise"; | |
import type { | |
generateText, | |
LanguageModel, | |
LanguageModelUsage, | |
TextStreamPart, | |
ToolSet, | |
} from "ai"; | |
import { streamText } from "ai"; | |
const retry = async <T>( | |
fn: (attempt: number) => Promise<T>, | |
tries = 3, | |
): Promise<T> => { | |
let attempts = 1; | |
const _wrap = async () => { | |
if (attempts > tries) { | |
throw new Error("Max retries reached"); | |
} | |
try { | |
return await fn(attempts); | |
} catch (error) { | |
attempts++; | |
return await _wrap(); | |
} | |
}; | |
return await _wrap(); | |
}; | |
export type VertexSafetySettings = Array<{ | |
category: | |
| "HARM_CATEGORY_UNSPECIFIED" | |
| "HARM_CATEGORY_HATE_SPEECH" | |
| "HARM_CATEGORY_DANGEROUS_CONTENT" | |
| "HARM_CATEGORY_HARASSMENT" | |
| "HARM_CATEGORY_SEXUALLY_EXPLICIT"; | |
threshold: | |
| "HARM_BLOCK_THRESHOLD_UNSPECIFIED" | |
| "BLOCK_LOW_AND_ABOVE" | |
| "BLOCK_MEDIUM_AND_ABOVE" | |
| "BLOCK_ONLY_HIGH" | |
| "BLOCK_NONE"; | |
}>; | |
type AsyncIterableStream<T> = AsyncIterable<T> & ReadableStream<T>; | |
export type SingularLMArray<T extends typeof generateText | typeof streamText> = | |
| [ | |
LanguageModel, | |
Omit<Omit<Parameters<T>[0], "model">, "schema"> & { | |
safetySettings?: VertexSafetySettings; | |
}, | |
] | |
| [ | |
LanguageModel, | |
Omit<Omit<Parameters<T>[0], "model">, "schema"> & { | |
safetySettings?: VertexSafetySettings; | |
}, | |
string, | |
]; | |
type RetryAIGenerateReturnType< | |
T extends typeof generateText | typeof streamText, | |
> = ReturnType<T> & { | |
attempts: number; | |
resolvedModel: Promise<SingularLMArray<T>>; | |
} & (T extends typeof streamText | |
? { fullStreamWithoutErrors: AsyncIterableStream<TextStreamPart<ToolSet>> } | |
: unknown); | |
export const retryAIGenerate = async < | |
T extends typeof generateText | typeof streamText, | |
>( | |
fn: T, | |
{ | |
models, | |
sharedOptions, | |
}: { | |
models: SingularLMArray<T>[]; | |
sharedOptions: Omit<Parameters<T>[0], "model">; | |
}, | |
): Promise<RetryAIGenerateReturnType<T>> => { | |
let attempts = 1; | |
let persistentOutputTransform: TransformStream | null = null; | |
let textAccumulator = ""; | |
let activeUsagePromise: Promise<LanguageModelUsage> | null = null; | |
const textPromise = new DelayedPromise<string>(); | |
const usagePromise = new DelayedPromise<LanguageModelUsage>(); | |
const resolvedModelPromise = new DelayedPromise<SingularLMArray<T>>(); | |
const resolveModel = () => { | |
let key = models[attempts - 1]; | |
if (!key) key = models[models.length - 1]; | |
resolvedModelPromise.resolve(key); | |
}; | |
if (fn === streamText) { | |
persistentOutputTransform = new TransformStream({ | |
transform(chunk, controller) { | |
if (chunk.type === "text") { | |
textAccumulator += chunk.text; | |
} | |
controller.enqueue(chunk); | |
}, | |
async flush() { | |
textPromise.resolve(textAccumulator); | |
if (activeUsagePromise) { | |
usagePromise.resolve(await activeUsagePromise); | |
} | |
resolveModel(); | |
}, | |
}); | |
} | |
const _wrap = async () => { | |
if (attempts > models.length) { | |
if (persistentOutputTransform) { | |
const writer = persistentOutputTransform.writable.getWriter(); | |
await writer.abort(new Error("Max retries reached")); | |
textPromise.resolve(textAccumulator); | |
usagePromise.reject(new Error("Max retries reached")); | |
resolveModel(); | |
} | |
throw new Error("Max retries reached"); | |
} | |
try { | |
const options_combined = { | |
...(models[attempts - 1][1] || {}), | |
providerOptions: { | |
...(sharedOptions.providerOptions || {}), | |
...(models[attempts - 1][1]?.providerOptions || {}), | |
}, | |
} as Parameters<T>[0]; | |
console.log( | |
"Attempting to generate text with model", | |
models[attempts - 1][0].modelId, | |
"with options", | |
options_combined, | |
); | |
const final = await fn({ | |
...sharedOptions, | |
...options_combined, | |
model: models[attempts - 1][0], | |
maxRetries: 1, | |
} as Parameters<T>[0]); | |
Object.assign(final, { | |
attempts, | |
}); | |
if ("fullStream" in final && persistentOutputTransform) { | |
activeUsagePromise = final.usage; | |
const processCurrentModelStream = async () => { | |
const reader = final.fullStream.getReader(); | |
const writer = persistentOutputTransform!.writable.getWriter(); | |
try { | |
while (true) { | |
const { done, value } = await reader.read(); | |
if (done) { | |
await writer.close(); | |
break; | |
} | |
if (value.type === "error") { | |
console.error("gen[error] <- in retryAIGenerate stream"); | |
// Instead of aborting, we'll retry with the next model | |
throw value.error; | |
} | |
await writer.write(value); | |
} | |
usagePromise.resolve(await final.usage); | |
resolveModel(); | |
} catch (error) { | |
console.error("Stream processing error:", error); | |
writer.releaseLock(); | |
attempts++; | |
if (attempts <= models.length) { | |
await _wrap(); | |
} else { | |
const finalWriter = | |
persistentOutputTransform!.writable.getWriter(); | |
await finalWriter.abort(error); | |
textPromise.resolve(textAccumulator); | |
usagePromise.reject(error); | |
resolveModel(); | |
} | |
} | |
}; | |
processCurrentModelStream(); | |
Object.defineProperty(final, "text", { | |
enumerable: true, | |
configurable: true, | |
get() { | |
return textPromise.value; | |
}, | |
}); | |
Object.defineProperty(final, "usage", { | |
enumerable: true, | |
configurable: true, | |
get() { | |
return usagePromise.value; | |
}, | |
}); | |
Object.defineProperty(final, "resolvedModel", { | |
enumerable: true, | |
configurable: true, | |
get() { | |
return resolvedModelPromise.value; | |
}, | |
}); | |
Object.assign(final, { | |
fullStreamWithoutErrors: persistentOutputTransform.readable, | |
}); | |
} | |
return final as RetryAIGenerateReturnType<T>; | |
} catch (error) { | |
attempts++; | |
return await _wrap(); | |
} | |
}; | |
return await _wrap(); | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment