Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions core/src/main/java/com/google/adk/tools/VertexAiSearchTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import com.google.genai.types.VertexAISearch;
import com.google.genai.types.VertexAISearchDataStoreSpec;
import io.reactivex.rxjava3.core.Completable;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

Expand All @@ -39,7 +38,7 @@
public abstract class VertexAiSearchTool extends BaseTool {
public abstract Optional<String> dataStoreId();

public abstract Optional<List<VertexAISearchDataStoreSpec>> dataStoreSpecs();
public abstract ImmutableList<VertexAISearchDataStoreSpec> dataStoreSpecs();

public abstract Optional<String> searchEngineId();

Expand All @@ -54,7 +53,7 @@ public abstract class VertexAiSearchTool extends BaseTool {
public abstract Optional<String> dataStore();

public static Builder builder() {
return new AutoValue_VertexAiSearchTool.Builder();
return new AutoValue_VertexAiSearchTool.Builder().dataStoreSpecs(ImmutableList.of());
}

VertexAiSearchTool() {
Expand All @@ -80,24 +79,29 @@ public Completable processLlmRequest(
searchEngineId().ifPresent(vertexAiSearchBuilder::engine);
filter().ifPresent(vertexAiSearchBuilder::filter);
maxResults().ifPresent(vertexAiSearchBuilder::maxResults);
dataStoreSpecs().ifPresent(vertexAiSearchBuilder::dataStoreSpecs);
if (!dataStoreSpecs().isEmpty()) {
vertexAiSearchBuilder.dataStoreSpecs(dataStoreSpecs());
}

Tool retrievalTool =
Tool.builder()
.retrieval(Retrieval.builder().vertexAiSearch(vertexAiSearchBuilder.build()).build())
.build();

List<Tool> currentTools =
new ArrayList<>(
llmRequest.config().flatMap(GenerateContentConfig::tools).orElse(ImmutableList.of()));
currentTools.add(retrievalTool);

ImmutableList<Tool> currentTools =
ImmutableList.<Tool>builder()
.addAll(
llmRequest
.config()
.flatMap(GenerateContentConfig::tools)
.orElse(ImmutableList.of()))
.add(retrievalTool)
.build();
GenerateContentConfig newConfig =
llmRequest
.config()
.map(GenerateContentConfig::toBuilder)
.orElse(GenerateContentConfig.builder())
.tools(ImmutableList.copyOf(currentTools))
.tools(currentTools)
.build();
llmRequestBuilder.config(newConfig);
return Completable.complete();
Expand Down Expand Up @@ -126,14 +130,18 @@ public abstract static class Builder {

public final VertexAiSearchTool build() {
VertexAiSearchTool tool = autoBuild();
if ((tool.dataStoreId().isEmpty() && tool.searchEngineId().isEmpty())
|| (tool.dataStoreId().isPresent() && tool.searchEngineId().isPresent())) {
boolean hasDataStoreId =
tool.dataStoreId().isPresent() && !tool.dataStoreId().get().isEmpty();
boolean hasSearchEngineId =
tool.searchEngineId().isPresent() && !tool.searchEngineId().get().isEmpty();
if (hasDataStoreId == hasSearchEngineId) {
throw new IllegalArgumentException(
"Either dataStoreId or searchEngineId must be specified.");
"One and only one of dataStoreId or searchEngineId must not be empty.");
}
if (tool.dataStoreSpecs().isPresent() && tool.searchEngineId().isEmpty()) {
boolean hasDataStoreSpecs = !tool.dataStoreSpecs().isEmpty();
if (hasDataStoreSpecs && !hasSearchEngineId) {
throw new IllegalArgumentException(
"searchEngineId must be specified if dataStoreSpecs is specified.");
"searchEngineId must not be empty if dataStoreSpecs is not empty.");
}
return tool;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public void build_noDataStoreIdOrSearchEngineId_throwsException() {
assertThrows(IllegalArgumentException.class, () -> VertexAiSearchTool.builder().build());
assertThat(exception)
.hasMessageThat()
.isEqualTo("Either dataStoreId or searchEngineId must be specified.");
.isEqualTo("One and only one of dataStoreId or searchEngineId must not be empty.");
}

@Test
Expand All @@ -52,22 +52,46 @@ public void build_bothDataStoreIdAndSearchEngineId_throwsException() {
() -> VertexAiSearchTool.builder().dataStoreId("ds1").searchEngineId("se1").build());
assertThat(exception)
.hasMessageThat()
.isEqualTo("Either dataStoreId or searchEngineId must be specified.");
.isEqualTo("One and only one of dataStoreId or searchEngineId must not be empty.");
}

@Test
public void build_dataStoreSpecsWithoutSearchEngineId_throwsException() {
VertexAISearchDataStoreSpec spec =
VertexAISearchDataStoreSpec.builder().dataStore("ds1").build();
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
VertexAiSearchTool.builder()
.dataStoreId("ds1")
.dataStoreSpecs(ImmutableList.of())
.dataStoreSpecs(ImmutableList.of(spec))
.build());
assertThat(exception)
.hasMessageThat()
.isEqualTo("searchEngineId must be specified if dataStoreSpecs is specified.");
.isEqualTo("searchEngineId must not be empty if dataStoreSpecs is not empty.");
}

@Test
public void build_emptyDataStoreId_throwsException() {
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() -> VertexAiSearchTool.builder().dataStoreId("").build());
assertThat(exception)
.hasMessageThat()
.isEqualTo("One and only one of dataStoreId or searchEngineId must not be empty.");
}

@Test
public void build_emptySearchEngineId_throwsException() {
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() -> VertexAiSearchTool.builder().searchEngineId("").build());
assertThat(exception)
.hasMessageThat()
.isEqualTo("One and only one of dataStoreId or searchEngineId must not be empty.");
}

@Test
Expand All @@ -82,6 +106,19 @@ public void build_withSearchEngineId_succeeds() {
assertThat(tool.searchEngineId()).hasValue("se1");
}

@Test
public void build_withSearchEngineIdAndDataStoreSpecs_succeeds() {
VertexAISearchDataStoreSpec spec =
VertexAISearchDataStoreSpec.builder().dataStore("ds1").build();
VertexAiSearchTool tool =
VertexAiSearchTool.builder()
.searchEngineId("se1")
.dataStoreSpecs(ImmutableList.of(spec))
.build();
assertThat(tool.searchEngineId()).hasValue("se1");
assertThat(tool.dataStoreSpecs()).containsExactly(spec);
}

@Test
public void processLlmRequest_addsRetrievalTool() {
VertexAiSearchTool tool =
Expand Down Expand Up @@ -135,4 +172,18 @@ public void processLlmRequest_withDataStoreSpecs_addsRetrievalTool() {
.get())
.containsExactly(spec);
}

@Test
public void processLlmRequest_nonGeminiModel_throwsException() {
VertexAiSearchTool tool = VertexAiSearchTool.builder().searchEngineId("se1").build();
LlmRequest.Builder llmRequestBuilder = LlmRequest.builder().model("other-model");
tool.processLlmRequest(llmRequestBuilder, ToolContext.builder(invocationContext).build())
.test()
.assertError(
throwable ->
throwable instanceof IllegalArgumentException
&& throwable
.getMessage()
.equals("Vertex AI Search tool is only supported for Gemini models."));
}
}