Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions core/src/main/java/com/google/adk/tools/BaseTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.fasterxml.jackson.annotation.JsonAnySetter;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.adk.JsonBaseModel;
import com.google.adk.agents.ConfigAgentUtils.ConfigurationException;
import com.google.adk.models.LlmRequest;
Expand All @@ -38,6 +39,7 @@
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import javax.annotation.Nonnull;
import org.jspecify.annotations.Nullable;
import org.slf4j.Logger;
Expand Down Expand Up @@ -93,6 +95,84 @@ public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContex
throw new UnsupportedOperationException("This method is not implemented.");
}

/**
* Calls a tool with generic arguments and returns a map of results. The args type {@code T} need
* to be serializable with {@link JsonBaseModel#getMapper()}
*/
public <T> Single<Map<String, Object>> runAsync(T args, ToolContext toolContext) {
return runAsync(args, toolContext, JsonBaseModel.getMapper());
}

/**
* Calls a tool with generic arguments using a custom {@link ObjectMapper} and returns a map of
* results. The args type {@code T} needs to be serializable with the provided {@link
* ObjectMapper}.
*/
public <T> Single<Map<String, Object>> runAsync(
T args, ToolContext toolContext, ObjectMapper objectMapper) {
return runAsyncGeneric(args, toolContext, objectMapper, output -> output);
}

/**
* Calls a tool with generic arguments and a custom {@link ObjectMapper}, returning the results
* converted to a specified class. The input type {@code I} needs to be serializable and the
* output type {@code O} needs to be deserializable with the provided {@link ObjectMapper}.
*/
public <I, O> Single<O> runAsync(
I args, ToolContext toolContext, ObjectMapper objectMapper, Class<? extends O> oClass) {
return runAsyncGeneric(
args, toolContext, objectMapper, output -> objectMapper.convertValue(output, oClass));
}

/**
* Calls a tool with generic arguments and a custom {@link ObjectMapper}, returning the results
* converted to a specified type reference. The input type {@code I} needs to be serializable and
* the output type {@code O} needs to be deserializable with the provided {@link ObjectMapper}.
*/
public <I, O> Single<O> runAsync(
I args,
ToolContext toolContext,
ObjectMapper objectMapper,
TypeReference<? extends O> typeReference) {
return runAsyncGeneric(
args,
toolContext,
objectMapper,
output -> objectMapper.convertValue(output, typeReference));
}

/**
* Calls a tool with generic arguments, returning the results converted to a specified class. The
* input type {@code I} needs to be serializable and the output type {@code O} needs to be
* deserializable with {@link JsonBaseModel#getMapper()}
*/
public <I, O> Single<O> runAsync(I args, ToolContext toolContext, Class<? extends O> oClass) {
return runAsync(args, toolContext, JsonBaseModel.getMapper(), oClass);
}

/**
* Calls a tool with generic arguments, returning the results converted to a specified type
* reference. The input type needs to be serializable and the output type needs to be
* deserializable with {@link JsonBaseModel#getMapper()}
*/
public <I, O> Single<O> runAsync(
I args, ToolContext toolContext, TypeReference<? extends O> typeReference) {
return runAsync(args, toolContext, JsonBaseModel.getMapper(), typeReference);
}

private <I, O> Single<O> runAsyncGeneric(
I args,
ToolContext toolContext,
ObjectMapper objectMapper,
Function<? super Map<String, Object>, ? extends O> deserializer) {
return Single.defer(
() ->
Single.just(
objectMapper.convertValue(args, new TypeReference<Map<String, Object>>() {})))
.flatMap(argsMap -> runAsync(argsMap, toolContext))
.map(deserializer::apply);
}

/**
* Processes the outgoing {@link LlmRequest.Builder}.
*
Expand Down
144 changes: 144 additions & 0 deletions core/src/test/java/com/google/adk/tools/BaseToolTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,20 @@

import static com.google.common.truth.Truth.assertThat;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonSetter;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.google.adk.agents.InvocationContext;
import com.google.adk.agents.LlmAgent;
import com.google.adk.models.Gemini;
import com.google.adk.models.LlmRequest;
import com.google.adk.sessions.InMemorySessionService;
import com.google.auto.value.AutoValue;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.genai.types.FunctionDeclaration;
import com.google.genai.types.GenerateContentConfig;
import com.google.genai.types.GoogleMaps;
Expand All @@ -17,6 +25,7 @@
import com.google.genai.types.UrlContext;
import io.reactivex.rxjava3.core.Completable;
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.observers.TestObserver;
import java.util.Map;
import java.util.Optional;
import org.junit.Test;
Expand All @@ -27,6 +36,20 @@
@RunWith(JUnit4.class)
public final class BaseToolTest {

private final BaseTool doublingBaseTool =
new BaseTool("doubling-test-tool", "returns doubled args") {
@Override
public Single<Map<String, Object>> runAsync(
Map<String, Object> args, ToolContext toolContext) {
String sArg = (String) args.get("s");
Integer iArg = (Integer) args.get("i");
return Single.just(
ImmutableMap.<String, Object>of(
"s", sArg + sArg,
"i", iArg + iArg));
}
};

@Test
public void processLlmRequestNoDeclarationReturnsSameRequest() {
BaseTool tool =
Expand Down Expand Up @@ -247,4 +270,125 @@ public void processLlmRequestWithGoogleMapsToolAddsToolToConfig() {
assertThat(updatedLlmRequest.config().get().tools().get())
.containsExactly(Tool.builder().googleMaps(GoogleMaps.builder().build()).build());
}

@AutoValue
@JsonDeserialize(builder = AutoValue_BaseToolTest_TestToolArgs.Builder.class)
public abstract static class TestToolArgs {
@JsonProperty("i")
public abstract int getI();

@JsonProperty("s")
public abstract String getS();

TestToolArgs() {}

public static Builder builder() {
return new AutoValue_BaseToolTest_TestToolArgs.Builder();
}

@AutoValue.Builder
public abstract static class Builder {

@JsonSetter("i")
public abstract Builder setI(int i);

@JsonSetter("s")
public abstract Builder setS(String s);

public abstract TestToolArgs build();

@JsonCreator
public static Builder builder() {
return TestToolArgs.builder();
}
}
}

@Test
public void runAsync_withTypeReference_convertsArguments() throws Exception {
TestToolArgs.Builder builder = TestToolArgs.builder().setI(42).setS("foo");

Single<TestToolArgs> out =
doublingBaseTool.runAsync(
builder.build(), /* toolContext= */ null, new TypeReference<TestToolArgs>() {});
TestObserver<TestToolArgs> testObserver = out.test();

testObserver.assertComplete();
TestToolArgs expected = TestToolArgs.builder().setI(84).setS("foofoo").build();
testObserver.assertValue(expected);
}

@Test
public void runAsync_withClass_convertsArguments() throws Exception {
TestToolArgs.Builder builder = TestToolArgs.builder().setI(21).setS("bar");

Single<TestToolArgs> out =
doublingBaseTool.runAsync(builder.build(), /* toolContext= */ null, TestToolArgs.class);
TestObserver<TestToolArgs> testObserver = out.test();

testObserver.assertComplete();
TestToolArgs expected = TestToolArgs.builder().setI(42).setS("barbar").build();
testObserver.assertValue(expected);
}

@Test
public void runAsync_withObjectOnly_convertsArguments() throws Exception {
TestToolArgs.Builder builder = TestToolArgs.builder().setI(11).setS("baz");

Single<Map<String, Object>> out =
doublingBaseTool.runAsync(builder.build(), /* toolContext= */ null);
TestObserver<Map<String, Object>> testObserver = out.test();

testObserver.assertComplete();
ImmutableMap<String, Object> expected = ImmutableMap.of("i", 22, "s", "bazbaz");
testObserver.assertValue(expected);
}

@Test
public void runAsync_withObjectMapperAndObjectOnly_convertsArguments() throws Exception {
TestToolArgs.Builder builder = TestToolArgs.builder().setI(11).setS("baz");
ObjectMapper objectMapper = new ObjectMapper();

Single<Map<String, Object>> out =
doublingBaseTool.runAsync(builder.build(), /* toolContext= */ null, objectMapper);
TestObserver<Map<String, Object>> testObserver = out.test();

testObserver.assertComplete();
ImmutableMap<String, Object> expected = ImmutableMap.of("i", 22, "s", "bazbaz");
testObserver.assertValue(expected);
}

@Test
public void runAsync_withTypeReferenceAndObjectMapper_convertsArguments() throws Exception {
TestToolArgs.Builder builder = TestToolArgs.builder().setI(42).setS("foo");
ObjectMapper objectMapper = new ObjectMapper();

Single<TestToolArgs> out =
doublingBaseTool.runAsync(
builder.build(),
/* toolContext= */ null,
objectMapper,
new TypeReference<TestToolArgs>() {});

TestObserver<TestToolArgs> testObserver = out.test();

testObserver.assertComplete();
TestToolArgs expected = TestToolArgs.builder().setI(84).setS("foofoo").build();
testObserver.assertValue(expected);
}

@Test
public void runAsync_withClassAndObjectMapper_convertsArguments() throws Exception {
TestToolArgs.Builder builder = TestToolArgs.builder().setI(21).setS("bar");
ObjectMapper objectMapper = new ObjectMapper();

Single<TestToolArgs> out =
doublingBaseTool.runAsync(
builder.build(), /* toolContext= */ null, objectMapper, TestToolArgs.class);
TestObserver<TestToolArgs> testObserver = out.test();

testObserver.assertComplete();
TestToolArgs expected = TestToolArgs.builder().setI(42).setS("barbar").build();
testObserver.assertValue(expected);
}
}
Loading