Last active
June 17, 2024 15:45
-
-
Save Und3rf10w/05ebf6fee439ff31044956886cb138f7 to your computer and use it in GitHub Desktop.
A custom azure-openai-completion-provider for the vercel ai-sdk, provided as one file
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { | |
type LanguageModelV1, | |
type LanguageModelV1FinishReason, | |
type LanguageModelV1LogProbs, | |
type LanguageModelV1StreamPart, | |
UnsupportedFunctionalityError, | |
InvalidPromptError, | |
type LanguageModelV1Prompt, | |
} from '@ai-sdk/provider'; | |
import { | |
type ParseResult, | |
createEventSourceResponseHandler, | |
createJsonResponseHandler, | |
createJsonErrorResponseHandler, | |
postJsonToApi, | |
} from '@ai-sdk/provider-utils'; | |
import { z } from 'zod'; | |
import { | |
type OpenAICompletionModelId, | |
type OpenAICompletionSettings, | |
} from '@ai-sdk/openai/internal'; | |
export const openAIErrorDataSchema = z.object({ | |
error: z.object({ | |
message: z.string(), | |
type: z.string(), | |
param: z.any().nullable(), | |
code: z.string().nullable(), | |
}), | |
}); | |
export type OpenAIErrorData = z.infer<typeof openAIErrorDataSchema>; | |
export const openaiFailedResponseHandler = createJsonErrorResponseHandler({ | |
errorSchema: openAIErrorDataSchema, | |
errorToMessage: data => data.error.message, | |
}); | |
export function mapOpenAIFinishReason( | |
finishReason: string | null | undefined, | |
): LanguageModelV1FinishReason { | |
switch (finishReason) { | |
case 'stop': | |
return 'stop'; | |
case 'length': | |
return 'length'; | |
case 'content_filter': | |
return 'content-filter'; | |
case 'function_call': | |
case 'tool_calls': | |
return 'tool-calls'; | |
default: | |
return 'unknown'; | |
} | |
} | |
interface OpenAICompletionLogProps { | |
tokens: string[]; | |
token_logprobs: number[]; | |
top_logprobs: Array<Record<string, number>> | null; | |
} | |
export function mapOpenAICompletionLogProbs( | |
logprobs: OpenAICompletionLogProps | null | undefined, | |
): LanguageModelV1LogProbs | undefined { | |
return logprobs?.tokens.map((token, index) => ({ | |
token, | |
logprob: logprobs.token_logprobs[index], | |
topLogprobs: logprobs.top_logprobs | |
? Object.entries(logprobs.top_logprobs[index]).map( | |
([token, logprob]) => ({ | |
token, | |
logprob, | |
}), | |
) | |
: [], | |
})); | |
} | |
export function convertToOpenAICompletionPrompt({ | |
prompt, | |
inputFormat, | |
user = 'user', | |
assistant = 'assistant', | |
}: { | |
prompt: LanguageModelV1Prompt; | |
inputFormat: 'prompt' | 'messages'; | |
user?: string; | |
assistant?: string; | |
}): { | |
prompt: string; | |
stopSequences?: string[]; | |
} { | |
// When the user supplied a prompt input, we don't transform it: | |
if ( | |
inputFormat === 'prompt' && | |
prompt.length === 1 && | |
prompt[0].role === 'user' && | |
prompt[0].content.length === 1 && | |
prompt[0].content[0].type === 'text' | |
) { | |
return { prompt: prompt[0].content[0].text }; | |
} | |
// otherwise transform to a chat message format: | |
let text = ''; | |
// if first message is a system message, add it to the text: | |
if (prompt[0].role === 'system') { | |
text += `${prompt[0].content}\n\n`; | |
prompt = prompt.slice(1); | |
} | |
for (const { role, content } of prompt) { | |
switch (role) { | |
case 'system': { | |
throw new InvalidPromptError({ | |
message: 'Unexpected system message in prompt: ${content}', | |
prompt, | |
}); | |
} | |
case 'user': { | |
const userMessage = content | |
.map(part => { | |
switch (part.type) { | |
case 'text': { | |
return part.text; | |
} | |
case 'image': { | |
throw new UnsupportedFunctionalityError({ | |
functionality: 'images', | |
}); | |
} | |
} | |
}) | |
.join(''); | |
text += `${user}:\n${userMessage}\n\n`; | |
break; | |
} | |
case 'assistant': { | |
const assistantMessage = content | |
.map(part => { | |
switch (part.type) { | |
case 'text': { | |
return part.text; | |
} | |
case 'tool-call': { | |
throw new UnsupportedFunctionalityError({ | |
functionality: 'tool-call messages', | |
}); | |
} | |
} | |
}) | |
.join(''); | |
text += `${assistant}:\n${assistantMessage}\n\n`; | |
break; | |
} | |
case 'tool': { | |
throw new UnsupportedFunctionalityError({ | |
functionality: 'tool messages', | |
}); | |
} | |
default: { | |
const _exhaustiveCheck: never = role; | |
throw new Error(`Unsupported role: ${_exhaustiveCheck}`); | |
} | |
} | |
} | |
// Assistant message prefix: | |
text += `${assistant}:\n`; | |
return { | |
prompt: text, | |
stopSequences: [`\n${user}:`], | |
}; | |
} | |
interface AzureOpenAICompletionConfig { | |
provider: string; | |
url: (options: { modelId: string; path: string }) => string; | |
compatibility: 'strict' | 'compatible'; | |
headers: () => Record<string, string | undefined>; | |
fetch?: typeof fetch; | |
} | |
export class AzureOpenAICompletionLanguageModel implements LanguageModelV1 { | |
readonly specificationVersion = 'v1'; | |
readonly defaultObjectGenerationMode = undefined; | |
readonly modelId: OpenAICompletionModelId; | |
readonly settings: OpenAICompletionSettings; | |
private readonly config: AzureOpenAICompletionConfig; | |
constructor( | |
modelId: OpenAICompletionModelId, | |
settings: OpenAICompletionSettings, | |
config: AzureOpenAICompletionConfig, | |
) { | |
this.modelId = modelId; | |
this.settings = settings; | |
this.config = config; | |
} | |
get provider(): string { | |
return this.config.provider; | |
} | |
private getArgs({ | |
mode, | |
inputFormat, | |
prompt, | |
maxTokens, | |
temperature, | |
topP, | |
frequencyPenalty, | |
presencePenalty, | |
seed, | |
}: Parameters<LanguageModelV1['doGenerate']>[0]) { | |
const type = mode.type; | |
const { prompt: completionPrompt, stopSequences } = | |
convertToOpenAICompletionPrompt({ prompt, inputFormat }); | |
const baseArgs = { | |
// model id: | |
model: this.modelId, | |
// model specific settings: | |
echo: this.settings.echo, | |
logit_bias: this.settings.logitBias, | |
logprobs: | |
typeof this.settings.logprobs === 'number' | |
? this.settings.logprobs | |
: typeof this.settings.logprobs === 'boolean' | |
? this.settings.logprobs | |
? 0 | |
: undefined | |
: undefined, | |
suffix: this.settings.suffix, | |
user: this.settings.user, | |
// standardized settings: | |
max_tokens: maxTokens, | |
temperature, | |
top_p: topP, | |
frequency_penalty: frequencyPenalty, | |
presence_penalty: presencePenalty, | |
seed, | |
// prompt: | |
prompt: completionPrompt, | |
// stop sequences: | |
stop: stopSequences, | |
}; | |
switch (type) { | |
case 'regular': { | |
if (mode.tools?.length) { | |
throw new UnsupportedFunctionalityError({ | |
functionality: 'tools', | |
}); | |
} | |
if (mode.toolChoice) { | |
throw new UnsupportedFunctionalityError({ | |
functionality: 'toolChoice', | |
}); | |
} | |
return baseArgs; | |
} | |
case 'object-json': { | |
throw new UnsupportedFunctionalityError({ | |
functionality: 'object-json mode', | |
}); | |
} | |
case 'object-tool': { | |
throw new UnsupportedFunctionalityError({ | |
functionality: 'object-tool mode', | |
}); | |
} | |
case 'object-grammar': { | |
throw new UnsupportedFunctionalityError({ | |
functionality: 'object-grammar mode', | |
}); | |
} | |
default: { | |
const _exhaustiveCheck: never = type; | |
throw new Error(`Unsupported type: ${_exhaustiveCheck}`); | |
} | |
} | |
} | |
async doGenerate( | |
options: Parameters<LanguageModelV1['doGenerate']>[0], | |
): Promise<Awaited<ReturnType<LanguageModelV1['doGenerate']>>> { | |
const args = this.getArgs(options); | |
const { responseHeaders, value: response } = await postJsonToApi({ | |
url: this.config.url({ | |
path: '/completions', | |
modelId: this.modelId, | |
}), | |
headers: this.config.headers(), | |
body: args, | |
failedResponseHandler: openaiFailedResponseHandler, | |
successfulResponseHandler: createJsonResponseHandler( | |
openAICompletionResponseSchema, | |
), | |
abortSignal: options.abortSignal, | |
fetch: this.config.fetch, | |
}); | |
const { prompt: rawPrompt, ...rawSettings } = args; | |
const choice = response.choices[0]; | |
return { | |
text: choice.text, | |
usage: { | |
promptTokens: response.usage.prompt_tokens, | |
completionTokens: response.usage.completion_tokens, | |
}, | |
finishReason: mapOpenAIFinishReason(choice.finish_reason), | |
logprobs: mapOpenAICompletionLogProbs(choice.logprobs), | |
rawCall: { rawPrompt, rawSettings }, | |
rawResponse: { headers: responseHeaders }, | |
warnings: [], | |
}; | |
} | |
async doStream( | |
options: Parameters<LanguageModelV1['doStream']>[0], | |
): Promise<Awaited<ReturnType<LanguageModelV1['doStream']>>> { | |
const args = this.getArgs(options); | |
const { responseHeaders, value: response } = await postJsonToApi({ | |
url: this.config.url({ | |
path: '/completions', | |
modelId: this.modelId, | |
}), | |
headers: this.config.headers(), | |
body: { | |
...this.getArgs(options), | |
stream: true, | |
// only include stream_options when in strict compatibility mode: | |
stream_options: | |
this.config.compatibility === 'strict' | |
? { include_usage: true } | |
: undefined, | |
}, | |
failedResponseHandler: openaiFailedResponseHandler, | |
successfulResponseHandler: createEventSourceResponseHandler( | |
openaiCompletionChunkSchema, | |
), | |
abortSignal: options.abortSignal, | |
fetch: this.config.fetch, | |
}); | |
const { prompt: rawPrompt, ...rawSettings } = args; | |
let finishReason: LanguageModelV1FinishReason = 'other'; | |
let usage: { promptTokens: number; completionTokens: number } = { | |
promptTokens: Number.NaN, | |
completionTokens: Number.NaN, | |
}; | |
let logprobs: LanguageModelV1LogProbs; | |
return { | |
stream: response.pipeThrough( | |
new TransformStream< | |
ParseResult<z.infer<typeof openaiCompletionChunkSchema>>, | |
LanguageModelV1StreamPart | |
>({ | |
transform(chunk, controller) { | |
// handle failed chunk parsing / validation: | |
if (!chunk.success) { | |
finishReason = 'error'; | |
controller.enqueue({ type: 'error', error: chunk.error }); | |
return; | |
} | |
const value = chunk.value; | |
// handle error chunks: | |
if ('error' in value) { | |
finishReason = 'error'; | |
controller.enqueue({ type: 'error', error: value.error }); | |
return; | |
} | |
if (value.usage != null) { | |
usage = { | |
promptTokens: value.usage.prompt_tokens, | |
completionTokens: value.usage.completion_tokens, | |
}; | |
} | |
const choice = value.choices[0]; | |
if (choice?.finish_reason != null) { | |
finishReason = mapOpenAIFinishReason(choice.finish_reason); | |
} | |
if (choice?.text != null) { | |
controller.enqueue({ | |
type: 'text-delta', | |
textDelta: choice.text, | |
}); | |
} | |
const mappedLogprobs = mapOpenAICompletionLogProbs( | |
choice?.logprobs, | |
); | |
if (mappedLogprobs?.length) { | |
if (logprobs === undefined) logprobs = []; | |
logprobs.push(...mappedLogprobs); | |
} | |
}, | |
flush(controller) { | |
controller.enqueue({ | |
type: 'finish', | |
finishReason, | |
logprobs, | |
usage, | |
}); | |
}, | |
}), | |
), | |
rawCall: { rawPrompt, rawSettings }, | |
rawResponse: { headers: responseHeaders }, | |
warnings: [], | |
}; | |
} | |
} | |
// limited version of the schema, focussed on what is needed for the implementation | |
// this approach limits breakages when the API changes and increases efficiency | |
const openAICompletionResponseSchema = z.object({ | |
choices: z.array( | |
z.object({ | |
text: z.string(), | |
finish_reason: z.string(), | |
logprobs: z | |
.object({ | |
tokens: z.array(z.string()), | |
token_logprobs: z.array(z.number()), | |
top_logprobs: z.array(z.record(z.string(), z.number())).nullable(), | |
}) | |
.nullable() | |
.optional(), | |
}), | |
), | |
usage: z.object({ | |
prompt_tokens: z.number(), | |
completion_tokens: z.number(), | |
}), | |
}); | |
// limited version of the schema, focussed on what is needed for the implementation | |
// this approach limits breakages when the API changes and increases efficiency | |
const openaiCompletionChunkSchema = z.union([ | |
z.object({ | |
choices: z.array( | |
z.object({ | |
text: z.string(), | |
finish_reason: z.string().nullish(), | |
index: z.number(), | |
logprobs: z | |
.object({ | |
tokens: z.array(z.string()), | |
token_logprobs: z.array(z.number()), | |
top_logprobs: z.array(z.record(z.string(), z.number())).nullable(), | |
}) | |
.nullable() | |
.optional(), | |
}), | |
), | |
usage: z | |
.object({ | |
prompt_tokens: z.number(), | |
completion_tokens: z.number(), | |
}) | |
.optional() | |
.nullable(), | |
}), | |
openAIErrorDataSchema, | |
]); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment