Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,12 @@ func ActorFromContext(ctx context.Context) *Actor {

return a
}

// ActorIDFromContext safely extracts the actor ID from the context.
// Returns an empty string if no actor is found.
func ActorIDFromContext(ctx context.Context) string {
if actor := ActorFromContext(ctx); actor != nil {
return actor.ID
}
return ""
}
88 changes: 88 additions & 0 deletions context/context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package context

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/coder/aibridge/recorder"
)

func TestAsActor(t *testing.T) {
t.Parallel()

// Given: a metadata map
metadata := recorder.Metadata{"key": "value"}

// When: storing an actor in the context
ctx := AsActor(context.Background(), "actor-123", metadata)

// Then: the actor should be retrievable with correct ID and metadata
actor := ActorFromContext(ctx)
require.NotNil(t, actor)
assert.Equal(t, "actor-123", actor.ID)
assert.Equal(t, "value", actor.Metadata["key"])
}

func TestActorFromContext(t *testing.T) {
t.Parallel()

t.Run("returns actor when present", func(t *testing.T) {
t.Parallel()

// Given: a context with an actor
ctx := AsActor(context.Background(), "test-id", recorder.Metadata{})

// When: extracting the actor from context
actor := ActorFromContext(ctx)

// Then: the actor should be returned with correct ID
require.NotNil(t, actor)
assert.Equal(t, "test-id", actor.ID)
})

t.Run("returns nil when no actor", func(t *testing.T) {
t.Parallel()

// Given: a context without an actor
ctx := context.Background()

// When: extracting the actor from context
actor := ActorFromContext(ctx)

// Then: nil should be returned
assert.Nil(t, actor)
})
}

func TestActorIDFromContext(t *testing.T) {
t.Parallel()

t.Run("returns actor ID when present", func(t *testing.T) {
t.Parallel()

// Given: a context with an actor
ctx := AsActor(context.Background(), "test-actor-id", recorder.Metadata{})

// When: extracting the actor ID from context
got := ActorIDFromContext(ctx)

// Then: the actor ID should be returned
assert.Equal(t, "test-actor-id", got)
})

t.Run("returns empty string when no actor", func(t *testing.T) {
t.Parallel()

// Given: a context without an actor
ctx := context.Background()

// When: extracting the actor ID from context
got := ActorIDFromContext(ctx)

// Then: an empty string should be returned
assert.Empty(t, got)
})
}
2 changes: 1 addition & 1 deletion intercept/chatcompletions/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool)
return []attribute.KeyValue{
attribute.String(tracing.RequestPath, r.URL.Path),
attribute.String(tracing.InterceptionID, s.id.String()),
attribute.String(tracing.InitiatorID, aibcontext.ActorFromContext(r.Context()).ID),
attribute.String(tracing.InitiatorID, aibcontext.ActorIDFromContext(r.Context())),
attribute.String(tracing.Provider, config.ProviderOpenAI),
attribute.String(tracing.Model, s.Model()),
attribute.Bool(tracing.Streaming, streaming),
Expand Down
2 changes: 1 addition & 1 deletion intercept/messages/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool)
return []attribute.KeyValue{
attribute.String(tracing.RequestPath, r.URL.Path),
attribute.String(tracing.InterceptionID, s.id.String()),
attribute.String(tracing.InitiatorID, aibcontext.ActorFromContext(r.Context()).ID),
attribute.String(tracing.InitiatorID, aibcontext.ActorIDFromContext(r.Context())),
attribute.String(tracing.Provider, aibconfig.ProviderAnthropic),
attribute.String(tracing.Model, s.Model()),
attribute.Bool(tracing.Streaming, streaming),
Expand Down
2 changes: 1 addition & 1 deletion intercept/responses/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func (i *responsesInterceptionBase) baseTraceAttributes(r *http.Request, streami
return []attribute.KeyValue{
attribute.String(tracing.RequestPath, r.URL.Path),
attribute.String(tracing.InterceptionID, i.id.String()),
attribute.String(tracing.InitiatorID, aibcontext.ActorFromContext(r.Context()).ID),
attribute.String(tracing.InitiatorID, aibcontext.ActorIDFromContext(r.Context())),
attribute.String(tracing.Provider, config.ProviderOpenAI),
attribute.String(tracing.Model, i.Model()),
attribute.Bool(tracing.Streaming, streaming),
Expand Down