Skip to content

Instantly share code, notes, and snippets.

@kzu
Created October 9, 2024 21:12
Show Gist options
  • Save kzu/ce03963cdb0fd48ce1bbef6e6bcad52b to your computer and use it in GitHub Desktop.
Save kzu/ce03963cdb0fd48ce1bbef6e6bcad52b to your computer and use it in GitHub Desktop.
Typed extension for ChatClient
// <auto-generated />
#region License
// MIT License
//
// Copyright (c) Daniel Cazzulino
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
#endregion
#nullable enable
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
using System.Threading.Tasks;
using System.Threading;
using System.Text.Json.Schema;
namespace OpenAI.Chat;
/// <summary>
/// Provides strong-typed extension methods for <see cref="ChatClient"/>.
/// </summary>
/// <remarks>
/// Requires .NET 8+
/// </remarks>
/// <package id="OpenAI" version="2.0.0" />
/// <package id="System.Text.Json" version="9.0.0-rc.*" />
static partial class ChatClientTypedExtensions
{
static ConcurrentDictionary<Type, BinaryData> jsonSchemas = new();
static JsonSerializerOptions jsonOptions = new(JsonSerializerDefaults.Web)
{
TypeInfoResolver = new DefaultJsonTypeInfoResolver(),
PropertyNameCaseInsensitive = true,
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
NumberHandling = System.Text.Json.Serialization.JsonNumberHandling.Strict,
};
public static async Task<T?> CompleteChatAsync<T>(this ChatClient client, IEnumerable<ChatMessage> messages, ChatCompletionOptions? options = null, CancellationToken cancellationToken = default(CancellationToken))
{
options ??= new ChatCompletionOptions();
var elementType = typeof(T);
if (elementType.IsArray)
{
elementType = elementType.GetElementType()!;
}
else if (elementType.IsGenericType && (
elementType.GetGenericTypeDefinition() == typeof(IEnumerable<>) ||
elementType.GetGenericTypeDefinition() == typeof(ICollection<>) ||
elementType.GetGenericTypeDefinition() == typeof(List<>) ||
elementType.GetGenericTypeDefinition() == typeof(IList<>) ||
elementType.GetGenericTypeDefinition() == typeof(IReadOnlyCollection<>)))
{
elementType = elementType.GetGenericArguments()[0];
}
var typeName = elementType.Name;
if (elementType == typeof(T))
{
var schema = jsonSchemas.GetOrAdd(typeof(T), _ => GetJsonSchema<T>());
options.ResponseFormat = ChatResponseFormat.CreateJsonSchemaFormat(typeName, schema);
var response = await client.CompleteChatAsync(messages, options, cancellationToken);
var json = response.Value.Content.FirstOrDefault(x => x.Kind == ChatMessageContentPartKind.Text)?.Text;
if (string.IsNullOrEmpty(json))
return default;
return JsonSerializer.Deserialize<T>(json, jsonOptions);
}
else
{
typeName = $"{typeName}s";
var schema = jsonSchemas.GetOrAdd(typeof(Values<T>), _ => GetJsonSchema<Values<T>>());
options.ResponseFormat = ChatResponseFormat.CreateJsonSchemaFormat(typeName, schema);
var response = await client.CompleteChatAsync(messages, options, cancellationToken);
var json = response.Value.Content.FirstOrDefault(x => x.Kind == ChatMessageContentPartKind.Text)?.Text;
if (string.IsNullOrEmpty(json) ||
JsonSerializer.Deserialize<Values<T>>(json, jsonOptions) is not { } data)
return default;
return data.Data;
}
}
static BinaryData GetJsonSchema<T>()
{
var node = JsonSchemaExporter.GetJsonSchemaAsNode(jsonOptions, typeof(T), new()
{
TreatNullObliviousAsNonNullable = true,
TransformSchemaNode = (context, node) =>
{
var description = context.PropertyInfo?.AttributeProvider?.GetCustomAttributes(typeof(DescriptionAttribute), false)
.OfType<DescriptionAttribute>()
.FirstOrDefault()?.Description;
if (description != null)
node["description"] = description;
return node;
},
});
return BinaryData.FromString(node.ToJsonString());
}
public class Values<T>
{
public required T Data { get; set; }
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment