Skip to content

Instantly share code, notes, and snippets.

@Und3rf10w
Last active June 17, 2024 15:45
Show Gist options
  • Save Und3rf10w/05ebf6fee439ff31044956886cb138f7 to your computer and use it in GitHub Desktop.
Save Und3rf10w/05ebf6fee439ff31044956886cb138f7 to your computer and use it in GitHub Desktop.
A custom azure-openai-completion-provider for the vercel ai-sdk, provided as one file
import {
type LanguageModelV1,
type LanguageModelV1FinishReason,
type LanguageModelV1LogProbs,
type LanguageModelV1StreamPart,
UnsupportedFunctionalityError,
InvalidPromptError,
type LanguageModelV1Prompt,
} from '@ai-sdk/provider';
import {
type ParseResult,
createEventSourceResponseHandler,
createJsonResponseHandler,
createJsonErrorResponseHandler,
postJsonToApi,
} from '@ai-sdk/provider-utils';
import { z } from 'zod';
import {
type OpenAICompletionModelId,
type OpenAICompletionSettings,
} from '@ai-sdk/openai/internal';
export const openAIErrorDataSchema = z.object({
error: z.object({
message: z.string(),
type: z.string(),
param: z.any().nullable(),
code: z.string().nullable(),
}),
});
export type OpenAIErrorData = z.infer<typeof openAIErrorDataSchema>;
export const openaiFailedResponseHandler = createJsonErrorResponseHandler({
errorSchema: openAIErrorDataSchema,
errorToMessage: data => data.error.message,
});
export function mapOpenAIFinishReason(
finishReason: string | null | undefined,
): LanguageModelV1FinishReason {
switch (finishReason) {
case 'stop':
return 'stop';
case 'length':
return 'length';
case 'content_filter':
return 'content-filter';
case 'function_call':
case 'tool_calls':
return 'tool-calls';
default:
return 'unknown';
}
}
interface OpenAICompletionLogProps {
tokens: string[];
token_logprobs: number[];
top_logprobs: Array<Record<string, number>> | null;
}
export function mapOpenAICompletionLogProbs(
logprobs: OpenAICompletionLogProps | null | undefined,
): LanguageModelV1LogProbs | undefined {
return logprobs?.tokens.map((token, index) => ({
token,
logprob: logprobs.token_logprobs[index],
topLogprobs: logprobs.top_logprobs
? Object.entries(logprobs.top_logprobs[index]).map(
([token, logprob]) => ({
token,
logprob,
}),
)
: [],
}));
}
export function convertToOpenAICompletionPrompt({
prompt,
inputFormat,
user = 'user',
assistant = 'assistant',
}: {
prompt: LanguageModelV1Prompt;
inputFormat: 'prompt' | 'messages';
user?: string;
assistant?: string;
}): {
prompt: string;
stopSequences?: string[];
} {
// When the user supplied a prompt input, we don't transform it:
if (
inputFormat === 'prompt' &&
prompt.length === 1 &&
prompt[0].role === 'user' &&
prompt[0].content.length === 1 &&
prompt[0].content[0].type === 'text'
) {
return { prompt: prompt[0].content[0].text };
}
// otherwise transform to a chat message format:
let text = '';
// if first message is a system message, add it to the text:
if (prompt[0].role === 'system') {
text += `${prompt[0].content}\n\n`;
prompt = prompt.slice(1);
}
for (const { role, content } of prompt) {
switch (role) {
case 'system': {
throw new InvalidPromptError({
message: 'Unexpected system message in prompt: ${content}',
prompt,
});
}
case 'user': {
const userMessage = content
.map(part => {
switch (part.type) {
case 'text': {
return part.text;
}
case 'image': {
throw new UnsupportedFunctionalityError({
functionality: 'images',
});
}
}
})
.join('');
text += `${user}:\n${userMessage}\n\n`;
break;
}
case 'assistant': {
const assistantMessage = content
.map(part => {
switch (part.type) {
case 'text': {
return part.text;
}
case 'tool-call': {
throw new UnsupportedFunctionalityError({
functionality: 'tool-call messages',
});
}
}
})
.join('');
text += `${assistant}:\n${assistantMessage}\n\n`;
break;
}
case 'tool': {
throw new UnsupportedFunctionalityError({
functionality: 'tool messages',
});
}
default: {
const _exhaustiveCheck: never = role;
throw new Error(`Unsupported role: ${_exhaustiveCheck}`);
}
}
}
// Assistant message prefix:
text += `${assistant}:\n`;
return {
prompt: text,
stopSequences: [`\n${user}:`],
};
}
interface AzureOpenAICompletionConfig {
provider: string;
url: (options: { modelId: string; path: string }) => string;
compatibility: 'strict' | 'compatible';
headers: () => Record<string, string | undefined>;
fetch?: typeof fetch;
}
export class AzureOpenAICompletionLanguageModel implements LanguageModelV1 {
readonly specificationVersion = 'v1';
readonly defaultObjectGenerationMode = undefined;
readonly modelId: OpenAICompletionModelId;
readonly settings: OpenAICompletionSettings;
private readonly config: AzureOpenAICompletionConfig;
constructor(
modelId: OpenAICompletionModelId,
settings: OpenAICompletionSettings,
config: AzureOpenAICompletionConfig,
) {
this.modelId = modelId;
this.settings = settings;
this.config = config;
}
get provider(): string {
return this.config.provider;
}
private getArgs({
mode,
inputFormat,
prompt,
maxTokens,
temperature,
topP,
frequencyPenalty,
presencePenalty,
seed,
}: Parameters<LanguageModelV1['doGenerate']>[0]) {
const type = mode.type;
const { prompt: completionPrompt, stopSequences } =
convertToOpenAICompletionPrompt({ prompt, inputFormat });
const baseArgs = {
// model id:
model: this.modelId,
// model specific settings:
echo: this.settings.echo,
logit_bias: this.settings.logitBias,
logprobs:
typeof this.settings.logprobs === 'number'
? this.settings.logprobs
: typeof this.settings.logprobs === 'boolean'
? this.settings.logprobs
? 0
: undefined
: undefined,
suffix: this.settings.suffix,
user: this.settings.user,
// standardized settings:
max_tokens: maxTokens,
temperature,
top_p: topP,
frequency_penalty: frequencyPenalty,
presence_penalty: presencePenalty,
seed,
// prompt:
prompt: completionPrompt,
// stop sequences:
stop: stopSequences,
};
switch (type) {
case 'regular': {
if (mode.tools?.length) {
throw new UnsupportedFunctionalityError({
functionality: 'tools',
});
}
if (mode.toolChoice) {
throw new UnsupportedFunctionalityError({
functionality: 'toolChoice',
});
}
return baseArgs;
}
case 'object-json': {
throw new UnsupportedFunctionalityError({
functionality: 'object-json mode',
});
}
case 'object-tool': {
throw new UnsupportedFunctionalityError({
functionality: 'object-tool mode',
});
}
case 'object-grammar': {
throw new UnsupportedFunctionalityError({
functionality: 'object-grammar mode',
});
}
default: {
const _exhaustiveCheck: never = type;
throw new Error(`Unsupported type: ${_exhaustiveCheck}`);
}
}
}
async doGenerate(
options: Parameters<LanguageModelV1['doGenerate']>[0],
): Promise<Awaited<ReturnType<LanguageModelV1['doGenerate']>>> {
const args = this.getArgs(options);
const { responseHeaders, value: response } = await postJsonToApi({
url: this.config.url({
path: '/completions',
modelId: this.modelId,
}),
headers: this.config.headers(),
body: args,
failedResponseHandler: openaiFailedResponseHandler,
successfulResponseHandler: createJsonResponseHandler(
openAICompletionResponseSchema,
),
abortSignal: options.abortSignal,
fetch: this.config.fetch,
});
const { prompt: rawPrompt, ...rawSettings } = args;
const choice = response.choices[0];
return {
text: choice.text,
usage: {
promptTokens: response.usage.prompt_tokens,
completionTokens: response.usage.completion_tokens,
},
finishReason: mapOpenAIFinishReason(choice.finish_reason),
logprobs: mapOpenAICompletionLogProbs(choice.logprobs),
rawCall: { rawPrompt, rawSettings },
rawResponse: { headers: responseHeaders },
warnings: [],
};
}
async doStream(
options: Parameters<LanguageModelV1['doStream']>[0],
): Promise<Awaited<ReturnType<LanguageModelV1['doStream']>>> {
const args = this.getArgs(options);
const { responseHeaders, value: response } = await postJsonToApi({
url: this.config.url({
path: '/completions',
modelId: this.modelId,
}),
headers: this.config.headers(),
body: {
...this.getArgs(options),
stream: true,
// only include stream_options when in strict compatibility mode:
stream_options:
this.config.compatibility === 'strict'
? { include_usage: true }
: undefined,
},
failedResponseHandler: openaiFailedResponseHandler,
successfulResponseHandler: createEventSourceResponseHandler(
openaiCompletionChunkSchema,
),
abortSignal: options.abortSignal,
fetch: this.config.fetch,
});
const { prompt: rawPrompt, ...rawSettings } = args;
let finishReason: LanguageModelV1FinishReason = 'other';
let usage: { promptTokens: number; completionTokens: number } = {
promptTokens: Number.NaN,
completionTokens: Number.NaN,
};
let logprobs: LanguageModelV1LogProbs;
return {
stream: response.pipeThrough(
new TransformStream<
ParseResult<z.infer<typeof openaiCompletionChunkSchema>>,
LanguageModelV1StreamPart
>({
transform(chunk, controller) {
// handle failed chunk parsing / validation:
if (!chunk.success) {
finishReason = 'error';
controller.enqueue({ type: 'error', error: chunk.error });
return;
}
const value = chunk.value;
// handle error chunks:
if ('error' in value) {
finishReason = 'error';
controller.enqueue({ type: 'error', error: value.error });
return;
}
if (value.usage != null) {
usage = {
promptTokens: value.usage.prompt_tokens,
completionTokens: value.usage.completion_tokens,
};
}
const choice = value.choices[0];
if (choice?.finish_reason != null) {
finishReason = mapOpenAIFinishReason(choice.finish_reason);
}
if (choice?.text != null) {
controller.enqueue({
type: 'text-delta',
textDelta: choice.text,
});
}
const mappedLogprobs = mapOpenAICompletionLogProbs(
choice?.logprobs,
);
if (mappedLogprobs?.length) {
if (logprobs === undefined) logprobs = [];
logprobs.push(...mappedLogprobs);
}
},
flush(controller) {
controller.enqueue({
type: 'finish',
finishReason,
logprobs,
usage,
});
},
}),
),
rawCall: { rawPrompt, rawSettings },
rawResponse: { headers: responseHeaders },
warnings: [],
};
}
}
// limited version of the schema, focussed on what is needed for the implementation
// this approach limits breakages when the API changes and increases efficiency
const openAICompletionResponseSchema = z.object({
choices: z.array(
z.object({
text: z.string(),
finish_reason: z.string(),
logprobs: z
.object({
tokens: z.array(z.string()),
token_logprobs: z.array(z.number()),
top_logprobs: z.array(z.record(z.string(), z.number())).nullable(),
})
.nullable()
.optional(),
}),
),
usage: z.object({
prompt_tokens: z.number(),
completion_tokens: z.number(),
}),
});
// limited version of the schema, focussed on what is needed for the implementation
// this approach limits breakages when the API changes and increases efficiency
const openaiCompletionChunkSchema = z.union([
z.object({
choices: z.array(
z.object({
text: z.string(),
finish_reason: z.string().nullish(),
index: z.number(),
logprobs: z
.object({
tokens: z.array(z.string()),
token_logprobs: z.array(z.number()),
top_logprobs: z.array(z.record(z.string(), z.number())).nullable(),
})
.nullable()
.optional(),
}),
),
usage: z
.object({
prompt_tokens: z.number(),
completion_tokens: z.number(),
})
.optional()
.nullable(),
}),
openAIErrorDataSchema,
]);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment