Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 77 additions & 18 deletions lib/braintrust/eval/context.rb
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@ class Context
:project_id, :project_name, :state, :tracer_provider,
:on_progress, :parent_span_attr, :generation

# @param task [Task] Normalized task wrapper
# @param scorers [Array<Scorer>] Normalized scorer wrappers
# @param cases [Cases] Normalized eval cases
# @param experiment_id [String, nil] Experiment ID for logging and trace linkage
# @param experiment_name [String, nil] Experiment name, included in span attributes
# @param project_id [String, nil] Project ID
# @param project_name [String, nil] Project name
# @param state [Braintrust::State, nil] Authenticated API state; nil for local-only evals
# @param tracer_provider [#tracer, nil] OpenTelemetry tracer provider
# @param on_progress [Proc, nil] Callback invoked after each case completes, receiving a progress Hash
# @param parent_span_attr [String, nil] Formatted parent span identifier ("type:id"), linking spans to a parent context
# @param generation [Integer, nil] Generation number from the parent span context, used to link spans in a trace hierarchy
def initialize(task:, scorers:, cases:, experiment_id: nil, experiment_name: nil,
project_id: nil, project_name: nil, state: nil, tracer_provider: nil,
on_progress: nil, parent_span_attr: nil, generation: nil)
Expand All @@ -29,37 +41,77 @@ def initialize(task:, scorers:, cases:, experiment_id: nil, experiment_name: nil
end

# Build a Context from raw user inputs.
# Factory normalizes task, scorers, and cases into typed wrappers.
# Parent is resolved into parent_span_attr and generation.
# Delegates to Factory for normalization.
# @param task [Task, Proc, #call] Task to evaluate; wrapped into a {Task} if needed
# @param scorers [Array<Scorer, Proc, String, Scorer::ID, #call>] Scorers; each is normalized into a {Scorer}
# @param cases [Cases, Array, Enumerable] Eval cases; wrapped into {Cases} if needed
# @param experiment_id [String, nil] Experiment ID for logging
# @param experiment_name [String, nil] Experiment name, included in span attributes
# @param project_id [String, nil] Project ID
# @param project_name [String, nil] Project name; required when resolving scorer slugs
# @param state [Braintrust::State, nil] Authenticated API state; nil for local-only evals
# @param tracer_provider [#tracer, nil] OpenTelemetry tracer provider; defaults to global provider
# @param on_progress [Proc, nil] Callback invoked after each case completes, receiving a progress Hash
# @param parent [Hash, nil] Parent span info with keys :object_type, :object_id, and optionally :generation
# @return [Context]
def self.build(task:, scorers:, cases:, experiment_id: nil, experiment_name: nil,
project_id: nil, project_name: nil, state: nil, tracer_provider: nil,
on_progress: nil, parent: nil)
factory = Factory.new(state: state, tracer_provider: tracer_provider, project_name: project_name)

Context.new(
task: factory.normalize_task(task),
scorers: factory.normalize_scorers(scorers),
cases: factory.normalize_cases(cases),
experiment_id: experiment_id,
experiment_name: experiment_name,
project_id: project_id,
project_name: project_name,
state: state,
tracer_provider: tracer_provider,
on_progress: on_progress,
parent_span_attr: factory.resolve_parent_span_attr(parent),
generation: parent&.dig(:generation)
Factory.new(
state: state, tracer_provider: tracer_provider,
project_id: project_id, project_name: project_name
).build(
task: task, scorers: scorers, cases: cases,
experiment_id: experiment_id, experiment_name: experiment_name,
on_progress: on_progress, parent: parent
)
end

# Encapsulates normalization of raw user inputs into typed wrappers.
class Factory
def initialize(state: nil, tracer_provider: nil, project_name: nil)
# @param state [Braintrust::State, nil] Authenticated API state; passed through to scorer resolution
# @param tracer_provider [#tracer, nil] OpenTelemetry tracer provider; passed through to remote scorers
# @param project_id [String, nil] Project ID; passed through to the built Context
# @param project_name [String, nil] Project name; required when resolving scorer slugs
def initialize(state: nil, tracer_provider: nil, project_id: nil, project_name: nil)
@state = state
@tracer_provider = tracer_provider
@project_id = project_id
@project_name = project_name
end

# Normalize raw inputs and construct a {Context}.
# @param task [Task, Proc, #call] Raw task
# @param scorers [Array] Raw scorers
# @param cases [Cases, Array, Enumerable] Raw eval cases
# @param experiment_id [String, nil]
# @param experiment_name [String, nil]
# @param on_progress [Proc, nil]
# @param parent [Hash, nil] Parent span info with keys :object_type, :object_id, and optionally :generation
# @return [Context]
def build(task:, scorers:, cases:, experiment_id: nil, experiment_name: nil,
on_progress: nil, parent: nil)
Context.new(
task: normalize_task(task),
scorers: normalize_scorers(scorers),
cases: normalize_cases(cases),
experiment_id: experiment_id,
experiment_name: experiment_name,
project_id: @project_id,
project_name: @project_name,
state: @state,
tracer_provider: @tracer_provider || OpenTelemetry.tracer_provider,
on_progress: on_progress,
parent_span_attr: resolve_parent_span_attr(parent),
generation: parent&.dig(:generation)
)
end

private

# @param raw [Cases, Array, Enumerable, #each]
# @return [Cases]
# @raise [ArgumentError] if raw is not enumerable
def normalize_cases(raw)
case raw
when Cases
Expand All @@ -75,11 +127,15 @@ def normalize_cases(raw)
end
end

# @param parent [Hash, nil]
# @return [String, nil] Formatted as "type:id", e.g. "experiment_id:abc-123"
def resolve_parent_span_attr(parent)
return nil unless parent
"#{parent[:object_type]}:#{parent[:object_id]}"
end

# @param raw [Task, Proc, #call]
# @return [Task]
def normalize_task(raw)
case raw
when Task
Expand All @@ -95,6 +151,9 @@ def normalize_task(raw)
end
end

# @param raw [Array<Scorer, Proc, String, Scorer::ID, #call>]
# @return [Array<Scorer>]
# @raise [ArgumentError] if a String slug is given without a project name
def normalize_scorers(raw)
raw.map do |scorer|
case scorer
Expand Down
71 changes: 35 additions & 36 deletions lib/braintrust/eval/runner.rb
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ class Runner
# @param eval_context [Context] Normalized eval context
def initialize(eval_context)
@eval_context = eval_context
tracer_provider = eval_context.tracer_provider || OpenTelemetry.tracer_provider
@tracer = tracer_provider.tracer("braintrust-eval")
@tracer = eval_context.tracer_provider.tracer("braintrust-eval")

# Mutex for thread-safe score collection
@score_mutex = Mutex.new
Expand Down Expand Up @@ -79,68 +78,68 @@ def run(parallelism: 1)

# Run a single test case with OpenTelemetry tracing
# Creates eval span (parent) with task and score as children
# @param case_context [CaseContext] The per-case accumulator
# @param kase [CaseContext] The per-case accumulator
# @param errors [Queue] Thread-safe error collection queue
def run_eval_case(case_context, errors)
def run_eval_case(kase, errors)
# Each eval case starts its own trace — detach from any ambient span context
eval_span = tracer.start_root_span("eval")
OpenTelemetry::Trace.with_span(eval_span) do
# Set attributes known before task execution
eval_span.set_attribute("braintrust.parent", eval_context.parent_span_attr) if eval_context.parent_span_attr
set_json_attr(eval_span, "braintrust.span_attributes", build_span_attributes("eval"))
set_json_attr(eval_span, "braintrust.input_json", {input: case_context.input})
set_json_attr(eval_span, "braintrust.expected", case_context.expected) if case_context.expected
set_json_attr(eval_span, "braintrust.metadata", case_context.metadata) if case_context.metadata
eval_span.set_attribute("braintrust.tags", case_context.tags) if case_context.tags
eval_span.set_attribute("braintrust.origin", case_context.origin) if case_context.origin
set_json_attr(eval_span, "braintrust.input_json", {input: kase.input})
set_json_attr(eval_span, "braintrust.expected", kase.expected) if kase.expected
set_json_attr(eval_span, "braintrust.metadata", kase.metadata) if kase.metadata
eval_span.set_attribute("braintrust.tags", kase.tags) if kase.tags
eval_span.set_attribute("braintrust.origin", kase.origin) if kase.origin

# Run task
begin
case_context.output = run_task(case_context)
kase.output = run_task(kase)
rescue => e
# Error already recorded on task span, set eval span status
eval_span.status = OpenTelemetry::Trace::Status.error(e.message)
set_json_attr(eval_span, "braintrust.output_json", {output: nil})
errors << "Task failed for input '#{case_context.input}': #{e.message}"
report_progress(eval_span, case_context, error: e.message)
errors << "Task failed for input '#{kase.input}': #{e.message}"
report_progress(eval_span, kase, error: e.message)
next
end

# Flush spans so they're queryable via BTQL, then build trace
eval_context.tracer_provider&.force_flush
case_context.trace = build_trace(eval_span)
eval_context.tracer_provider.force_flush if eval_context.tracer_provider.respond_to?(:force_flush)
kase.trace = build_trace(eval_span)

# Run scorers
begin
run_scorers(case_context)
run_scorers(kase)
rescue => e
# Error already recorded on score span, set eval span status
eval_span.status = OpenTelemetry::Trace::Status.error(e.message)
errors << "Scorers failed for input '#{case_context.input}': #{e.message}"
errors << "Scorers failed for input '#{kase.input}': #{e.message}"
end

# Set output after task completes
set_json_attr(eval_span, "braintrust.output_json", {output: case_context.output})
set_json_attr(eval_span, "braintrust.output_json", {output: kase.output})

report_progress(eval_span, case_context, data: case_context.output)
report_progress(eval_span, kase, data: kase.output)
end
ensure
eval_span&.finish
end

# Run task with OpenTelemetry tracing
# Creates task span with input and output
# @param case_context [CaseContext] The per-case context
# @param kase [CaseContext] The per-case context
# @return [Object] Task output
def run_task(case_context)
def run_task(kase)
tracer.in_span("task") do |task_span|
task_span.set_attribute("braintrust.parent", eval_context.parent_span_attr) if eval_context.parent_span_attr
set_json_attr(task_span, "braintrust.span_attributes", build_span_attributes("task"))
set_json_attr(task_span, "braintrust.input_json", case_context.input)
set_json_attr(task_span, "braintrust.input_json", kase.input)

begin
output = eval_context.task.call(
input: case_context.input
input: kase.input
)
set_json_attr(task_span, "braintrust.output_json", output)
output
Expand All @@ -155,20 +154,20 @@ def run_task(case_context)

# Run scorers with OpenTelemetry tracing.
# Creates one span per scorer, each a direct child of the current (eval) span.
# @param case_context [CaseContext] The per-case context (output must be populated)
def run_scorers(case_context)
# @param kase [CaseContext] The per-case context (output must be populated)
def run_scorers(kase)
scorer_kwargs = {
input: case_context.input,
expected: case_context.expected,
output: case_context.output,
metadata: case_context.metadata || {},
trace: case_context.trace
input: kase.input,
expected: kase.expected,
output: kase.output,
metadata: kase.metadata || {},
trace: kase.trace
}
scorer_input = {
input: case_context.input,
expected: case_context.expected,
output: case_context.output,
metadata: case_context.metadata || {}
input: kase.input,
expected: kase.expected,
output: kase.output,
metadata: kase.metadata || {}
}

scorer_error = nil
Expand Down Expand Up @@ -241,11 +240,11 @@ def build_case_context(eval_case)

# Report progress for a case via on_progress callback.
# Rescues errors in the callback so a broken handler never crashes the eval.
def report_progress(eval_span, case_context, **fields)
def report_progress(eval_span, kase, **fields)
return unless eval_context.on_progress
progress = {"id" => eval_span.context.hex_span_id}.merge(fields.transform_keys(&:to_s))
if case_context.origin
progress["origin"] = case_context.origin.is_a?(String) ? JSON.parse(case_context.origin) : case_context.origin
if kase.origin
progress["origin"] = kase.origin.is_a?(String) ? JSON.parse(kase.origin) : kase.origin
end
eval_context.on_progress.call(progress)
rescue => e
Expand Down
Loading
Loading