Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions core/src/main/java/com/google/adk/agents/BaseAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ public abstract class BaseAgent {

private final List<? extends BaseAgent> subAgents;

private final Optional<List<? extends BeforeAgentCallback>> beforeAgentCallback;
private final Optional<List<? extends AfterAgentCallback>> afterAgentCallback;
private final List<? extends BeforeAgentCallback> beforeAgentCallback;
private final List<? extends AfterAgentCallback> afterAgentCallback;

/**
* Creates a new BaseAgent.
Expand All @@ -83,8 +83,9 @@ public BaseAgent(
this.description = description;
this.parentAgent = null;
this.subAgents = subAgents != null ? subAgents : ImmutableList.of();
this.beforeAgentCallback = Optional.ofNullable(beforeAgentCallback);
this.afterAgentCallback = Optional.ofNullable(afterAgentCallback);
this.beforeAgentCallback =
beforeAgentCallback != null ? beforeAgentCallback : ImmutableList.of();
this.afterAgentCallback = afterAgentCallback != null ? afterAgentCallback : ImmutableList.of();

// Establish parent relationships for all sub-agents if needed.
for (BaseAgent subAgent : this.subAgents) {
Expand Down Expand Up @@ -171,11 +172,11 @@ public List<? extends BaseAgent> subAgents() {
return subAgents;
}

public Optional<List<? extends BeforeAgentCallback>> beforeAgentCallback() {
public List<? extends BeforeAgentCallback> beforeAgentCallback() {
return beforeAgentCallback;
}

public Optional<List<? extends AfterAgentCallback>> afterAgentCallback() {
public List<? extends AfterAgentCallback> afterAgentCallback() {
return afterAgentCallback;
}

Expand All @@ -185,7 +186,7 @@ public Optional<List<? extends AfterAgentCallback>> afterAgentCallback() {
* <p>This method is only for use by Agent Development Kit.
*/
public List<? extends BeforeAgentCallback> canonicalBeforeAgentCallbacks() {
return beforeAgentCallback.orElse(ImmutableList.of());
return beforeAgentCallback;
}

/**
Expand All @@ -194,7 +195,7 @@ public List<? extends BeforeAgentCallback> canonicalBeforeAgentCallbacks() {
* <p>This method is only for use by Agent Development Kit.
*/
public List<? extends AfterAgentCallback> canonicalAfterAgentCallbacks() {
return afterAgentCallback.orElse(ImmutableList.of());
return afterAgentCallback;
}

/**
Expand Down Expand Up @@ -239,8 +240,7 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
() ->
callCallback(
beforeCallbacksToFunctions(
invocationContext.pluginManager(),
beforeAgentCallback.orElse(ImmutableList.of())),
invocationContext.pluginManager(), beforeAgentCallback),
invocationContext)
.flatMapPublisher(
beforeEventOpt -> {
Expand All @@ -257,7 +257,7 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
callCallback(
afterCallbacksToFunctions(
invocationContext.pluginManager(),
afterAgentCallback.orElse(ImmutableList.of())),
afterAgentCallback),
invocationContext)
.flatMapPublisher(Flowable::fromOptional));

Expand Down
20 changes: 20 additions & 0 deletions core/src/test/java/com/google/adk/agents/BaseAgentTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -316,4 +316,24 @@ public void canonicalCallbacks_returnsListWhenPresent() {
assertThat(agent.canonicalBeforeAgentCallbacks()).containsExactly(bc);
assertThat(agent.canonicalAfterAgentCallbacks()).containsExactly(ac);
}

@Test
public void runLive_invokesRunLiveImpl() {
var runLiveCallback = TestCallback.<Void>returningEmpty();
Content runLiveImplContent = Content.fromParts(Part.fromText("live_output"));
TestBaseAgent agent =
new TestBaseAgent(
TEST_AGENT_NAME,
TEST_AGENT_DESCRIPTION,
/* beforeAgentCallbacks= */ ImmutableList.of(),
/* afterAgentCallbacks= */ ImmutableList.of(),
runLiveCallback.asRunLiveImplSupplier(runLiveImplContent));
InvocationContext invocationContext = TestUtils.createInvocationContext(agent);

List<Event> results = agent.runLive(invocationContext).toList().blockingGet();

assertThat(results).hasSize(1);
assertThat(results.get(0).content()).hasValue(runLiveImplContent);
assertThat(runLiveCallback.wasCalled()).isTrue();
}
}
25 changes: 14 additions & 11 deletions core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1161,20 +1161,25 @@ public void fromConfig_withConfiguredCallbacks_resolvesCallbacks()

String pfx = "test.callbacks.";
registry.register(
pfx + "before_agent_1", (Callbacks.BeforeAgentCallback) (ctx) -> Maybe.empty());
pfx + "before_agent_1", (Callbacks.BeforeAgentCallback) (unusedCtx) -> Maybe.empty());
registry.register(
pfx + "before_agent_2", (Callbacks.BeforeAgentCallback) (ctx) -> Maybe.empty());
registry.register(pfx + "after_agent_1", (Callbacks.AfterAgentCallback) (ctx) -> Maybe.empty());
pfx + "before_agent_2", (Callbacks.BeforeAgentCallback) (unusedCtx) -> Maybe.empty());
registry.register(
pfx + "before_model_1", (Callbacks.BeforeModelCallback) (ctx, req) -> Maybe.empty());
pfx + "after_agent_1", (Callbacks.AfterAgentCallback) (unusedCtx) -> Maybe.empty());
registry.register(
pfx + "after_model_1", (Callbacks.AfterModelCallback) (ctx, resp) -> Maybe.empty());
pfx + "before_model_1",
(Callbacks.BeforeModelCallback) (unusedCtx, unusedReq) -> Maybe.empty());
registry.register(
pfx + "after_model_1",
(Callbacks.AfterModelCallback) (unusedCtx, unusedResp) -> Maybe.empty());
registry.register(
pfx + "before_tool_1",
(Callbacks.BeforeToolCallback) (inv, tool, args, toolCtx) -> Maybe.empty());
(Callbacks.BeforeToolCallback)
(unusedInv, unusedTool, unusedArgs, unusedToolCtx) -> Maybe.empty());
registry.register(
pfx + "after_tool_1",
(Callbacks.AfterToolCallback) (inv, tool, args, toolCtx, resp) -> Maybe.empty());
(Callbacks.AfterToolCallback)
(unusedInv, unusedTool, unusedArgs, unusedToolCtx, unusedResp) -> Maybe.empty());

File configFile = tempFolder.newFile("with_callbacks.yaml");
Files.writeString(
Expand Down Expand Up @@ -1204,10 +1209,8 @@ public void fromConfig_withConfiguredCallbacks_resolvesCallbacks()
assertThat(agent).isInstanceOf(LlmAgent.class);
LlmAgent llm = (LlmAgent) agent;

assertThat(agent.beforeAgentCallback()).isPresent();
assertThat(agent.beforeAgentCallback().get()).hasSize(2);
assertThat(agent.afterAgentCallback()).isPresent();
assertThat(agent.afterAgentCallback().get()).hasSize(1);
assertThat(agent.beforeAgentCallback()).hasSize(2);
assertThat(agent.afterAgentCallback()).hasSize(1);

assertThat(llm.beforeModelCallback()).isPresent();
assertThat(llm.beforeModelCallback().get()).hasSize(1);
Expand Down
41 changes: 29 additions & 12 deletions core/src/test/java/com/google/adk/testing/TestCallback.java
Original file line number Diff line number Diff line change
Expand Up @@ -102,63 +102,80 @@ public Supplier<Flowable<Event>> asRunAsyncImplSupplier(String contentText) {
return asRunAsyncImplSupplier(Content.fromParts(Part.fromText(contentText)));
}

/**
* Returns a {@link Supplier} that marks this callback as called and returns a {@link Flowable}
* with an event containing the given content.
*/
public Supplier<Flowable<Event>> asRunLiveImplSupplier(Content content) {
return () ->
Flowable.defer(
() -> {
markAsCalled();
return Flowable.just(Event.builder().content(content).build());
});
}

@SuppressWarnings("unchecked") // This cast is safe if T is Content.
public BeforeAgentCallback asBeforeAgentCallback() {
return ctx -> (Maybe<Content>) callMaybe();
return (unusedCtx) -> (Maybe<Content>) callMaybe();
}

@SuppressWarnings("unchecked") // This cast is safe if T is Content.
public BeforeAgentCallbackSync asBeforeAgentCallbackSync() {
return ctx -> (Optional<Content>) callOptional();
return (unusedCtx) -> (Optional<Content>) callOptional();
}

@SuppressWarnings("unchecked") // This cast is safe if T is Content.
public AfterAgentCallback asAfterAgentCallback() {
return ctx -> (Maybe<Content>) callMaybe();
return (unusedCtx) -> (Maybe<Content>) callMaybe();
}

@SuppressWarnings("unchecked") // This cast is safe if T is Content.
public AfterAgentCallbackSync asAfterAgentCallbackSync() {
return ctx -> (Optional<Content>) callOptional();
return (unusedCtx) -> (Optional<Content>) callOptional();
}

@SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse.
public BeforeModelCallback asBeforeModelCallback() {
return (ctx, req) -> (Maybe<LlmResponse>) callMaybe();
return (unusedCtx, unusedReq) -> (Maybe<LlmResponse>) callMaybe();
}

@SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse.
public BeforeModelCallbackSync asBeforeModelCallbackSync() {
return (ctx, req) -> (Optional<LlmResponse>) callOptional();
return (unusedCtx, unusedReq) -> (Optional<LlmResponse>) callOptional();
}

@SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse.
public AfterModelCallback asAfterModelCallback() {
return (ctx, res) -> (Maybe<LlmResponse>) callMaybe();
return (unusedCtx, unusedRes) -> (Maybe<LlmResponse>) callMaybe();
}

@SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse.
public AfterModelCallbackSync asAfterModelCallbackSync() {
return (ctx, res) -> (Optional<LlmResponse>) callOptional();
return (unusedCtx, unusedRes) -> (Optional<LlmResponse>) callOptional();
}

@SuppressWarnings("unchecked") // This cast is safe if T is Map<String, Object>.
public BeforeToolCallback asBeforeToolCallback() {
return (invCtx, tool, toolArgs, toolCtx) -> (Maybe<Map<String, Object>>) callMaybe();
return (unusedCtx, unusedTool, unusedToolArgs, unusedToolCtx) ->
(Maybe<Map<String, Object>>) callMaybe();
}

@SuppressWarnings("unchecked") // This cast is safe if T is Map<String, Object>.
public BeforeToolCallbackSync asBeforeToolCallbackSync() {
return (invCtx, tool, toolArgs, toolCtx) -> (Optional<Map<String, Object>>) callOptional();
return (unusedCtx, unusedTool, unusedToolArgs, unusedToolCtx) ->
(Optional<Map<String, Object>>) callOptional();
}

@SuppressWarnings("unchecked") // This cast is safe if T is Map<String, Object>.
public AfterToolCallback asAfterToolCallback() {
return (invCtx, tool, toolArgs, toolCtx, res) -> (Maybe<Map<String, Object>>) callMaybe();
return (unusedCtx, unusedTool, unusedToolArgs, unusedToolCtx, unusedRes) ->
(Maybe<Map<String, Object>>) callMaybe();
}

@SuppressWarnings("unchecked") // This cast is safe if T is Map<String, Object>.
public AfterToolCallbackSync asAfterToolCallbackSync() {
return (invCtx, tool, toolArgs, toolCtx, res) -> (Optional<Map<String, Object>>) callOptional();
return (unusedCtx, unusedTool, unusedToolArgs, unusedToolCtx, unusedRes) ->
(Optional<Map<String, Object>>) callOptional();
}
}