Skip to content

Instantly share code, notes, and snippets.

@f1shy-dev
Created May 5, 2025 17:34
Show Gist options
  • Save f1shy-dev/540eeca44270e1ae595e21d4545e3a7b to your computer and use it in GitHub Desktop.
Save f1shy-dev/540eeca44270e1ae595e21d4545e3a7b to your computer and use it in GitHub Desktop.
import { DelayedPromise } from "@/lib/utils/delayed-promise";
import type {
generateText,
LanguageModel,
LanguageModelUsage,
TextStreamPart,
ToolSet,
} from "ai";
import { streamText } from "ai";
const retry = async <T>(
fn: (attempt: number) => Promise<T>,
tries = 3,
): Promise<T> => {
let attempts = 1;
const _wrap = async () => {
if (attempts > tries) {
throw new Error("Max retries reached");
}
try {
return await fn(attempts);
} catch (error) {
attempts++;
return await _wrap();
}
};
return await _wrap();
};
export type VertexSafetySettings = Array<{
category:
| "HARM_CATEGORY_UNSPECIFIED"
| "HARM_CATEGORY_HATE_SPEECH"
| "HARM_CATEGORY_DANGEROUS_CONTENT"
| "HARM_CATEGORY_HARASSMENT"
| "HARM_CATEGORY_SEXUALLY_EXPLICIT";
threshold:
| "HARM_BLOCK_THRESHOLD_UNSPECIFIED"
| "BLOCK_LOW_AND_ABOVE"
| "BLOCK_MEDIUM_AND_ABOVE"
| "BLOCK_ONLY_HIGH"
| "BLOCK_NONE";
}>;
type AsyncIterableStream<T> = AsyncIterable<T> & ReadableStream<T>;
export type SingularLMArray<T extends typeof generateText | typeof streamText> =
| [
LanguageModel,
Omit<Omit<Parameters<T>[0], "model">, "schema"> & {
safetySettings?: VertexSafetySettings;
},
]
| [
LanguageModel,
Omit<Omit<Parameters<T>[0], "model">, "schema"> & {
safetySettings?: VertexSafetySettings;
},
string,
];
type RetryAIGenerateReturnType<
T extends typeof generateText | typeof streamText,
> = ReturnType<T> & {
attempts: number;
resolvedModel: Promise<SingularLMArray<T>>;
} & (T extends typeof streamText
? { fullStreamWithoutErrors: AsyncIterableStream<TextStreamPart<ToolSet>> }
: unknown);
export const retryAIGenerate = async <
T extends typeof generateText | typeof streamText,
>(
fn: T,
{
models,
sharedOptions,
}: {
models: SingularLMArray<T>[];
sharedOptions: Omit<Parameters<T>[0], "model">;
},
): Promise<RetryAIGenerateReturnType<T>> => {
let attempts = 1;
let persistentOutputTransform: TransformStream | null = null;
let textAccumulator = "";
let activeUsagePromise: Promise<LanguageModelUsage> | null = null;
const textPromise = new DelayedPromise<string>();
const usagePromise = new DelayedPromise<LanguageModelUsage>();
const resolvedModelPromise = new DelayedPromise<SingularLMArray<T>>();
const resolveModel = () => {
let key = models[attempts - 1];
if (!key) key = models[models.length - 1];
resolvedModelPromise.resolve(key);
};
if (fn === streamText) {
persistentOutputTransform = new TransformStream({
transform(chunk, controller) {
if (chunk.type === "text") {
textAccumulator += chunk.text;
}
controller.enqueue(chunk);
},
async flush() {
textPromise.resolve(textAccumulator);
if (activeUsagePromise) {
usagePromise.resolve(await activeUsagePromise);
}
resolveModel();
},
});
}
const _wrap = async () => {
if (attempts > models.length) {
if (persistentOutputTransform) {
const writer = persistentOutputTransform.writable.getWriter();
await writer.abort(new Error("Max retries reached"));
textPromise.resolve(textAccumulator);
usagePromise.reject(new Error("Max retries reached"));
resolveModel();
}
throw new Error("Max retries reached");
}
try {
const options_combined = {
...(models[attempts - 1][1] || {}),
providerOptions: {
...(sharedOptions.providerOptions || {}),
...(models[attempts - 1][1]?.providerOptions || {}),
},
} as Parameters<T>[0];
console.log(
"Attempting to generate text with model",
models[attempts - 1][0].modelId,
"with options",
options_combined,
);
const final = await fn({
...sharedOptions,
...options_combined,
model: models[attempts - 1][0],
maxRetries: 1,
} as Parameters<T>[0]);
Object.assign(final, {
attempts,
});
if ("fullStream" in final && persistentOutputTransform) {
activeUsagePromise = final.usage;
const processCurrentModelStream = async () => {
const reader = final.fullStream.getReader();
const writer = persistentOutputTransform!.writable.getWriter();
try {
while (true) {
const { done, value } = await reader.read();
if (done) {
await writer.close();
break;
}
if (value.type === "error") {
console.error("gen[error] <- in retryAIGenerate stream");
// Instead of aborting, we'll retry with the next model
throw value.error;
}
await writer.write(value);
}
usagePromise.resolve(await final.usage);
resolveModel();
} catch (error) {
console.error("Stream processing error:", error);
writer.releaseLock();
attempts++;
if (attempts <= models.length) {
await _wrap();
} else {
const finalWriter =
persistentOutputTransform!.writable.getWriter();
await finalWriter.abort(error);
textPromise.resolve(textAccumulator);
usagePromise.reject(error);
resolveModel();
}
}
};
processCurrentModelStream();
Object.defineProperty(final, "text", {
enumerable: true,
configurable: true,
get() {
return textPromise.value;
},
});
Object.defineProperty(final, "usage", {
enumerable: true,
configurable: true,
get() {
return usagePromise.value;
},
});
Object.defineProperty(final, "resolvedModel", {
enumerable: true,
configurable: true,
get() {
return resolvedModelPromise.value;
},
});
Object.assign(final, {
fullStreamWithoutErrors: persistentOutputTransform.readable,
});
}
return final as RetryAIGenerateReturnType<T>;
} catch (error) {
attempts++;
return await _wrap();
}
};
return await _wrap();
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment