Created
August 13, 2024 20:56
-
-
Save airhorns/393d1df954ae59a319ee2c5f1c3cf8a7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const strictSchemaCache = new WeakMap<ZodSchema, ZodSchema>(); | |
const strictZodSchema = (schema: ZodSchema): ZodSchema => { | |
if (schema instanceof ZodObject) { | |
const shape = schema.shape; | |
const newShape: any = {}; | |
for (const key in shape) { | |
const fieldSchema = shape[key]; | |
if (key == "default" && fieldSchema instanceof ZodOptional && fieldSchema.unwrap() instanceof ZodAny) { | |
continue; | |
} | |
newShape[key] = strictifyZodSchema(fieldSchema); | |
} | |
return z.object(newShape).required(); | |
} else if (schema instanceof ZodOptional) { | |
return z.union([strictifyZodSchema(schema.unwrap()), z.null()]).transform((value) => { | |
// un-nullify the value on output | |
if (isNull(value)) { | |
return undefined; | |
} else { | |
return value; | |
} | |
}); | |
} else if (schema instanceof ZodNullable) { | |
return z.nullable(strictifyZodSchema(schema.unwrap())); | |
} else if (schema instanceof ZodArray) { | |
return z.array(strictifyZodSchema(schema.element)); | |
} else if (schema instanceof ZodUnion) { | |
return z.union(schema.options.map(strictifyZodSchema)); | |
} else if (schema instanceof ZodDiscriminatedUnion) { | |
return z.discriminatedUnion(schema.discriminator, schema.options.map(strictifyZodSchema)); | |
} else if (schema instanceof ZodTuple) { | |
return z.tuple(schema.items.map(strictifyZodSchema)); | |
} else if (schema instanceof ZodRecord) { | |
return z.record(strictifyZodSchema(schema.valueSchema)); | |
} else if (schema instanceof ZodMap) { | |
return z.map(strictifyZodSchema(schema.keySchema), strictifyZodSchema(schema.valueSchema)); | |
} else if (schema instanceof ZodLazy) { | |
return z.lazy(() => strictifyZodSchema(schema.schema)); | |
} else if ( | |
schema instanceof ZodEnum || | |
schema instanceof ZodNativeEnum || | |
schema instanceof ZodString || | |
schema instanceof ZodNumber || | |
schema instanceof ZodBoolean || | |
schema instanceof ZodLiteral | |
) { | |
return schema; | |
} else { | |
throw new InternalError(`Unsupported schema type: ${schema.constructor.name}`); | |
} | |
}; | |
/** Walk a zod schema and make all fields required. Optional fields become unions with null, which is compatible with openai's structured output restrictions. */ | |
export const strictifyZodSchema = (schema: ZodSchema): ZodSchema => { | |
const cached = strictSchemaCache.get(schema); | |
if (cached) { | |
return cached; | |
} | |
const result = strictZodSchema(schema); | |
strictSchemaCache.set(schema, result); | |
return result; | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment