-
Notifications
You must be signed in to change notification settings - Fork 3
feat: add MCP injection support to responses streaming interceptor #143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ba4e96d
eaec7d1
1455005
7006f7a
5c5c528
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,7 +3,6 @@ package responses | |
| import ( | ||
| "context" | ||
| "errors" | ||
| "fmt" | ||
| "net/http" | ||
| "time" | ||
|
|
||
|
|
@@ -15,7 +14,6 @@ import ( | |
| "github.com/google/uuid" | ||
| "github.com/openai/openai-go/v3/option" | ||
| "github.com/openai/openai-go/v3/responses" | ||
| "github.com/tidwall/sjson" | ||
| "go.opentelemetry.io/otel/attribute" | ||
| "go.opentelemetry.io/otel/trace" | ||
| ) | ||
|
|
@@ -62,98 +60,54 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r * | |
|
|
||
| var ( | ||
| response *responses.Response | ||
| err error | ||
| upstreamErr error | ||
| respCopy responseCopier | ||
| ) | ||
|
|
||
| for { | ||
| shouldLoop := true | ||
| recordPromptOnce := true | ||
| for shouldLoop { | ||
| srv := i.newResponsesService() | ||
| respCopy = responseCopier{} | ||
|
|
||
| opts := i.requestOptions(&respCopy) | ||
| opts = append(opts, option.WithRequestTimeout(time.Second*600)) | ||
| response, upstreamErr = i.newResponse(ctx, srv, opts) | ||
|
|
||
| if upstreamErr != nil { | ||
| if upstreamErr != nil || response == nil { | ||
| break | ||
| } | ||
|
|
||
| // response could be nil eg. fixtures/openai/responses/blocking/wrong_response_format.txtar | ||
| if response == nil { | ||
| break | ||
| // Record prompt usage on first successful response. | ||
| if recordPromptOnce { | ||
| recordPromptOnce = false | ||
| i.recordUserPrompt(ctx, response.ID) | ||
| } | ||
|
|
||
| // Record prompt usage on first successful response. | ||
| i.recordUserPrompt(ctx, response.ID) | ||
| // Record token usage for each inner loop iteration | ||
| i.recordTokenUsage(ctx, response) | ||
|
|
||
| // Check if there any injected tools to invoke. | ||
| pending := i.getPendingInjectedToolCalls(ctx, response) | ||
| if len(pending) == 0 { | ||
| // No injected tools, record non-injected tool usage. | ||
| i.recordNonInjectedToolUsage(ctx, response) | ||
|
|
||
| // No injected function calls need to be invoked, flow is complete. | ||
| break | ||
| } | ||
|
|
||
| shouldLoop, err := i.handleInnerAgenticLoop(ctx, pending, response) | ||
| pending := i.getPendingInjectedToolCalls(response) | ||
| shouldLoop, err = i.handleInnerAgenticLoop(ctx, pending, response) | ||
| if err != nil { | ||
| i.sendCustomErr(ctx, w, http.StatusInternalServerError, err) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it mean that if one tool returns an error, we failed the entire prompt?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Depends what "returns an error" means.
Thanks to this comment I've found re-marshaling error was ignored. Will add |
||
| shouldLoop = false | ||
| } | ||
|
|
||
| if !shouldLoop { | ||
| break | ||
| } | ||
| } | ||
|
|
||
| i.recordNonInjectedToolUsage(ctx, response) | ||
pawbana marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| if upstreamErr != nil && !respCopy.responseReceived.Load() { | ||
| // no response received from upstream, return custom error | ||
| i.sendCustomErr(ctx, w, http.StatusInternalServerError, upstreamErr) | ||
| } | ||
|
|
||
| err := respCopy.forwardResp(w) | ||
|
|
||
| err = respCopy.forwardResp(w) | ||
| return errors.Join(upstreamErr, err) | ||
| } | ||
|
|
||
| // handleInnerAgenticLoop orchestrates the inner agentic loop whereby injected tools | ||
| // are invoked and their results are sent back to the model. | ||
| // This is in contrast to regular tool calls which will be handled by the client | ||
| // in its own agentic loop. | ||
| func (i *BlockingResponsesInterceptor) handleInnerAgenticLoop(ctx context.Context, pending []responses.ResponseFunctionToolCall, response *responses.Response) (bool, error) { | ||
| // Invoke any injected function calls. | ||
| // The Responses API refers to what we call "tools" as "functions", so we keep the terminology | ||
| // consistent in this package. | ||
| // See https://platform.openai.com/docs/guides/function-calling | ||
| results, err := i.handleInjectedToolCalls(ctx, pending, response) | ||
| if err != nil { | ||
| return false, fmt.Errorf("failed to handle injected tool calls: %w", err) | ||
| } | ||
|
|
||
| // No tool results means no tools were invocable, so the flow is complete. | ||
| if len(results) == 0 { | ||
| return false, nil | ||
| } | ||
|
|
||
| // We'll use the tool results to issue another request to provide the model with. | ||
| i.prepareRequestForAgenticLoop(response) | ||
| i.req.Input.OfInputItemList = append(i.req.Input.OfInputItemList, results...) | ||
|
|
||
| // TODO: we should avoid re-marshaling Input, but since it changes from a string to | ||
| // a list in this loop, we have to. | ||
| // See responsesInterceptionBase.requestOptions for more details about marshaling issues. | ||
| i.reqPayload, err = sjson.SetBytes(i.reqPayload, "input", i.req.Input) | ||
| if err != nil { | ||
| i.logger.Error(ctx, "failure to marshal new input in inner agentic loop", slog.Error(err)) | ||
| // TODO: what should be returned under this condition? | ||
| return false, nil | ||
| } | ||
|
|
||
| return true, nil | ||
| } | ||
|
|
||
| func (i *BlockingResponsesInterceptor) newResponse(ctx context.Context, srv responses.ResponseService, opts []option.RequestOption) (_ *responses.Response, outErr error) { | ||
| ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) | ||
| defer tracing.EndSpanErr(span, &outErr) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -76,6 +76,31 @@ func (i *responsesInterceptionBase) disableParallelToolCalls() { | |
| } | ||
| } | ||
|
|
||
| // handleInnerAgenticLoop orchestrates the inner agentic loop whereby injected tools | ||
| // are invoked and their results are sent back to the model. | ||
| // This is in contrast to regular tool calls which will be handled by the client | ||
| // in its own agentic loop. | ||
| func (i *responsesInterceptionBase) handleInnerAgenticLoop(ctx context.Context, pending []responses.ResponseFunctionToolCall, response *responses.Response) (bool, error) { | ||
| // Invoke any injected function calls. | ||
| // The Responses API refers to what we call "tools" as "functions", so we keep the terminology | ||
| // consistent in this package. | ||
| // See https://platform.openai.com/docs/guides/function-calling | ||
| results, err := i.handleInjectedToolCalls(ctx, pending, response) | ||
| if err != nil { | ||
| return false, fmt.Errorf("failed to handle injected tool calls: %w", err) | ||
| } | ||
|
|
||
| // No tool results means no tools were invocable, so the flow is complete. | ||
| if len(results) == 0 { | ||
| return false, nil | ||
| } | ||
|
|
||
| // We'll use the tool results to issue another request to provide the model with. | ||
| err = i.prepareRequestForAgenticLoop(ctx, response, results) | ||
|
|
||
| return true, err | ||
| } | ||
|
|
||
| // handleInjectedToolCalls checks for function calls that we need to handle in our inner agentic loop. | ||
| // These are functions injected by the MCP proxy. | ||
| // Returns a list of tool call results. | ||
|
|
@@ -99,19 +124,60 @@ func (i *responsesInterceptionBase) handleInjectedToolCalls(ctx context.Context, | |
|
|
||
| // prepareRequestForAgenticLoop prepares the request by setting the output of the given | ||
| // response as input to the next request, in order for the tool call result(s) to make function correctly. | ||
| func (i *responsesInterceptionBase) prepareRequestForAgenticLoop(response *responses.Response) { | ||
| func (i *responsesInterceptionBase) prepareRequestForAgenticLoop(ctx context.Context, response *responses.Response, toolResults []responses.ResponseInputItemUnionParam) error { | ||
| var err error | ||
| originalInputSize := len(i.req.Input.OfInputItemList) | ||
|
|
||
| // Unset the string input; we need a list now. | ||
| i.req.Input.OfString = param.Opt[string]{} | ||
| if i.req.Input.OfString.Valid() { | ||
| // convert old string value to list item | ||
| i.req.Input.OfInputItemList = responses.ResponseInputParam{ | ||
| responses.ResponseInputItemParamOfMessage( | ||
| i.req.Input.OfString.Value, | ||
| responses.EasyInputMessageRoleUser, | ||
| ), | ||
| } | ||
|
|
||
| // clear old value | ||
| i.req.Input.OfString = param.Opt[string]{} | ||
| } | ||
|
|
||
| // OutputText is also available, but by definition the trigger for a function call is not a simple | ||
| // text response from the model. | ||
| for _, output := range response.Output { | ||
| i.appendOutputToInput(i.req, output) | ||
| if inputItem := i.convertOutputToInput(output); inputItem != nil { | ||
| i.req.Input.OfInputItemList = append(i.req.Input.OfInputItemList, *inputItem) | ||
| } | ||
| } | ||
|
|
||
| for _, result := range toolResults { | ||
| i.req.Input.OfInputItemList = append(i.req.Input.OfInputItemList, result) | ||
| } | ||
|
|
||
| // If original payload was in string format or was an empty list re-marshal whole input | ||
| if originalInputSize == 0 { | ||
| if i.reqPayload, err = sjson.SetBytes(i.reqPayload, "input", i.req.Input.OfInputItemList); err != nil { | ||
| i.logger.Error(ctx, "failure to marshal new input in inner agentic loop", slog.Error(err)) | ||
| return fmt.Errorf("failed to marshal input: %v", err) | ||
| } | ||
| return nil | ||
| } | ||
|
|
||
| // Append newly added items to reqPayload field | ||
| // New items are appended to limit Input re-marshaling. | ||
| // See responsesInterceptionBase.requestOptions for more details about marshaling issues. | ||
| for j := originalInputSize; j < len(i.req.Input.OfInputItemList); j++ { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👌 nice |
||
| if i.reqPayload, err = sjson.SetBytes(i.reqPayload, "input.-1", i.req.Input.OfInputItemList[j]); err != nil { | ||
| i.logger.Error(ctx, "failure to marshal output item to new input in inner agentic loop", slog.Error(err)) | ||
| return fmt.Errorf("failed to marshal input: %v", err) | ||
| } | ||
| } | ||
|
|
||
| return nil | ||
| } | ||
|
|
||
| // getPendingInjectedToolCalls extracts function calls from the response that are managed by MCP proxy | ||
| func (i *responsesInterceptionBase) getPendingInjectedToolCalls(ctx context.Context, response *responses.Response) []responses.ResponseFunctionToolCall { | ||
| func (i *responsesInterceptionBase) getPendingInjectedToolCalls(response *responses.Response) []responses.ResponseFunctionToolCall { | ||
| var calls []responses.ResponseFunctionToolCall | ||
|
|
||
| for _, item := range response.Output { | ||
|
|
@@ -171,14 +237,14 @@ func (i *responsesInterceptionBase) invokeInjectedTool(ctx context.Context, resp | |
| return responses.ResponseInputItemParamOfFunctionCallOutput(fc.CallID, output) | ||
| } | ||
|
|
||
| // appendOutputToInput converts a response output item to an input item and appends it to the | ||
| // convertOutputToInput converts a response output item to an input item and appends it to the | ||
| // request's input list. This is used in agentic loops where we need to feed the model's output | ||
| // back as input for the next iteration (e.g., when processing tool call results). | ||
| // | ||
| // The conversion uses the openai-go library's ToParam() methods where available, which leverage | ||
| // param.Override() with raw JSON to preserve all fields. For types without ToParam(), we use | ||
| // the ResponseInputItemParamOf* helper functions. | ||
| func (i *responsesInterceptionBase) appendOutputToInput(req *ResponsesNewParamsWrapper, item responses.ResponseOutputItemUnion) { | ||
| func (i *responsesInterceptionBase) convertOutputToInput(item responses.ResponseOutputItemUnion) *responses.ResponseInputItemUnionParam { | ||
| var inputItem responses.ResponseInputItemUnionParam | ||
|
|
||
| switch item.Type { | ||
|
|
@@ -228,8 +294,8 @@ func (i *responsesInterceptionBase) appendOutputToInput(req *ResponsesNewParamsW | |
| // - mcp_call, mcp_list_tools, mcp_approval_request: MCP-specific outputs | ||
| default: | ||
| i.logger.Debug(context.Background(), "skipping output item type for input", slog.F("type", item.Type)) | ||
| return | ||
| return nil | ||
| } | ||
|
|
||
| req.Input.OfInputItemList = append(req.Input.OfInputItemList, inputItem) | ||
| return &inputItem | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.