Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions internal/assistant/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ type CompletionRequest struct {
OnEvent func(StreamEvent) `json:"-"`
OnProviderObserve func(context.Context, *CompletionRequest, int) `json:"-"`
OnProviderRequest ProviderRequestHook `json:"-"`
OnToolCall func(context.Context, ToolCallEvent) `json:"-"`
OnToolResult func(context.Context, *ToolEvent) `json:"-"`
OnToolCall func(context.Context, *ToolCallEvent) error `json:"-"`
OnToolResult func(context.Context, *ToolEvent) error `json:"-"`
ToolRegistry *tool.Registry `json:"-"`
SessionID string `json:"session_id"`
SystemPrompt string `json:"system_prompt"`
Expand Down
15 changes: 15 additions & 0 deletions internal/assistant/export_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package assistant

import (
"context"
)

// DispatchToolCallLifecycleForTest exposes tool call lifecycle dispatch for external package tests.
func (runtime *Runtime) DispatchToolCallLifecycleForTest(ctx context.Context, call *ToolCallEvent) error {
return runtime.dispatchToolCallLifecycle(ctx, call)
}

// DispatchToolResultLifecycleForTest exposes tool result lifecycle dispatch for external package tests.
func (runtime *Runtime) DispatchToolResultLifecycleForTest(ctx context.Context, event *ToolEvent) error {
return runtime.dispatchToolResultLifecycle(ctx, event)
}
67 changes: 67 additions & 0 deletions internal/assistant/lifecycle.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ func (runtime *Runtime) dispatchLifecycle(
return extension.LifecycleDispatchResult{
Payload: cloneAnyMap(payload),
ProviderRequest: extension.ProviderRequestMutation{Headers: map[string]string{}},
ToolCall: extension.ToolCallMutation{Arguments: nil},
ToolResult: extension.ToolResultMutation{Result: nil, DetailsJSON: nil, Error: nil},
Name: string(name),
Errors: []string{},
Duration: 0,
Expand Down Expand Up @@ -227,6 +229,71 @@ func turnEndLifecyclePayload(
return payload
}

func (runtime *Runtime) dispatchToolCallLifecycle(ctx context.Context, call *ToolCallEvent) error {
if call == nil {
return nil
}
payload := toolCallPayload(*call)
runtime.emit(ctx, string(extension.LifecycleToolCall), payload)
if runtime.extensions == nil {
return nil
}

result, err := runtime.extensions.DispatchLifecycle(ctx, extension.LifecycleEvent{
Name: extension.LifecycleToolCall,
Payload: payload,
})
if err != nil {
return err
}
applyToolCallMutation(call, result.ToolCall)

return nil
}

func (runtime *Runtime) dispatchToolResultLifecycle(ctx context.Context, event *ToolEvent) error {
if event == nil {
return nil
}
payload := toolEventPayload(event)
runtime.emit(ctx, string(extension.LifecycleToolResult), payload)
if runtime.extensions != nil {
result, err := runtime.extensions.DispatchLifecycle(ctx, extension.LifecycleEvent{
Name: extension.LifecycleToolResult,
Payload: payload,
})
if err != nil {
return err
}
applyToolResultMutation(event, result.ToolResult)
}
if event.Error != "" {
runtime.dispatchObservationalLifecycle(ctx, extension.LifecycleToolError, toolEventPayload(event))
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

return nil
}

func applyToolCallMutation(call *ToolCallEvent, mutation extension.ToolCallMutation) {
if len(mutation.Arguments) == 0 {
return
}
call.Arguments = mutation.Arguments
call.ArgumentsJSON = encodeToolArguments(call.Arguments)
}

func applyToolResultMutation(event *ToolEvent, mutation extension.ToolResultMutation) {
if mutation.Result != nil {
event.Result = *mutation.Result
}
if mutation.DetailsJSON != nil {
event.DetailsJSON = *mutation.DetailsJSON
}
if mutation.Error != nil {
event.Error = *mutation.Error
}
}

func contextBuildLifecyclePayload(
sessionID string,
cwd string,
Expand Down
4 changes: 2 additions & 2 deletions internal/assistant/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -827,8 +827,8 @@ func (runtime *Runtime) modelCompletionRequest(
OnEvent: onEvent,
OnProviderObserve: runtime.emitProviderRequest,
OnProviderRequest: runtime.dispatchProviderRequestHook,
OnToolCall: runtime.emitToolCall,
OnToolResult: runtime.emitToolResult,
OnToolCall: runtime.dispatchToolCallLifecycle,
OnToolResult: runtime.dispatchToolResultLifecycle,
ToolRegistry: registry,
SessionID: sessionID,
SystemPrompt: systemPrompt,
Expand Down
15 changes: 0 additions & 15 deletions internal/assistant/runtime_events.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,6 @@ func (runtime *Runtime) emitProviderError(ctx context.Context, request *Completi
})
}

func (runtime *Runtime) emitToolCall(ctx context.Context, call ToolCallEvent) {
runtime.dispatchObservationalLifecycle(ctx, extension.LifecycleToolCall, toolCallPayload(call))
}

func (runtime *Runtime) emitToolResult(ctx context.Context, event *ToolEvent) {
if event == nil {
return
}
payload := toolEventPayload(event)
runtime.dispatchObservationalLifecycle(ctx, extension.LifecycleToolResult, payload)
if event.Error != "" {
runtime.dispatchObservationalLifecycle(ctx, extension.LifecycleToolError, payload)
}
}

func toolCallPayload(call ToolCallEvent) map[string]any {
return map[string]any{
"call_id": call.ID,
Expand Down
15 changes: 11 additions & 4 deletions internal/assistant/runtime_events_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
const (
testCompletionText = "done"
testToolName = "read"
testToolPathKey = "path"
testToolArgsJSON = `{"path":"README.md"}`
)

Expand Down Expand Up @@ -101,21 +102,27 @@ func (toolCallbackClient) Complete(
request *assistant.CompletionRequest,
) (*assistant.CompletionResult, error) {
if request.OnToolCall != nil {
request.OnToolCall(ctx, assistant.ToolCallEvent{
toolCall := assistant.ToolCallEvent{
Arguments: map[string]any{"path": "README.md"},
ID: "call-1",
Name: testToolName,
ArgumentsJSON: testToolArgsJSON,
})
}
if err := request.OnToolCall(ctx, &toolCall); err != nil {
return nil, err
}
}
if request.OnToolResult != nil {
request.OnToolResult(ctx, &assistant.ToolEvent{
toolResult := assistant.ToolEvent{
Name: testToolName,
ArgumentsJSON: testToolArgsJSON,
DetailsJSON: "",
Result: "contents",
Error: "",
})
}
if err := request.OnToolResult(ctx, &toolResult); err != nil {
return nil, err
}
}

return &assistant.CompletionResult{
Expand Down
4 changes: 2 additions & 2 deletions internal/assistant/runtime_lifecycle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ func TestRuntime_PromptEmitsSideEffectMessageAppendEvents(t *testing.T) {
Thinking: []string{"reasoning"},
ToolEvents: []assistant.ToolEvent{
{
Name: "read",
ArgumentsJSON: `{"path":"README.md"}`,
Name: testToolName,
ArgumentsJSON: testToolArgsJSON,
DetailsJSON: "",
Result: "contents",
Error: "",
Expand Down
4 changes: 2 additions & 2 deletions internal/assistant/runtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -550,8 +550,8 @@ func (partialFailureCompletionClient) Complete(
})
request.OnEvent(assistant.StreamEvent{
ToolEvent: &assistant.ToolEvent{
Name: "read",
ArgumentsJSON: `{"path":"README.md"}`,
Name: testToolName,
ArgumentsJSON: testToolArgsJSON,
DetailsJSON: "",
Result: "file content",
Error: "",
Expand Down
154 changes: 154 additions & 0 deletions internal/assistant/tool_lifecycle_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package assistant_test

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/omarluq/librecode/internal/assistant"
)

const testToolLifecycleError = "boom"

func TestRuntime_ToolCallLifecycleAppliesArgumentMutation(t *testing.T) {
t.Parallel()

tests := []struct {
expectedArguments map[string]any
initialArguments map[string]any
name string
lua string
expectedArgumentsJSON string
}{
{
name: "rewrites path and adds limit",
initialArguments: map[string]any{testToolPathKey: "README.md"},
lua: `
local lc = require("librecode")
lc.on("tool_call", function(event)
return {
tool_call = {
arguments = {
path = "changed.txt",
limit = 3,
},
},
}
end)
`,
expectedArguments: map[string]any{
testToolPathKey: "changed.txt",
"limit": float64(3),
},
expectedArgumentsJSON: `{"limit":3,"path":"changed.txt"}`,
},
}

for _, testCase := range tests {
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()

runtime, _, manager := newTestRuntimeWithManager(t, testCompletionClient{})
loadRuntimeExtension(t, manager, testCase.lua)
call := assistant.ToolCallEvent{
Arguments: testCase.initialArguments,
ID: "call-1",
Name: testToolName,
ArgumentsJSON: testToolArgsJSON,
}

err := runtime.DispatchToolCallLifecycleForTest(context.Background(), &call)

require.NoError(t, err)
assert.Equal(t, testCase.expectedArguments, call.Arguments)
assert.JSONEq(t, testCase.expectedArgumentsJSON, call.ArgumentsJSON)
})
}
}

func TestRuntime_ToolResultLifecycleAppliesResultMutation(t *testing.T) {
t.Parallel()

tests := []struct {
initialEvent *assistant.ToolEvent
name string
lua string
expectedResult string
expectedDetailsJSON string
expectedError string
}{
{
name: "redacts result and clears error",
initialEvent: &assistant.ToolEvent{
Name: testToolName,
ArgumentsJSON: testToolArgsJSON,
DetailsJSON: "{}",
Result: "secret",
Error: testToolLifecycleError,
},
lua: `
local lc = require("librecode")
lc.on("tool_result", function(event)
return {
tool_result = {
result = "redacted",
details_json = "{\"redacted\":true}",
error = "",
},
}
end)
`,
expectedResult: "redacted",
expectedDetailsJSON: `{"redacted":true}`,
expectedError: "",
},
}

for _, testCase := range tests {
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()

runtime, _, manager := newTestRuntimeWithManager(t, testCompletionClient{})
loadRuntimeExtension(t, manager, testCase.lua)

err := runtime.DispatchToolResultLifecycleForTest(context.Background(), testCase.initialEvent)

require.NoError(t, err)
assert.Equal(t, testCase.expectedResult, testCase.initialEvent.Result)
assert.JSONEq(t, testCase.expectedDetailsJSON, testCase.initialEvent.DetailsJSON)
assert.Equal(t, testCase.expectedError, testCase.initialEvent.Error)
})
}
}

func TestRuntime_ToolResultLifecycleDispatchesToolErrorHandlers(t *testing.T) {
t.Parallel()

runtime, _, manager := newTestRuntimeWithManager(t, testCompletionClient{})
loadRuntimeExtension(t, manager, `
local lc = require("librecode")
local seen = ""
lc.on("tool_error", function(event)
seen = event.payload.name .. ":" .. event.payload.error
end)
lc.register_command("seen_tool_error", "seen_tool_error", function()
return seen
end)
`)
event := &assistant.ToolEvent{
Name: testToolName,
ArgumentsJSON: testToolArgsJSON,
DetailsJSON: "",
Result: testToolLifecycleError,
Error: testToolLifecycleError,
}

err := runtime.DispatchToolResultLifecycleForTest(context.Background(), event)

require.NoError(t, err)
output, err := manager.ExecuteCommand(context.Background(), "seen_tool_error", "")
require.NoError(t, err)
assert.Equal(t, "read:boom", output)
}
Loading
Loading