Skip to content

Commit 4709af8

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Propogating the otel context
This change ensures that the OpenTelemetry context is correctly propagated across asynchronous boundaries throughout the ADK, primarily within RxJava streams. ### Key Changes * **Context Propagation:** Replaces manual `Scope` management (which often fails in reactive code) with `.compose(Tracing.withContext(context))`. This ensures the OTel context is preserved when work moves between different threads or schedulers. * **`Runner` Refactoring:** * Adds a top-level `"invocation"` span to `runAsync` and `runLive` calls. * Captures the context at entry points and propagates it through the internal execution flow (`runAsyncImpl`, `runLiveImpl`, `runAgentWithFreshSession`). * **`BaseLlmFlow` & `Functions`:** Updates preprocessing, postprocessing, and tool execution logic to maintain context. This ensures that spans created within tools or processors are correctly parented. * **`PluginManager`:** Ensures that plugin callbacks (like `afterRunCallback` and `onEventCallback`) execute within the captured context. * **Testing:** Adds several unit tests across `BaseLlmFlowTest`, `FunctionsTest`, `PluginManagerTest`, and `RunnerTest` that specifically verify context propagation using `ContextKey` and `Schedulers.computation()`. ### Files Modified * **`BaseLlmFlow.java`**, **`Functions.java`**, **`PluginManager.java`**, **`Runner.java`**: Core logic updates for context propagation. * **`LlmAgentTest.java`**, **`BaseLlmFlowTest.java`**, **`FunctionsTest.java`**, **`PluginManagerTest.java`**, **`RunnerTest.java`**: New tests for OTel integration. * **`BUILD` files**: Updated dependencies for OpenTelemetry APIs and SDK testing. PiperOrigin-RevId: 881463869
1 parent 3c702b1 commit 4709af8

9 files changed

Lines changed: 683 additions & 165 deletions

File tree

core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,9 @@ public BaseLlmFlow(
9191
* RequestProcessor} transforming the provided {@code llmRequestRef} in-place, and emits the
9292
* events generated by them.
9393
*/
94-
protected Flowable<Event> preprocess(
94+
private Flowable<Event> preprocess(
9595
InvocationContext context, AtomicReference<LlmRequest> llmRequestRef) {
96+
Context currentContext = Context.current();
9697
LlmAgent agent = (LlmAgent) context.agent();
9798

9899
RequestProcessor toolsProcessor =
@@ -114,9 +115,11 @@ protected Flowable<Event> preprocess(
114115
.concatMap(
115116
processor ->
116117
Single.defer(() -> processor.processRequest(context, llmRequestRef.get()))
118+
.compose(Tracing.withContext(currentContext))
117119
.doOnSuccess(result -> llmRequestRef.set(result.updatedRequest()))
118120
.flattenAsFlowable(
119-
result -> result.events() != null ? result.events() : ImmutableList.of()));
121+
result -> result.events() != null ? result.events() : ImmutableList.of()))
122+
.compose(Tracing.withContext(currentContext));
120123
}
121124

122125
/**
@@ -146,13 +149,13 @@ protected Flowable<Event> postprocess(
146149
}
147150
Context parentContext = Context.current();
148151

149-
return currentLlmResponse.flatMapPublisher(
150-
updatedResponse -> {
151-
try (Scope scope = parentContext.makeCurrent()) {
152-
return buildPostprocessingEvents(
153-
updatedResponse, eventIterables, context, baseEventForLlmResponse, llmRequest);
154-
}
155-
});
152+
return currentLlmResponse
153+
.compose(Tracing.withContext(parentContext))
154+
.flatMapPublisher(
155+
updatedResponse ->
156+
buildPostprocessingEvents(
157+
updatedResponse, eventIterables, context, baseEventForLlmResponse, llmRequest))
158+
.compose(Tracing.withContext(parentContext));
156159
}
157160

158161
/**
@@ -222,6 +225,7 @@ private Flowable<LlmResponse> callLlm(
222225
*/
223226
private Maybe<LlmResponse> handleBeforeModelCallback(
224227
InvocationContext context, LlmRequest.Builder llmRequestBuilder, Event modelResponseEvent) {
228+
Context currentContext = Context.current();
225229
Event callbackEvent = modelResponseEvent.toBuilder().build();
226230
CallbackContext callbackContext =
227231
new CallbackContext(context, callbackEvent.actions(), callbackEvent.id());
@@ -240,7 +244,11 @@ private Maybe<LlmResponse> handleBeforeModelCallback(
240244
Maybe.defer(
241245
() ->
242246
Flowable.fromIterable(callbacks)
243-
.concatMapMaybe(callback -> callback.call(callbackContext, llmRequestBuilder))
247+
.concatMapMaybe(
248+
callback ->
249+
callback
250+
.call(callbackContext, llmRequestBuilder)
251+
.compose(Tracing.withContext(currentContext)))
244252
.firstElement());
245253

246254
return pluginResult.switchIfEmpty(callbackResult);
@@ -257,6 +265,7 @@ private Maybe<LlmResponse> handleOnModelErrorCallback(
257265
LlmRequest.Builder llmRequestBuilder,
258266
Event modelResponseEvent,
259267
Throwable throwable) {
268+
Context currentContext = Context.current();
260269
Event callbackEvent = modelResponseEvent.toBuilder().build();
261270
CallbackContext callbackContext =
262271
new CallbackContext(context, callbackEvent.actions(), callbackEvent.id());
@@ -277,7 +286,11 @@ private Maybe<LlmResponse> handleOnModelErrorCallback(
277286
() -> {
278287
LlmRequest llmRequest = llmRequestBuilder.build();
279288
return Flowable.fromIterable(callbacks)
280-
.concatMapMaybe(callback -> callback.call(callbackContext, llmRequest, ex))
289+
.concatMapMaybe(
290+
callback ->
291+
callback
292+
.call(callbackContext, llmRequest, ex)
293+
.compose(Tracing.withContext(currentContext)))
281294
.firstElement();
282295
});
283296

@@ -292,6 +305,7 @@ private Maybe<LlmResponse> handleOnModelErrorCallback(
292305
*/
293306
private Single<LlmResponse> handleAfterModelCallback(
294307
InvocationContext context, LlmResponse llmResponse, Event modelResponseEvent) {
308+
Context currentContext = Context.current();
295309
Event callbackEvent = modelResponseEvent.toBuilder().build();
296310
CallbackContext callbackContext =
297311
new CallbackContext(context, callbackEvent.actions(), callbackEvent.id());
@@ -310,7 +324,11 @@ private Single<LlmResponse> handleAfterModelCallback(
310324
Maybe.defer(
311325
() ->
312326
Flowable.fromIterable(callbacks)
313-
.concatMapMaybe(callback -> callback.call(callbackContext, llmResponse))
327+
.concatMapMaybe(
328+
callback ->
329+
callback
330+
.call(callbackContext, llmResponse)
331+
.compose(Tracing.withContext(currentContext)))
314332
.firstElement());
315333

316334
return pluginResult.switchIfEmpty(callbackResult).defaultIfEmpty(llmResponse);

core/src/main/java/com/google/adk/flows/llmflows/Functions.java

Lines changed: 66 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
import com.google.genai.types.Part;
4343
import io.opentelemetry.api.trace.Span;
4444
import io.opentelemetry.context.Context;
45-
import io.opentelemetry.context.Scope;
4645
import io.reactivex.rxjava3.core.Flowable;
4746
import io.reactivex.rxjava3.core.Maybe;
4847
import io.reactivex.rxjava3.core.Observable;
@@ -163,7 +162,9 @@ public static Maybe<Event> handleFunctionCalls(
163162
}
164163
return functionResponseEventsObservable
165164
.toList()
166-
.flatMapMaybe(
165+
.toMaybe()
166+
.compose(Tracing.withContext(parentContext))
167+
.flatMap(
167168
events -> {
168169
if (events.isEmpty()) {
169170
return Maybe.empty();
@@ -226,7 +227,9 @@ public static Maybe<Event> handleFunctionCallsLive(
226227

227228
return responseEventsObservable
228229
.toList()
229-
.flatMapMaybe(
230+
.toMaybe()
231+
.compose(Tracing.withContext(parentContext))
232+
.flatMap(
230233
events -> {
231234
if (events.isEmpty()) {
232235
return Maybe.empty();
@@ -243,47 +246,45 @@ private static Function<FunctionCall, Maybe<Event>> getFunctionCallMapper(
243246
Context parentContext) {
244247
return functionCall ->
245248
Maybe.defer(
246-
() -> {
247-
try (Scope scope = parentContext.makeCurrent()) {
248-
BaseTool tool = tools.get(functionCall.name().get());
249-
ToolContext toolContext =
250-
ToolContext.builder(invocationContext)
251-
.functionCallId(functionCall.id().orElse(""))
252-
.toolConfirmation(
253-
functionCall.id().map(toolConfirmations::get).orElse(null))
254-
.build();
255-
256-
Map<String, Object> functionArgs =
257-
functionCall.args().map(HashMap::new).orElse(new HashMap<>());
258-
259-
Maybe<Map<String, Object>> maybeFunctionResult =
260-
maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext)
261-
.switchIfEmpty(
262-
Maybe.defer(
263-
() -> {
264-
try (Scope innerScope = parentContext.makeCurrent()) {
265-
return isLive
266-
? processFunctionLive(
267-
invocationContext,
268-
tool,
269-
toolContext,
270-
functionCall,
271-
functionArgs,
272-
parentContext)
273-
: callTool(tool, functionArgs, toolContext, parentContext);
274-
}
275-
}));
276-
277-
return postProcessFunctionResult(
278-
maybeFunctionResult,
279-
invocationContext,
280-
tool,
281-
functionArgs,
282-
toolContext,
283-
isLive,
284-
parentContext);
285-
}
286-
});
249+
() -> {
250+
BaseTool tool = tools.get(functionCall.name().get());
251+
ToolContext toolContext =
252+
ToolContext.builder(invocationContext)
253+
.functionCallId(functionCall.id().orElse(""))
254+
.toolConfirmation(
255+
functionCall.id().map(toolConfirmations::get).orElse(null))
256+
.build();
257+
258+
Map<String, Object> functionArgs =
259+
functionCall.args().map(HashMap::new).orElse(new HashMap<>());
260+
261+
Maybe<Map<String, Object>> maybeFunctionResult =
262+
maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext)
263+
.switchIfEmpty(
264+
Maybe.defer(
265+
() ->
266+
isLive
267+
? processFunctionLive(
268+
invocationContext,
269+
tool,
270+
toolContext,
271+
functionCall,
272+
functionArgs,
273+
parentContext)
274+
: callTool(
275+
tool, functionArgs, toolContext, parentContext))
276+
.compose(Tracing.withContext(parentContext)));
277+
278+
return postProcessFunctionResult(
279+
maybeFunctionResult,
280+
invocationContext,
281+
tool,
282+
functionArgs,
283+
toolContext,
284+
isLive,
285+
parentContext);
286+
})
287+
.compose(Tracing.withContext(parentContext));
287288
}
288289

289290
/**
@@ -410,34 +411,27 @@ private static Maybe<Event> postProcessFunctionResult(
410411
})
411412
.flatMapMaybe(
412413
optionalInitialResult -> {
413-
try (Scope scope = parentContext.makeCurrent()) {
414-
Map<String, Object> initialFunctionResult = optionalInitialResult.orElse(null);
415-
416-
return maybeInvokeAfterToolCall(
417-
invocationContext, tool, functionArgs, toolContext, initialFunctionResult)
418-
.map(Optional::of)
419-
.defaultIfEmpty(Optional.ofNullable(initialFunctionResult))
420-
.flatMapMaybe(
421-
finalOptionalResult -> {
422-
Map<String, Object> finalFunctionResult =
423-
finalOptionalResult.orElse(null);
424-
if (tool.longRunning() && finalFunctionResult == null) {
425-
return Maybe.empty();
426-
}
427-
return Maybe.fromCallable(
428-
() ->
429-
buildResponseEvent(
430-
tool,
431-
finalFunctionResult,
432-
toolContext,
433-
invocationContext))
434-
.compose(
435-
Tracing.<Event>trace("tool_response [" + tool.name() + "]")
436-
.setParent(parentContext))
437-
.doOnSuccess(event -> Tracing.traceToolResponse(event.id(), event));
438-
});
439-
}
440-
});
414+
Map<String, Object> initialFunctionResult = optionalInitialResult.orElse(null);
415+
416+
return maybeInvokeAfterToolCall(
417+
invocationContext, tool, functionArgs, toolContext, initialFunctionResult)
418+
.map(Optional::of)
419+
.defaultIfEmpty(Optional.ofNullable(initialFunctionResult))
420+
.flatMapMaybe(
421+
finalOptionalResult -> {
422+
Map<String, Object> finalFunctionResult = finalOptionalResult.orElse(null);
423+
if (tool.longRunning() && finalFunctionResult == null) {
424+
return Maybe.empty();
425+
}
426+
Event event =
427+
buildResponseEvent(
428+
tool, finalFunctionResult, toolContext, invocationContext);
429+
Tracing.traceToolResponse(event.id(), event);
430+
return Maybe.just(event);
431+
});
432+
})
433+
.compose(
434+
Tracing.<Event>trace("tool_response [" + tool.name() + "]").setParent(parentContext));
441435
}
442436

443437
private static Optional<Event> mergeParallelFunctionResponseEvents(

core/src/main/java/com/google/adk/plugins/PluginManager.java

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@
2121
import com.google.adk.events.Event;
2222
import com.google.adk.models.LlmRequest;
2323
import com.google.adk.models.LlmResponse;
24+
import com.google.adk.telemetry.Tracing;
2425
import com.google.adk.tools.BaseTool;
2526
import com.google.adk.tools.ToolContext;
2627
import com.google.common.annotations.VisibleForTesting;
2728
import com.google.common.collect.ImmutableList;
2829
import com.google.genai.types.Content;
30+
import io.opentelemetry.context.Context;
2931
import io.reactivex.rxjava3.core.Completable;
3032
import io.reactivex.rxjava3.core.Flowable;
3133
import io.reactivex.rxjava3.core.Maybe;
@@ -126,6 +128,7 @@ public Maybe<Content> beforeRunCallback(InvocationContext invocationContext) {
126128

127129
@Override
128130
public Completable afterRunCallback(InvocationContext invocationContext) {
131+
Context capturedContext = Context.current();
129132
return Flowable.fromIterable(plugins)
130133
.concatMapCompletable(
131134
plugin ->
@@ -136,20 +139,22 @@ public Completable afterRunCallback(InvocationContext invocationContext) {
136139
logger.error(
137140
"[{}] Error during callback 'afterRunCallback'",
138141
plugin.getName(),
139-
e)));
142+
e))
143+
.compose(Tracing.withContext(capturedContext)));
140144
}
141145

142146
@Override
143147
public Completable close() {
148+
Context capturedContext = Context.current();
144149
return Flowable.fromIterable(plugins)
145150
.concatMapCompletableDelayError(
146151
plugin ->
147152
plugin
148153
.close()
149154
.doOnError(
150155
e ->
151-
logger.error(
152-
"[{}] Error during callback 'close'", plugin.getName(), e)));
156+
logger.error("[{}] Error during callback 'close'", plugin.getName(), e))
157+
.compose(Tracing.withContext(capturedContext)));
153158
}
154159

155160
@Override
@@ -227,7 +232,7 @@ public Maybe<Map<String, Object>> onToolErrorCallback(
227232
*/
228233
private <T> Maybe<T> runMaybeCallbacks(
229234
Function<Plugin, Maybe<T>> callbackExecutor, String callbackName) {
230-
235+
Context capturedContext = Context.current();
231236
return Flowable.fromIterable(this.plugins)
232237
.concatMapMaybe(
233238
plugin ->
@@ -246,7 +251,8 @@ private <T> Maybe<T> runMaybeCallbacks(
246251
"[{}] Error during callback '{}'",
247252
plugin.getName(),
248253
callbackName,
249-
e)))
254+
e))
255+
.compose(Tracing.<T>withContext(capturedContext)))
250256
.firstElement();
251257
}
252258
}

0 commit comments

Comments
 (0)