Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
Expand Down Expand Up @@ -126,11 +127,77 @@ public async Task<AgentResponse<T>> RunAsync<T>(
{
serializerOptions ??= AgentAbstractionsJsonUtilities.DefaultOptions;

var responseFormat = ChatResponseFormat.ForJsonSchema<T>(serializerOptions);

(responseFormat, bool isWrappedInObject) = EnsureObjectSchema(responseFormat);

options = options?.Clone() ?? new AgentRunOptions();
options.ResponseFormat = ChatResponseFormat.ForJsonSchema<T>(serializerOptions);
options.ResponseFormat = responseFormat;

AgentResponse response = await this.RunAsync(messages, session, options, cancellationToken).ConfigureAwait(false);

return new AgentResponse<T>(response, serializerOptions);
return new AgentResponse<T>(response, serializerOptions) { IsWrappedInObject = isWrappedInObject };
}

private static bool SchemaRepresentsObject(JsonElement? schema)
{
if (schema is not { } schemaElement)
{
return false;
}

if (schemaElement.ValueKind is JsonValueKind.Object)
{
foreach (var property in schemaElement.EnumerateObject())
{
if (property.NameEquals("type"u8))
{
return property.Value.ValueKind == JsonValueKind.String
&& property.Value.ValueEquals("object"u8);
}
}
}

return false;
}

private static (ChatResponseFormatJson ResponseFormat, bool IsWrappedInObject) EnsureObjectSchema(ChatResponseFormatJson responseFormat)
{
if (responseFormat.Schema is null)
{
throw new InvalidOperationException("The response format must have a valid JSON schema.");
}

var schema = responseFormat.Schema.Value;
bool isWrappedInObject = false;

if (!SchemaRepresentsObject(responseFormat.Schema))
{
// For non-object-representing schemas, we wrap them in an object schema, because all
// the real LLM providers today require an object schema as the root. This is currently
// true even for providers that support native structured output.
isWrappedInObject = true;
schema = JsonSerializer.SerializeToElement(new JsonObject
{
{ "$schema", "https://json-schema.org/draft/2020-12/schema" },
{ "type", "object" },
{ "properties", new JsonObject { { "data", JsonElementToJsonNode(schema) } } },
{ "additionalProperties", false },
{ "required", new JsonArray("data") },
}, AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonObject)));

responseFormat = ChatResponseFormat.ForJsonSchema(schema, responseFormat.SchemaName, responseFormat.SchemaDescription);
}

return (responseFormat, isWrappedInObject);
}

private static JsonNode? JsonElementToJsonNode(JsonElement element) =>
element.ValueKind switch
{
JsonValueKind.Null => null,
JsonValueKind.Array => JsonArray.Create(element),
JsonValueKind.Object => JsonObject.Create(element),
_ => JsonValue.Create(element)
};
}
26 changes: 26 additions & 0 deletions dotnet/src/Microsoft.Agents.AI.Abstractions/AgentResponse{T}.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ public AgentResponse(ChatResponse response, JsonSerializerOptions serializerOpti
this._serializerOptions = serializerOptions;
}

/// <summary>
/// Gets or sets a value indicating whether the JSON schema has an extra object wrapper.
/// </summary>
/// <remarks>
/// The wrapper is required for any non-JSON-object-typed values such as numbers, enum values, and arrays.
/// </remarks>
internal bool IsWrappedInObject { get; init; }

/// <summary>
/// Gets the result value of the agent response as an instance of <typeparamref name="T"/>.
/// </summary>
Expand All @@ -57,6 +65,11 @@ public virtual T Result
throw new InvalidOperationException("The response did not contain JSON to be deserialized.");
}

if (this.IsWrappedInObject)
{
json = UnwrapDataProperty(json!);
}

T? deserialized = DeserializeFirstTopLevelObject(json!, (JsonTypeInfo<T>)this._serializerOptions.GetTypeInfo(typeof(T)));
if (deserialized is null)
{
Expand All @@ -67,6 +80,19 @@ public virtual T Result
}
}

private static string UnwrapDataProperty(string json)
{
using var document = JsonDocument.Parse(json);
if (document.RootElement.ValueKind == JsonValueKind.Object &&
document.RootElement.TryGetProperty("data", out JsonElement dataElement))
{
return dataElement.GetRawText();
}

// If root is not an object or "data" property is not found, return the original JSON as a fallback
return json;
}

private static T? DeserializeFirstTopLevelObject(string json, JsonTypeInfo<T> typeInfo)
{
#if NET
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,25 @@ public virtual async Task RunWithGenericTypeReturnsExpectedResultAsync()
Assert.Equal("Paris", response.Result.Name);
}

[RetryFact(Constants.RetryCount, Constants.RetryDelay)]
public virtual async Task RunWithPrimitiveTypeReturnsExpectedResultAsync()
{
// Arrange
var agent = this.Fixture.Agent;
var session = await agent.CreateSessionAsync();
await using var cleanup = new SessionCleanup(session, this.Fixture);

// Act - Request a primitive type, which requires wrapping in an object schema
AgentResponse<int> response = await agent.RunAsync<int>(
new ChatMessage(ChatRole.User, "What is the sum of 15 and 27? Respond with just the number."),
session);

// Assert
Assert.NotNull(response);
Assert.Single(response.Messages);
Assert.Equal(42, response.Result);
}

protected static bool TryDeserialize<T>(string json, JsonSerializerOptions jsonSerializerOptions, out T structuredOutput)
{
try
Expand Down
Loading
Loading