Skip to content

Commit 843f5d8

Browse files
committed
fix: provider consolidation and remaining audit fixes
Unify Provider, TokenStreamer, and StreamProvider into a single Provider interface with Complete + CompleteStream + Name. Remove dead StreamProvider interface, StreamEvent/StreamConfig types, SSEDecoder, and Stream() methods from all providers. Agent now calls CompleteStream directly. Also: delete leftover basic/cli binaries, add .gitignore, expand BashTool dangerous patterns (21 total), add symlink-safe path protection, and implement weighted token estimation heuristic.
1 parent d8dcd77 commit 843f5d8

23 files changed

Lines changed: 689 additions & 654 deletions

.gitignore

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Compiled binaries
2+
*.exe
3+
*.exe~
4+
*.dll
5+
*.so
6+
*.dylib
7+
basic
8+
cli
9+
10+
# Test binary
11+
*.test
12+
13+
# Output of go coverage
14+
*.out
15+
16+
# Go workspace
17+
go.work
18+
go.work.sum
19+
20+
# IDE
21+
.idea/
22+
.vscode/
23+
*.swp
24+
*.swo
25+
26+
# OS
27+
.DS_Store
28+
Thumbs.db

agent.go

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import (
1919
// ErrUnknownTool is returned when a tool call references a tool that is not registered.
2020
var ErrUnknownTool = errors.New("unknown tool")
2121

22-
const maxAgentIterations = 20
22+
const defaultMaxAgentIterations = 20
2323

2424
// ToolCall represents a tool invocation.
2525
type ToolCall struct {
@@ -74,6 +74,9 @@ type Agent struct {
7474
Skills []Skill
7575
Messages []Message
7676

77+
// MaxIterations overrides the default iteration limit when set to > 0.
78+
MaxIterations int
79+
7780
// Context compaction config.
7881
contextConfig ContextConfig
7982

@@ -361,7 +364,12 @@ func (a *Agent) Run(ctx context.Context, systemPrompt, userMessage string, emitF
361364

362365
opts := a.completionOpts()
363366

364-
for i := 0; i < maxAgentIterations; i++ {
367+
maxIter := defaultMaxAgentIterations
368+
if a.MaxIterations > 0 {
369+
maxIter = a.MaxIterations
370+
}
371+
372+
for i := 0; i < maxIter; i++ {
365373
a.logger.Info("agent iteration", "step", i+1)
366374
emit(Event{Type: string(EventTurnStart), Content: fmt.Sprintf("turn %d", i+1)})
367375

@@ -376,17 +384,11 @@ func (a *Agent) Run(ctx context.Context, systemPrompt, userMessage string, emitF
376384
response string
377385
err error
378386
)
379-
if ts, ok := a.provider.(TokenStreamer); ok {
380-
response, err = RetryWithResult(ctx, DefaultRetryConfig, func() (string, error) {
381-
return ts.CompleteStream(ctx, messages, opts, func(token string) {
382-
emit(Event{Type: string(EventTokenUpdate), Content: token})
383-
})
387+
response, err = RetryWithResult(ctx, DefaultRetryConfig, func() (string, error) {
388+
return a.provider.CompleteStream(ctx, messages, opts, func(token string) {
389+
emit(Event{Type: string(EventTokenUpdate), Content: token})
384390
})
385-
} else {
386-
response, err = RetryWithResult(ctx, DefaultRetryConfig, func() (string, error) {
387-
return a.provider.Complete(ctx, messages, opts)
388-
})
389-
}
391+
})
390392
if err != nil {
391393
emit(Event{Type: string(EventError), Content: err.Error(), IsError: true})
392394
return "", fmt.Errorf("provider error at step %d: %w", i+1, err)
@@ -419,7 +421,7 @@ func (a *Agent) Run(ctx context.Context, systemPrompt, userMessage string, emitF
419421
emit(Event{Type: string(EventTurnEnd), Content: ""})
420422
}
421423

422-
return "", fmt.Errorf("agent exceeded max iterations (%d)", maxAgentIterations)
424+
return "", fmt.Errorf("agent exceeded max iterations (%d)", maxIter)
423425
}
424426

425427
func (a *Agent) emit(e Event) {
@@ -846,7 +848,12 @@ func (a *Agent) PromptMessages(ctx context.Context, messages []Message) chan Eve
846848

847849
opts := a.completionOpts()
848850

849-
for i := 0; i < maxAgentIterations; i++ {
851+
maxIter := defaultMaxAgentIterations
852+
if a.MaxIterations > 0 {
853+
maxIter = a.MaxIterations
854+
}
855+
856+
for i := 0; i < maxIter; i++ {
850857
// Check for cancellation before each turn.
851858
select {
852859
case <-loopCtx.Done():
@@ -868,17 +875,11 @@ func (a *Agent) PromptMessages(ctx context.Context, messages []Message) chan Eve
868875
response string
869876
turnErr error
870877
)
871-
if ts, ok := a.provider.(TokenStreamer); ok {
872-
response, turnErr = RetryWithResult(loopCtx, DefaultRetryConfig, func() (string, error) {
873-
return ts.CompleteStream(loopCtx, fullMessages, opts, func(token string) {
874-
emitFn(Event{Type: string(EventTokenUpdate), Content: token})
875-
})
878+
response, turnErr = RetryWithResult(loopCtx, DefaultRetryConfig, func() (string, error) {
879+
return a.provider.CompleteStream(loopCtx, fullMessages, opts, func(token string) {
880+
emitFn(Event{Type: string(EventTokenUpdate), Content: token})
876881
})
877-
} else {
878-
response, turnErr = RetryWithResult(loopCtx, DefaultRetryConfig, func() (string, error) {
879-
return a.provider.Complete(loopCtx, fullMessages, opts)
880-
})
881-
}
882+
})
882883
if turnErr != nil {
883884
emitFn(Event{Type: string(EventError), Content: turnErr.Error(), IsError: true})
884885
break

agent_test.go

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -310,12 +310,12 @@ func TestAgentHooks_BeforeTurnReceivesMessages(t *testing.T) {
310310
}
311311

312312
// ---------------------------------------------------------------------------
313-
// TokenStreamer tests
313+
// Streaming tests
314314
// ---------------------------------------------------------------------------
315315

316-
// TestAgentTokenStreamer_EventsEmitted verifies that EventTokenUpdate events
317-
// are emitted when the provider implements TokenStreamer.
318-
func TestAgentTokenStreamer_EventsEmitted(t *testing.T) {
316+
// TestAgentStreaming_EventsEmitted verifies that EventTokenUpdate events
317+
// are emitted during streaming completions.
318+
func TestAgentStreaming_EventsEmitted(t *testing.T) {
319319
p := iteragent.NewMockStream("hello world")
320320
a := iteragent.New(p, nil, testLogger())
321321

@@ -338,9 +338,9 @@ func TestAgentTokenStreamer_EventsEmitted(t *testing.T) {
338338
}
339339
}
340340

341-
// TestAgentTokenStreamer_RunReturnsFullResponse verifies Run() returns the
341+
// TestAgentStreaming_RunReturnsFullResponse verifies Run() returns the
342342
// complete response even when token streaming is active.
343-
func TestAgentTokenStreamer_RunReturnsFullResponse(t *testing.T) {
343+
func TestAgentStreaming_RunReturnsFullResponse(t *testing.T) {
344344
p := iteragent.NewMockStream("the quick brown fox")
345345
a := iteragent.New(p, nil, testLogger())
346346

@@ -353,10 +353,10 @@ func TestAgentTokenStreamer_RunReturnsFullResponse(t *testing.T) {
353353
}
354354
}
355355

356-
// TestAgentTokenStreamer_WithTools verifies streaming still works when a tool
356+
// TestAgentStreaming_WithTools verifies streaming still works when a tool
357357
// call precedes the final response. The mock streams every turn, so we verify
358358
// that EventTokenUpdate events were emitted and the final answer is correct.
359-
func TestAgentTokenStreamer_WithTools(t *testing.T) {
359+
func TestAgentStreaming_WithTools(t *testing.T) {
360360
toolCall := iteragent.ToolCall{Tool: "noop", Args: map[string]string{}}
361361
noopTool := iteragent.Tool{
362362
Name: "noop",
@@ -388,21 +388,6 @@ func TestAgentTokenStreamer_WithTools(t *testing.T) {
388388
}
389389
}
390390

391-
// TestAgentTokenStreamer_NonStreamingProviderNoTokenEvents verifies that a
392-
// plain Provider (no TokenStreamer) emits NO EventTokenUpdate events.
393-
func TestAgentTokenStreamer_NonStreamingProviderNoTokenEvents(t *testing.T) {
394-
p := iteragent.NewMock("plain response") // does not implement TokenStreamer
395-
a := iteragent.New(p, nil, testLogger())
396-
397-
events := a.Prompt(context.Background(), "hello")
398-
for e := range events {
399-
if iteragent.EventType(e.Type) == iteragent.EventTokenUpdate {
400-
t.Errorf("unexpected EventTokenUpdate from non-streaming provider: %q", e.Content)
401-
}
402-
}
403-
a.Finish()
404-
}
405-
406391
// TestAgentClose verifies that Close() is safe to call on an agent with no MCP
407392
// servers and does not block or panic.
408393
func TestAgentClose(t *testing.T) {
@@ -545,7 +530,7 @@ func TestPromptMessages_Hooks_OnToolStartEnd(t *testing.T) {
545530
}
546531

547532
// TestPromptMessages_Hooks_TokenUpdate verifies EventTokenUpdate events are emitted
548-
// via PromptMessages when the provider implements TokenStreamer.
533+
// via PromptMessages during streaming completions.
549534
func TestPromptMessages_Hooks_TokenUpdate(t *testing.T) {
550535
const want = "streamed via prompt messages"
551536
p := iteragent.NewMockStream(want)
@@ -572,7 +557,7 @@ func TestPromptMessages_Hooks_TokenUpdate(t *testing.T) {
572557
}
573558

574559
// TestSubAgentStreaming verifies that SubAgent.Run delegates to the embedded Agent,
575-
// which means TokenStreamer and EventTokenUpdate both flow through correctly after the refactor.
560+
// which means streaming and EventTokenUpdate both flow through correctly.
576561
func TestSubAgentStreaming(t *testing.T) {
577562
const want = "subagent streamed answer"
578563
p := iteragent.NewMockStream(want)
@@ -657,6 +642,9 @@ func (p *blockingProvider) Complete(ctx context.Context, messages []iteragent.Me
657642
<-ctx.Done()
658643
return "", ctx.Err()
659644
}
645+
func (p *blockingProvider) CompleteStream(ctx context.Context, messages []iteragent.Message, opts iteragent.CompletionOptions, onToken func(string)) (string, error) {
646+
return p.Complete(ctx, messages, opts)
647+
}
660648

661649
// ---------------------------------------------------------------------------
662650
// MCP integration tests

anthropic.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ type anthropicProvider struct {
2323
}
2424

2525
// NewAnthropic returns a native Anthropic provider.
26-
// The returned provider implements both Provider and TokenStreamer.
26+
// The returned provider implements Provider with streaming support.
2727
func NewAnthropic(cfg AnthropicConfig) Provider {
2828
return &anthropicProvider{
2929
cfg: cfg,
@@ -102,7 +102,7 @@ func (p *anthropicProvider) buildAnthropicBody(messages []Message, opt Completio
102102
return json.Marshal(reqMap)
103103
}
104104

105-
// CompleteStream implements TokenStreamer. It uses SSE to deliver tokens
105+
// CompleteStream implements Provider. It uses SSE to deliver tokens
106106
// incrementally via onToken as they arrive from the Anthropic API.
107107
func (p *anthropicProvider) CompleteStream(ctx context.Context, messages []Message, opt CompletionOptions, onToken func(string)) (string, error) {
108108
body, err := p.buildAnthropicBody(messages, opt, true)

azure.go

Lines changed: 1 addition & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ func (p *AzureOpenAIProvider) Complete(ctx context.Context, messages []Message,
106106
return response.Choices[0].Message.Content, nil
107107
}
108108

109-
// CompleteStream implements TokenStreamer for Azure OpenAI using the SSE streaming endpoint.
109+
// CompleteStream implements Provider for Azure OpenAI using the SSE streaming endpoint.
110110
func (p *AzureOpenAIProvider) CompleteStream(ctx context.Context, messages []Message, opt CompletionOptions, onToken func(string)) (string, error) {
111111
apiVersion := p.config.APIVersion
112112
if apiVersion == "" {
@@ -161,121 +161,3 @@ func messagesToAzureFormat(messages []Message) []map[string]interface{} {
161161
return result
162162
}
163163

164-
func (p *AzureOpenAIProvider) Stream(ctx context.Context, config StreamConfig, messages []Message, onEvent func(StreamEvent)) (Message, error) {
165-
apiVersion := p.config.APIVersion
166-
if apiVersion == "" {
167-
apiVersion = "2024-02-15-preview"
168-
}
169-
170-
url := fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s",
171-
p.config.Endpoint, p.config.Deployment, apiVersion)
172-
173-
body := map[string]interface{}{
174-
"messages": messagesToAzureFormat(messages),
175-
"stream": true,
176-
}
177-
178-
if config.MaxTokens > 0 {
179-
body["max_tokens"] = config.MaxTokens
180-
}
181-
if config.Temperature > 0 {
182-
body["temperature"] = config.Temperature
183-
}
184-
185-
jsonBody, err := json.Marshal(body)
186-
if err != nil {
187-
return Message{}, fmt.Errorf("marshal request: %w", err)
188-
}
189-
190-
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody))
191-
if err != nil {
192-
return Message{}, fmt.Errorf("create request: %w", err)
193-
}
194-
195-
req.Header.Set("Content-Type", "application/json")
196-
req.Header.Set("api-key", p.config.APIKey)
197-
198-
resp, err := p.client.Do(req)
199-
if err != nil {
200-
return Message{}, fmt.Errorf("request failed: %w", err)
201-
}
202-
defer resp.Body.Close()
203-
204-
if resp.StatusCode != http.StatusOK {
205-
respBody, _ := io.ReadAll(resp.Body) // best-effort for error message
206-
return Message{}, fmt.Errorf("Azure OpenAI error (%d): %s", resp.StatusCode, string(respBody))
207-
}
208-
209-
var content strings.Builder
210-
decoder := NewSSEDecoder(resp.Body)
211-
212-
for {
213-
event, err := decoder.Decode()
214-
if err == io.EOF {
215-
break
216-
}
217-
if err != nil {
218-
break
219-
}
220-
221-
if event.Type == "content" || event.Type == "content_block" {
222-
content.WriteString(event.Content)
223-
onEvent(event)
224-
}
225-
}
226-
227-
return Message{
228-
Role: "assistant",
229-
Content: content.String(),
230-
}, nil
231-
}
232-
233-
type SSEDecoder struct {
234-
reader io.Reader
235-
buf []byte
236-
}
237-
238-
func NewSSEDecoder(reader io.Reader) *SSEDecoder {
239-
return &SSEDecoder{reader: reader, buf: make([]byte, 1024)}
240-
}
241-
242-
func (d *SSEDecoder) Decode() (StreamEvent, error) {
243-
var line string
244-
for {
245-
n, err := d.reader.Read(d.buf)
246-
if n == 0 || err != nil {
247-
return StreamEvent{}, err
248-
}
249-
line = string(d.buf[:n])
250-
if strings.HasPrefix(line, "data:") {
251-
break
252-
}
253-
}
254-
255-
line = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
256-
if line == "[DONE]" {
257-
return StreamEvent{}, io.EOF
258-
}
259-
260-
var delta struct {
261-
Choices []struct {
262-
Delta struct {
263-
Content string `json:"content"`
264-
} `json:"delta"`
265-
} `json:"choices"`
266-
}
267-
268-
if err := json.Unmarshal([]byte(line), &delta); err != nil {
269-
return StreamEvent{}, err
270-
}
271-
272-
content := ""
273-
if len(delta.Choices) > 0 {
274-
content = delta.Choices[0].Delta.Content
275-
}
276-
277-
return StreamEvent{
278-
Type: StreamEventContent,
279-
Content: content,
280-
}, nil
281-
}

azure_test.go

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -286,27 +286,3 @@ func TestAzureComplete_ContextCancelled(t *testing.T) {
286286
}
287287
}
288288

289-
// ---------------------------------------------------------------------------
290-
// SSEDecoder
291-
// ---------------------------------------------------------------------------
292-
293-
func TestSSEDecoder_ContentEvent(t *testing.T) {
294-
data := "data: {\"choices\":[{\"delta\":{\"content\":\"hello\"}}]}\n"
295-
decoder := iteragent.NewSSEDecoder(strings.NewReader(data))
296-
event, err := decoder.Decode()
297-
if err != nil {
298-
t.Fatalf("unexpected error: %v", err)
299-
}
300-
if event.Content != "hello" {
301-
t.Errorf("expected content 'hello', got %q", event.Content)
302-
}
303-
}
304-
305-
func TestSSEDecoder_DoneSignal(t *testing.T) {
306-
import_io := "data: [DONE]\n"
307-
decoder := iteragent.NewSSEDecoder(strings.NewReader(import_io))
308-
_, err := decoder.Decode()
309-
if err == nil {
310-
t.Error("expected EOF for [DONE] signal")
311-
}
312-
}

basic

-8.26 MB
Binary file not shown.

0 commit comments

Comments
 (0)