Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 123 additions & 31 deletions src/Azure.DataApiBuilder.Mcp/Core/McpStdioServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using Azure.DataApiBuilder.Config.ObjectModel;
using Azure.DataApiBuilder.Core.AuthenticationHelpers.AuthenticationSimulator;
using Azure.DataApiBuilder.Core.Configurations;
using Azure.DataApiBuilder.Core.Telemetry;
using Azure.DataApiBuilder.Mcp.Model;
using Azure.DataApiBuilder.Mcp.Utils;
using Microsoft.AspNetCore.Http;
Expand Down Expand Up @@ -46,8 +47,6 @@ public McpStdioServer(McpToolRegistry toolRegistry, IServiceProvider serviceProv
/// <returns>A task representing the asynchronous operation.</returns>
public async Task RunAsync(CancellationToken cancellationToken)
{
Console.Error.WriteLine("[MCP DEBUG] MCP stdio server started.");

// Use UTF-8 WITHOUT BOM
UTF8Encoding utf8NoBom = new(encoderShouldEmitUTF8Identifier: false);

Expand Down Expand Up @@ -77,15 +76,13 @@ public async Task RunAsync(CancellationToken cancellationToken)
{
doc = JsonDocument.Parse(line);
}
catch (JsonException jsonEx)
catch (JsonException)
{
Console.Error.WriteLine($"[MCP DEBUG] JSON parse error: {jsonEx.Message}");
WriteError(id: null, code: McpStdioJsonRpcErrorCodes.PARSE_ERROR, message: "Parse error");
continue;
}
catch (Exception ex)
catch (Exception)
{
Console.Error.WriteLine($"[MCP DEBUG] Unexpected error parsing request: {ex.Message}");
WriteError(id: null, code: McpStdioJsonRpcErrorCodes.INTERNAL_ERROR, message: "Internal error");
continue;
}
Expand Down Expand Up @@ -131,6 +128,10 @@ public async Task RunAsync(CancellationToken cancellationToken)
WriteResult(id, new { ok = true });
break;

case "logging/setLevel":
HandleSetLogLevel(id, root);
break;

case "shutdown":
WriteResult(id, new { ok = true });
return;
Expand Down Expand Up @@ -171,30 +172,50 @@ private void HandleInitialize(JsonElement? id)
RuntimeConfig runtimeConfig = runtimeConfigProvider.GetConfig();
instructions = runtimeConfig.Runtime?.Mcp?.Description;
}
catch (Exception ex)
catch (Exception)
{
// Log to stderr for diagnostics and rethrow to avoid masking configuration errors
Console.Error.WriteLine($"[MCP WARNING] Failed to retrieve MCP description from config: {ex.Message}");
// Rethrow to avoid masking configuration errors
throw;
}
}

// Create the initialize response
object result = new
// Create the initialize response - only include instructions if non-empty
object result;
if (!string.IsNullOrWhiteSpace(instructions))
{
protocolVersion = _protocolVersion,
capabilities = new
result = new
{
tools = new { listChanged = true },
logging = new { }
},
serverInfo = new
protocolVersion = _protocolVersion,
capabilities = new
{
tools = new { listChanged = true },
logging = new { }
},
serverInfo = new
{
name = McpProtocolDefaults.MCP_SERVER_NAME,
version = McpProtocolDefaults.MCP_SERVER_VERSION
},
instructions = instructions
};
}
else
{
result = new
{
name = McpProtocolDefaults.MCP_SERVER_NAME,
version = McpProtocolDefaults.MCP_SERVER_VERSION
},
instructions = !string.IsNullOrWhiteSpace(instructions) ? instructions : null
};
protocolVersion = _protocolVersion,
capabilities = new
{
tools = new { listChanged = true },
logging = new { }
},
serverInfo = new
{
name = McpProtocolDefaults.MCP_SERVER_NAME,
version = McpProtocolDefaults.MCP_SERVER_VERSION
}
};
}

WriteResult(id, result);
}
Expand Down Expand Up @@ -228,6 +249,85 @@ private void HandleListTools(JsonElement? id)
WriteResult(id, new { tools = toolsWire });
}

/// <summary>
/// Handles the "logging/setLevel" JSON-RPC method by updating the runtime log level.
/// </summary>
/// <param name="id">The request identifier extracted from the incoming JSON-RPC request.</param>
/// <param name="root">The root JSON element of the incoming JSON-RPC request.</param>
/// <remarks>
/// Log level precedence (highest to lowest):
/// 1. CLI --LogLevel flag - cannot be overridden
/// 2. Config runtime.telemetry.log-level - cannot be overridden by MCP
/// 3. MCP logging/setLevel - only works if neither CLI nor Config explicitly set a level
/// 4. Default: None for MCP stdio mode (silent by default to keep stdout clean for JSON-RPC)
///
/// If CLI or Config set the log level, this method accepts the request but silently ignores it.
/// The client won't get an error, but CLI/Config wins.
///
/// When MCP sets a level other than "none", this also restores Console.Error to the real stderr
/// stream so that logs become visible (Console may have been redirected to null at startup).
/// It also enables MCP log notifications so logs are sent to the client via notifications/message.
/// </remarks>
private void HandleSetLogLevel(JsonElement? id, JsonElement root)
{
// Extract the level parameter from the request
string? level = null;
if (root.TryGetProperty("params", out JsonElement paramsEl) &&
paramsEl.TryGetProperty("level", out JsonElement levelEl) &&
levelEl.ValueKind == JsonValueKind.String)
{
level = levelEl.GetString();
}

if (string.IsNullOrWhiteSpace(level))
{
WriteError(id, McpStdioJsonRpcErrorCodes.INVALID_PARAMS, "Missing or invalid 'level' parameter");
return;
}

// Get the ILogLevelController from service provider
ILogLevelController? logLevelController = _serviceProvider.GetService<ILogLevelController>();
if (logLevelController is null)
{
// Log level controller not available - still accept request per MCP spec
WriteResult(id, new { });
return;
}

// Attempt to update the log level
// If CLI or Config overrode, this returns false but we still return success to the client
bool updated = logLevelController.UpdateFromMcp(level);

// If MCP successfully changed the log level to something other than "none",
// ensure Console.Error is pointing to the real stderr (not TextWriter.Null).
// This handles the case where MCP stdio mode started with LogLevel.None (quiet startup)
// and the client later enables logging via logging/setLevel.
bool isLoggingEnabled = !string.Equals(level, "none", StringComparison.OrdinalIgnoreCase);
if (updated && isLoggingEnabled)
{
RestoreStderrIfNeeded();
}

// Always return success (empty result object) per MCP spec
WriteResult(id, new { });
}

/// <summary>
/// Restores Console.Error to the real stderr stream if it was redirected to TextWriter.Null.
/// This enables log output after MCP client sends logging/setLevel with a level other than "none".
/// </summary>
private static void RestoreStderrIfNeeded()
{
// Always restore stderr to the real stream when MCP enables logging.
// This is safe to call multiple times - we just re-wrap the standard error stream.
Stream stderr = Console.OpenStandardError();
StreamWriter stderrWriter = new(stderr, new UTF8Encoding(encoderShouldEmitUTF8Identifier: false))
{
AutoFlush = true
};
Console.SetError(stderrWriter);
}

/// <summary>
/// Handles the "tools/call" JSON-RPC method by executing the specified tool with the provided arguments.
/// </summary>
Expand Down Expand Up @@ -259,14 +359,12 @@ private async Task HandleCallToolAsync(JsonElement? id, JsonElement root, Cancel

if (string.IsNullOrWhiteSpace(toolName))
{
Console.Error.WriteLine("[MCP DEBUG] callTool → missing tool name.");
WriteError(id, McpStdioJsonRpcErrorCodes.INVALID_PARAMS, "Missing tool name");
return;
}

if (!_toolRegistry.TryGetTool(toolName!, out IMcpTool? tool) || tool is null)
{
Console.Error.WriteLine($"[MCP DEBUG] callTool → tool not found: {toolName}");
WriteError(id, McpStdioJsonRpcErrorCodes.INVALID_PARAMS, $"Tool not found: {toolName}");
return;
}
Expand All @@ -276,13 +374,7 @@ private async Task HandleCallToolAsync(JsonElement? id, JsonElement root, Cancel
{
if (@params.TryGetProperty("arguments", out JsonElement argsEl) && argsEl.ValueKind == JsonValueKind.Object)
{
string rawArgs = argsEl.GetRawText();
Console.Error.WriteLine($"[MCP DEBUG] callTool → tool: {toolName}, args: {rawArgs}");
argsDoc = JsonDocument.Parse(rawArgs);
}
else
{
Console.Error.WriteLine($"[MCP DEBUG] callTool → tool: {toolName}, args: <none>");
argsDoc = JsonDocument.Parse(argsEl.GetRawText());
}

// Execute the tool with telemetry.
Expand Down
81 changes: 81 additions & 0 deletions src/Cli.Tests/CustomLoggerTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

namespace Cli.Tests;

/// <summary>
/// Tests for CustomLoggerProvider and CustomConsoleLogger, verifying
/// that log level labels use ASP.NET Core abbreviated format.
/// </summary>
[TestClass]
public class CustomLoggerTests
{
/// <summary>
/// Validates that each enabled log level produces the correct abbreviated label
/// matching ASP.NET Core's default console formatter convention.
/// Trace and Debug are below the logger's minimum level and produce no output.
/// </summary>
[DataTestMethod]
[DataRow(LogLevel.Information, "info:")]
[DataRow(LogLevel.Warning, "warn:")]
public void LogOutput_UsesAbbreviatedLogLevelLabels(LogLevel logLevel, string expectedPrefix)
{
CustomLoggerProvider provider = new();
ILogger logger = provider.CreateLogger("TestCategory");

TextWriter originalOut = Console.Out;
try
{
StringWriter writer = new();
Console.SetOut(writer);

logger.Log(logLevel, "test message");

string output = writer.ToString();
Assert.IsTrue(
output.StartsWith(expectedPrefix),
$"Expected output to start with '{expectedPrefix}' but got: '{output}'");
Assert.IsTrue(
output.Contains("test message"),
$"Expected output to contain 'test message' but got: '{output}'");
}
finally
{
Console.SetOut(originalOut);
}
}

/// <summary>
/// Validates that each log level error and above produces the correct abbreviated
/// label matching ASP.NET Core's default console formatter convention.
/// Error and Critical logs should go to the stderr stream.
/// </summary>
[DataTestMethod]
[DataRow(LogLevel.Error, "fail:")]
[DataRow(LogLevel.Critical, "crit:")]
public void LogError_UsesAbbreviatedLogLevelLabels(LogLevel logLevel, string expectedPrefix)
{
CustomLoggerProvider provider = new();
ILogger logger = provider.CreateLogger("TestCategory");

TextWriter originalError = Console.Error;
try
{
StringWriter writer = new();
Console.SetError(writer);
logger.Log(logLevel, "test message");

string output = writer.ToString();
Assert.IsTrue(
output.StartsWith(expectedPrefix),
$"Expected output to start with '{expectedPrefix}' but got: '{output}'");
Assert.IsTrue(
output.Contains("test message"),
$"Expected output to contain 'test message' but got: '{output}'");
}
finally
{
Console.SetError(originalError);
}
}
}
Loading
Loading