Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions core/src/main/java/com/google/adk/tools/ExampleTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
public final class ExampleTool extends BaseTool {

private final Optional<BaseExampleProvider> exampleProvider;
private final Optional<List<Example>> examples;
private final List<Example> examples;

/** Single private constructor; create via builder or fromConfig. */
private ExampleTool(Builder builder) {
Expand All @@ -57,7 +57,7 @@ private ExampleTool(Builder builder) {
? "Adds few-shot examples to the request"
: builder.description);
this.exampleProvider = builder.provider;
this.examples = builder.examples.isEmpty() ? Optional.empty() : Optional.of(builder.examples);
this.examples = builder.examples;
}

@Override
Expand All @@ -77,9 +77,9 @@ public Completable processLlmRequest(
final String examplesBlock;
if (exampleProvider.isPresent()) {
examplesBlock = ExampleUtils.buildExampleSi(exampleProvider.get(), query);
} else if (examples.isPresent()) {
} else if (!examples.isEmpty()) {
// Adapter provider that returns a fixed list irrespective of query
BaseExampleProvider provider = q -> examples.get();
BaseExampleProvider provider = (unusedQuery) -> examples;
examplesBlock = ExampleUtils.buildExampleSi(provider, query);
} else {
return Completable.complete();
Expand Down Expand Up @@ -157,6 +157,7 @@ public static Builder builder() {
return new Builder();
}

/** Builder for {@link ExampleTool}. */
public static final class Builder {
private final List<Example> examples = new ArrayList<>();
private String name = "example_tool";
Expand Down
91 changes: 91 additions & 0 deletions core/src/test/java/com/google/adk/flows/llmflows/ExamplesTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright 2026 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.adk.flows.llmflows;

import static com.google.common.truth.Truth.assertThat;

import com.google.adk.agents.InvocationContext;
import com.google.adk.agents.LlmAgent;
import com.google.adk.agents.RunConfig;
import com.google.adk.examples.BaseExampleProvider;
import com.google.adk.examples.Example;
import com.google.adk.models.LlmRequest;
import com.google.common.collect.ImmutableList;
import com.google.genai.types.Content;
import com.google.genai.types.Part;
import java.util.List;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public final class ExamplesTest {

private static class TestExampleProvider implements BaseExampleProvider {
@Override
public List<Example> getExamples(String query) {
return ImmutableList.of(
Example.builder()
.input(Content.fromParts(Part.fromText("input1")))
.output(
ImmutableList.of(
Content.builder().parts(Part.fromText("output1")).role("model").build()))
.build());
}
}

@Test
public void processRequest_withExampleProvider_addsExamplesToInstructions() {
LlmAgent agent =
LlmAgent.builder().name("test-agent").exampleProvider(new TestExampleProvider()).build();
InvocationContext context =
InvocationContext.builder()
.invocationId("invocation1")
.agent(agent)
.userContent(Content.fromParts(Part.fromText("what is up?")))
.runConfig(RunConfig.builder().build())
.build();
LlmRequest request = LlmRequest.builder().build();
Examples examplesProcessor = new Examples();

RequestProcessor.RequestProcessingResult result =
examplesProcessor.processRequest(context, request).blockingGet();

assertThat(result.updatedRequest().getSystemInstructions()).isNotEmpty();
assertThat(result.updatedRequest().getSystemInstructions().get(0))
.contains("[user]\ninput1\n\n[model]\noutput1\n");
}

@Test
public void processRequest_withoutExampleProvider_doesNotAddExamplesToInstructions() {
LlmAgent agent = LlmAgent.builder().name("test-agent").build();
InvocationContext context =
InvocationContext.builder()
.invocationId("invocation1")
.agent(agent)
.userContent(Content.fromParts(Part.fromText("what is up?")))
.runConfig(RunConfig.builder().build())
.build();
LlmRequest request = LlmRequest.builder().build();
Examples examplesProcessor = new Examples();

RequestProcessor.RequestProcessingResult result =
examplesProcessor.processRequest(context, request).blockingGet();

assertThat(result.updatedRequest().getSystemInstructions()).isEmpty();
}
}
42 changes: 40 additions & 2 deletions core/src/test/java/com/google/adk/tools/ExampleToolTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,43 @@ public void processLlmRequest_withInlineExamples_appendsFewShot() {
assertThat(si).contains("qout");
}

@Test
public void processLlmRequest_withProvider_appendsFewShot() {
ExampleTool tool = ExampleTool.builder().setExampleProvider(ProviderHolder.EXAMPLES).build();

InvocationContext ctx = newContext();
LlmRequest.Builder builder = LlmRequest.builder().model("gemini-2.0-flash");

tool.processLlmRequest(builder, ToolContext.builder(ctx).build()).blockingAwait();
LlmRequest updated = builder.build();

assertThat(updated.getSystemInstructions()).isNotEmpty();
String si = String.join("\n", updated.getSystemInstructions());
assertThat(si).contains("Begin few-shot");
assertThat(si).contains("qin");
assertThat(si).contains("qout");
}

@Test
public void processLlmRequest_withEmptyUserContent_doesNotAppendFewShot() {
ExampleTool tool = ExampleTool.builder().addExample(makeExample("qin", "qout")).build();
InvocationContext ctxWithContent = newContext();
InvocationContext ctx =
InvocationContext.builder()
.invocationId(ctxWithContent.invocationId())
.agent(ctxWithContent.agent())
.session(ctxWithContent.session())
.userContent(Content.fromParts(Part.fromText("")))
.runConfig(ctxWithContent.runConfig())
.build();
LlmRequest.Builder builder = LlmRequest.builder().model("gemini-2.0-flash");

tool.processLlmRequest(builder, ToolContext.builder(ctx).build()).blockingAwait();
LlmRequest updated = builder.build();

assertThat(updated.getSystemInstructions()).isEmpty();
}

@Test
public void fromConfig_withInlineExamples_buildsTool() throws Exception {
BaseTool.ToolArgsConfig args = new BaseTool.ToolArgsConfig();
Expand All @@ -95,7 +132,7 @@ public void fromConfig_withInlineExamples_buildsTool() throws Exception {
/** Holder for a provider referenced via ClassName.FIELD reflection. */
static final class ProviderHolder {
public static final BaseExampleProvider EXAMPLES =
(query) -> ImmutableList.of(makeExample("qin", "qout"));
(unusedQuery) -> ImmutableList.of(makeExample("qin", "qout"));

private ProviderHolder() {}
}
Expand Down Expand Up @@ -255,7 +292,8 @@ public void fromConfig_withWrongTypeProviderField_throwsConfigurationException()
/** Holder with non-static field for testing. */
static final class NonStaticProviderHolder {
@SuppressWarnings("ConstantField") // Intentionally non-static for testing
public final BaseExampleProvider INSTANCE = (query) -> ImmutableList.of(makeExample("q", "a"));
public final BaseExampleProvider INSTANCE =
(unusedQuery) -> ImmutableList.of(makeExample("q", "a"));

private NonStaticProviderHolder() {}
}
Expand Down