Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion go/ai/document.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,11 +298,12 @@ func (p *Part) unmarshalPartFromSchema(s partSchema) {
default:
p.Kind = PartText
p.Text = s.Text
p.ContentType = ""
p.ContentType = "plain/text"
if s.Data != "" {
// Note: if part is completely empty, we use text by default.
p.Kind = PartData
p.Text = s.Data
p.ContentType = ""
}
}
p.Metadata = s.Metadata
Expand Down
12 changes: 11 additions & 1 deletion go/ai/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,17 @@ func (o *commonGenOptions) applyGenerate(genOpts *generateOptions) error {
func WithMessages(messages ...*Message) CommonGenOption {
return &commonGenOptions{
MessagesFn: func(ctx context.Context, _ any) ([]*Message, error) {
return messages, nil
buf, err := json.Marshal(messages)
if err != nil {
return nil, err
}

msgs := make([]*Message, 0, len(messages))
if err := json.Unmarshal(buf, &msgs); err != nil {
return nil, err
}

return msgs, nil
Comment on lines +207 to +217
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While using json.Marshal and json.Unmarshal is a correct way to create a deep copy of the messages, it can introduce performance overhead due to reflection and memory allocations, especially for a large number of messages.

For a library like this where performance can be important for consumers, consider implementing a more performant, manual Clone() method on the Message struct and its nested types. This would avoid the overhead of JSON processing.

Here's a conceptual example of what that might look like:

func (m *Message) Clone() *Message {
    if m == nil {
        return nil
    }
    cloned := &Message{
        Role:     m.Role,
        Metadata: maps.Clone(m.Metadata), // maps.Clone is available in Go 1.21+
    }
    if m.Content != nil {
        cloned.Content = make([]*Part, len(m.Content))
        for i, p := range m.Content {
            cloned.Content[i] = p.Clone() // Assumes Part has a Clone method
        }
    }
    return cloned
}

This would require adding Clone() methods to Part and other nested structs as well. This could be a good follow-up improvement.

},
}
}
Expand Down
31 changes: 31 additions & 0 deletions go/ai/prompt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3292,3 +3292,34 @@ Hello {{name}}, please help me with {{task}}.
}
})
}

func TestPromptWithDifferentInput(t *testing.T) {
reg := registry.New()
prompt, err := LoadPromptFromSource(reg, `Hello, {{name}}!`, "rawPrompt", "test-ns")
if err != nil {
t.Fatalf("LoadPromptFromRaw failed: %v", err)
}

if prompt == nil {
t.Fatal("prompt is nil")
}

for i := range 10 {
opts, err := prompt.Render(t.Context(), map[string]int{
"name": i,
})

if err != nil {
t.Fatalf("Render failed on iteration %d: %v", i, err)
}

if len(opts.Messages) != 1 {
t.Fatalf("Expected 1 message, got %d", len(opts.Messages))
}

expected := fmt.Sprintf("Hello, %d!", i)
if opts.Messages[0].Text() != expected {
t.Errorf("Iteration %d: got %q, want %q", i, opts.Messages[0].Text(), expected)
}
}
}