Skip to content

Commit d9d84ee

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Trigger traceCallLlm to set call_llm attributes before span ends
PiperOrigin-RevId: 882023326
1 parent 0d6dd55 commit d9d84ee

3 files changed

Lines changed: 78 additions & 60 deletions

File tree

core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -190,20 +190,23 @@ private Flowable<LlmResponse> callLlm(
190190
context, llmRequestBuilder, eventForCallbackUsage, exception)
191191
.switchIfEmpty(Single.error(exception))
192192
.toFlowable())
193-
.doOnNext(
194-
llmResp ->
195-
Tracing.traceCallLlm(
196-
context,
197-
eventForCallbackUsage.id(),
198-
llmRequestBuilder.build(),
199-
llmResp))
200193
.doOnError(
201194
error -> {
202195
Span span = Span.current();
203196
span.setStatus(StatusCode.ERROR, error.getMessage());
204197
span.recordException(error);
205198
})
206-
.compose(Tracing.<LlmResponse>trace("call_llm").setParent(spanContext))
199+
.compose(
200+
Tracing.<LlmResponse>trace("call_llm")
201+
.setParent(spanContext)
202+
.onSuccess(
203+
(span, llmResp) ->
204+
Tracing.traceCallLlm(
205+
span,
206+
context,
207+
eventForCallbackUsage.id(),
208+
llmRequestBuilder.build(),
209+
llmResp)))
207210
.concatMap(
208211
llmResp ->
209212
handleAfterModelCallback(context, llmResp, eventForCallbackUsage)

core/src/main/java/com/google/adk/telemetry/Tracing.java

Lines changed: 66 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
import java.util.Map;
5555
import java.util.Objects;
5656
import java.util.Optional;
57+
import java.util.function.BiConsumer;
5758
import java.util.function.Consumer;
5859
import java.util.function.Supplier;
5960
import org.reactivestreams.Publisher;
@@ -292,58 +293,49 @@ private static Map<String, Object> buildLlmRequestForTrace(LlmRequest llmRequest
292293
* @param llmResponse The LLM response object.
293294
*/
294295
public static void traceCallLlm(
296+
Span span,
295297
InvocationContext invocationContext,
296298
String eventId,
297299
LlmRequest llmRequest,
298300
LlmResponse llmResponse) {
299-
traceWithSpan(
300-
"traceCallLlm",
301-
span -> {
302-
span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent");
303-
llmRequest
304-
.model()
305-
.ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName));
306-
307-
setInvocationAttributes(span, invocationContext, eventId);
308-
309-
setJsonAttribute(span, ADK_LLM_REQUEST, buildLlmRequestForTrace(llmRequest));
310-
setJsonAttribute(span, ADK_LLM_RESPONSE, llmResponse);
311-
312-
llmRequest
313-
.config()
314-
.ifPresent(
315-
config -> {
316-
config
317-
.topP()
318-
.ifPresent(
319-
topP -> span.setAttribute(GEN_AI_REQUEST_TOP_P, topP.doubleValue()));
320-
config
321-
.maxOutputTokens()
322-
.ifPresent(
323-
maxTokens ->
324-
span.setAttribute(
325-
GEN_AI_REQUEST_MAX_TOKENS, maxTokens.longValue()));
326-
});
327-
llmResponse
328-
.usageMetadata()
329-
.ifPresent(
330-
usage -> {
331-
usage
332-
.promptTokenCount()
333-
.ifPresent(
334-
tokens -> span.setAttribute(GEN_AI_USAGE_INPUT_TOKENS, (long) tokens));
335-
usage
336-
.candidatesTokenCount()
337-
.ifPresent(
338-
tokens -> span.setAttribute(GEN_AI_USAGE_OUTPUT_TOKENS, (long) tokens));
339-
});
340-
llmResponse
341-
.finishReason()
342-
.map(reason -> reason.knownEnum().name().toLowerCase(Locale.ROOT))
343-
.ifPresent(
344-
reason ->
345-
span.setAttribute(GEN_AI_RESPONSE_FINISH_REASONS, ImmutableList.of(reason)));
346-
});
301+
span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent");
302+
llmRequest.model().ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName));
303+
304+
setInvocationAttributes(span, invocationContext, eventId);
305+
306+
setJsonAttribute(span, ADK_LLM_REQUEST, buildLlmRequestForTrace(llmRequest));
307+
setJsonAttribute(span, ADK_LLM_RESPONSE, llmResponse);
308+
309+
llmRequest
310+
.config()
311+
.ifPresent(
312+
config -> {
313+
config
314+
.topP()
315+
.ifPresent(topP -> span.setAttribute(GEN_AI_REQUEST_TOP_P, topP.doubleValue()));
316+
config
317+
.maxOutputTokens()
318+
.ifPresent(
319+
maxTokens ->
320+
span.setAttribute(GEN_AI_REQUEST_MAX_TOKENS, maxTokens.longValue()));
321+
});
322+
llmResponse
323+
.usageMetadata()
324+
.ifPresent(
325+
usage -> {
326+
usage
327+
.promptTokenCount()
328+
.ifPresent(tokens -> span.setAttribute(GEN_AI_USAGE_INPUT_TOKENS, (long) tokens));
329+
usage
330+
.candidatesTokenCount()
331+
.ifPresent(
332+
tokens -> span.setAttribute(GEN_AI_USAGE_OUTPUT_TOKENS, (long) tokens));
333+
});
334+
llmResponse
335+
.finishReason()
336+
.map(reason -> reason.knownEnum().name().toLowerCase(Locale.ROOT))
337+
.ifPresent(
338+
reason -> span.setAttribute(GEN_AI_RESPONSE_FINISH_REASONS, ImmutableList.of(reason)));
347339
}
348340

349341
/**
@@ -455,6 +447,7 @@ public static final class TracerProvider<T>
455447
private final String spanName;
456448
private Context explicitParentContext;
457449
private final List<Consumer<Span>> spanConfigurers = new ArrayList<>();
450+
private BiConsumer<Span, T> onSuccessConsumer;
458451

459452
private TracerProvider(String spanName) {
460453
this.spanName = spanName;
@@ -474,6 +467,16 @@ public TracerProvider<T> setParent(Context parentContext) {
474467
return this;
475468
}
476469

470+
/**
471+
* Registers a callback to be executed with the span and the result item when the stream emits a
472+
* success value.
473+
*/
474+
@CanIgnoreReturnValue
475+
public TracerProvider<T> onSuccess(BiConsumer<Span, T> consumer) {
476+
this.onSuccessConsumer = consumer;
477+
return this;
478+
}
479+
477480
private Context getParentContext() {
478481
return explicitParentContext != null ? explicitParentContext : Context.current();
479482
}
@@ -504,7 +507,11 @@ public Publisher<T> apply(Flowable<T> upstream) {
504507
return Flowable.defer(
505508
() -> {
506509
TracingLifecycle lifecycle = new TracingLifecycle();
507-
return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end);
510+
Flowable<T> pipeline = upstream.doOnSubscribe(s -> lifecycle.start());
511+
if (onSuccessConsumer != null) {
512+
pipeline = pipeline.doOnNext(t -> onSuccessConsumer.accept(lifecycle.span, t));
513+
}
514+
return pipeline.doFinally(lifecycle::end);
508515
});
509516
}
510517

@@ -513,7 +520,11 @@ public SingleSource<T> apply(Single<T> upstream) {
513520
return Single.defer(
514521
() -> {
515522
TracingLifecycle lifecycle = new TracingLifecycle();
516-
return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end);
523+
Single<T> pipeline = upstream.doOnSubscribe(s -> lifecycle.start());
524+
if (onSuccessConsumer != null) {
525+
pipeline = pipeline.doOnSuccess(t -> onSuccessConsumer.accept(lifecycle.span, t));
526+
}
527+
return pipeline.doFinally(lifecycle::end);
517528
});
518529
}
519530

@@ -522,7 +533,11 @@ public MaybeSource<T> apply(Maybe<T> upstream) {
522533
return Maybe.defer(
523534
() -> {
524535
TracingLifecycle lifecycle = new TracingLifecycle();
525-
return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end);
536+
Maybe<T> pipeline = upstream.doOnSubscribe(s -> lifecycle.start());
537+
if (onSuccessConsumer != null) {
538+
pipeline = pipeline.doOnSuccess(t -> onSuccessConsumer.accept(lifecycle.span, t));
539+
}
540+
return pipeline.doFinally(lifecycle::end);
526541
});
527542
}
528543

core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ public void testTraceCallLlm() {
503503
.totalTokenCount(30)
504504
.build())
505505
.build();
506-
Tracing.traceCallLlm(buildInvocationContext(), "event-1", llmRequest, llmResponse);
506+
Tracing.traceCallLlm(span, buildInvocationContext(), "event-1", llmRequest, llmResponse);
507507
} finally {
508508
span.end();
509509
}

0 commit comments

Comments
 (0)