Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ public String getUsername() {
List<String> extensionHeaderValues = rc.request().headers().getAll(A2AHeaders_v0_3.X_A2A_EXTENSIONS);
Set<String> requestedExtensions = A2AExtensions.getRequestedExtensions(extensionHeaderValues);

return new ServerCallContext(user, state, requestedExtensions);
return new ServerCallContext(user, state, requestedExtensions, "0.3");
} else {
CallContextFactory_v0_3 builder = callContextFactory.get();
return builder.build(rc);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
import org.a2aproject.sdk.server.ServerCallContext;
import io.vertx.ext.web.RoutingContext;

/**
* Factory interface for creating ServerCallContext from a Vert.x RoutingContext.
*
* <p>Implementations MUST pass {@code "0.3"} as the protocol version when constructing
* {@link ServerCallContext} so that push notification payloads are formatted correctly.</p>
*/
public interface CallContextFactory_v0_3 {
ServerCallContext build(RoutingContext rc);
}
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ public String getUsername() {
List<String> extensionHeaderValues = rc.request().headers().getAll(A2AHeaders_v0_3.X_A2A_EXTENSIONS);
Set<String> requestedExtensions = A2AExtensions.getRequestedExtensions(extensionHeaderValues);

return new ServerCallContext(user, state, requestedExtensions);
return new ServerCallContext(user, state, requestedExtensions, "0.3");
} else {
CallContextFactory_v0_3 builder = callContextFactory.get();
return builder.build(rc);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
import org.a2aproject.sdk.server.ServerCallContext;
import io.vertx.ext.web.RoutingContext;

/**
* Factory interface for creating ServerCallContext from a Vert.x RoutingContext.
*
* <p>Implementations MUST pass {@code "0.3"} as the protocol version when constructing
* {@link ServerCallContext} so that push notification payloads are formatted correctly.</p>
*/
public interface CallContextFactory_v0_3 {
ServerCallContext build(RoutingContext rc);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package org.a2aproject.sdk.compat03.conversion;

import jakarta.enterprise.context.ApplicationScoped;

import org.a2aproject.sdk.compat03.conversion.mappers.domain.TaskMapper_v0_3;
import org.a2aproject.sdk.compat03.json.JsonProcessingException_v0_3;
import org.a2aproject.sdk.compat03.json.JsonUtil_v0_3;
import org.a2aproject.sdk.compat03.spec.Task_v0_3;
import org.a2aproject.sdk.server.tasks.PushNotificationPayloadFormatter;
import org.a2aproject.sdk.spec.Message;
import org.a2aproject.sdk.spec.StreamingEventKind;
import org.a2aproject.sdk.spec.Task;
import org.jspecify.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@ApplicationScoped
public class PushNotificationPayloadFormatter_v0_3 implements PushNotificationPayloadFormatter {

private static final Logger LOGGER = LoggerFactory.getLogger(PushNotificationPayloadFormatter_v0_3.class);

@Override
public String targetVersion() {
return "0.3";
}

@Override
public @Nullable String formatPayload(StreamingEventKind event, @Nullable Task taskSnapshot) {
if (event instanceof Message) {
return null;
}
if (taskSnapshot == null) {
LOGGER.warn("Cannot format v0.3 push notification: no task snapshot available");
return null;
}
Task_v0_3 v03Task = TaskMapper_v0_3.INSTANCE.fromV10(taskSnapshot);
try {
return JsonUtil_v0_3.toJson(v03Task);
} catch (JsonProcessingException_v0_3 e) {
LOGGER.error("Failed to serialize v0.3 task for push notification", e);
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public abstract class AbstractA2ARequestHandlerTest_v0_3 {
.parts(new TextPart_v0_3("test message"))
.build();

private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {};
private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = (event, snapshot) -> {};

// V1.0 backend infrastructure
protected AgentExecutor agentExecutor;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package org.a2aproject.sdk.compat03.conversion;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

import org.a2aproject.sdk.spec.Message;
import org.a2aproject.sdk.spec.Task;
import org.a2aproject.sdk.spec.TaskState;
import org.a2aproject.sdk.spec.TaskStatus;
import org.a2aproject.sdk.spec.TaskStatusUpdateEvent;
import org.a2aproject.sdk.spec.TextPart;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

class PushNotificationPayloadFormatter_v0_3_Test {

private PushNotificationPayloadFormatter_v0_3 formatter;

@BeforeEach
void setUp() {
formatter = new PushNotificationPayloadFormatter_v0_3();
}

@Test
void targetVersionIs03() {
assertEquals("0.3", formatter.targetVersion());
}

@Test
void formatsTaskEventAsV03Task() {
Task task = Task.builder()
.id("t1").contextId("c1")
.status(new TaskStatus(TaskState.TASK_STATE_COMPLETED))
.build();

String payload = formatter.formatPayload(task, task);

assertNotNull(payload);
assertTrue(payload.contains("\"kind\":\"task\""));
assertTrue(payload.contains("\"id\":\"t1\""));
assertTrue(payload.contains("\"status\""));
}

@Test
void formatsStatusUpdateUsingTaskSnapshot() {
Task snapshot = Task.builder()
.id("t1").contextId("c1")
.status(new TaskStatus(TaskState.TASK_STATE_WORKING))
.build();

TaskStatusUpdateEvent event = TaskStatusUpdateEvent.builder()
.taskId("t1").contextId("c1")
.status(new TaskStatus(TaskState.TASK_STATE_WORKING))
.build();

String payload = formatter.formatPayload(event, snapshot);

assertNotNull(payload);
assertTrue(payload.contains("\"kind\":\"task\""));
assertTrue(payload.contains("\"id\":\"t1\""));
}

@Test
void skipsMessageEvents() {
Task snapshot = Task.builder()
.id("t1").contextId("c1")
.status(new TaskStatus(TaskState.TASK_STATE_WORKING))
.build();

Message message = Message.builder()
.messageId("m1")
.role(Message.Role.ROLE_AGENT)
.parts(new TextPart("hello"))
.build();

String payload = formatter.formatPayload(message, snapshot);

assertNull(payload);
}

@Test
void returnsNullWhenSnapshotIsNull() {
Task task = Task.builder()
.id("t1").contextId("c1")
.status(new TaskStatus(TaskState.TASK_STATE_COMPLETED))
.build();

String payload = formatter.formatPayload(task, null);

assertNull(payload);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
/**
* Factory interface for creating ServerCallContext from gRPC StreamObserver.
* Implementations can provide custom context creation logic.
*
* <p>Implementations MUST pass {@code "0.3"} as the protocol version when constructing
* {@link ServerCallContext} so that push notification payloads are formatted correctly.</p>
*/
public interface CallContextFactory_v0_3 {
<V> ServerCallContext create(StreamObserver<V> responseObserver);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ private <V> ServerCallContext createCallContext(StreamObserver<V> responseObserv
Map<String, Object> state = new HashMap<>();
state.put("grpc_response_observer", responseObserver);
Set<String> requestedExtensions = new HashSet<>();
return new ServerCallContext(user, state, requestedExtensions);
return new ServerCallContext(user, state, requestedExtensions, "0.3");
} else {
return factory.create(responseObserver);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public class GrpcHandler_v0_3_Test extends AbstractA2ARequestHandlerTest_v0_3 {
.build();

private final ServerCallContext callContext = new ServerCallContext(
UnauthenticatedUser.INSTANCE, Map.of("foo", "bar"), new HashSet<>());
UnauthenticatedUser.INSTANCE, Map.of("foo", "bar"), new HashSet<>(), "0.3");

// ========================================
// GetTask Tests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
public class JSONRPCHandler_v0_3_Test extends AbstractA2ARequestHandlerTest_v0_3 {

private final ServerCallContext callContext = new ServerCallContext(
UnauthenticatedUser.INSTANCE, Map.of("foo", "bar"), new HashSet<>());
UnauthenticatedUser.INSTANCE, Map.of("foo", "bar"), new HashSet<>(), "0.3");

// ========================================
// GetTask Tests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
public class RestHandler_v0_3_Test extends AbstractA2ARequestHandlerTest_v0_3 {

private final ServerCallContext callContext = new ServerCallContext(
UnauthenticatedUser.INSTANCE, Map.of("foo", "bar"), new HashSet<>());
UnauthenticatedUser.INSTANCE, Map.of("foo", "bar"), new HashSet<>(), "0.3");

// ========================================
// GetTask Tests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.a2aproject.sdk.util.Assert;
import org.a2aproject.sdk.util.PageToken;
import org.a2aproject.sdk.spec.TaskPushNotificationConfig;
import org.jspecify.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -36,6 +37,12 @@ public class JpaDatabasePushNotificationConfigStore implements PushNotificationC
@Transactional
@Override
public TaskPushNotificationConfig setInfo(TaskPushNotificationConfig notificationConfig) {
return setInfo(notificationConfig, null);
}

@Transactional
@Override
public TaskPushNotificationConfig setInfo(TaskPushNotificationConfig notificationConfig, @Nullable String protocolVersion) {
String taskId = Assert.checkNotNullParam("taskId", notificationConfig.taskId());
// Ensure config has an ID - default to taskId if not provided (mirroring InMemoryPushNotificationConfigStore behavior)
if (notificationConfig.id().isEmpty()) {
Expand All @@ -44,6 +51,7 @@ public TaskPushNotificationConfig setInfo(TaskPushNotificationConfig notificatio
notificationConfig = TaskPushNotificationConfig.builder(notificationConfig).id(taskId).build();
}

String resolvedVersion = PushNotificationConfigStore.resolveProtocolVersion(protocolVersion);
LOGGER.debug("Saving PushNotificationConfig for Task '{}' with ID: {}", taskId, notificationConfig.id());
try {
TaskConfigId configId = new TaskConfigId(taskId, notificationConfig.id());
Expand All @@ -54,11 +62,12 @@ public TaskPushNotificationConfig setInfo(TaskPushNotificationConfig notificatio
if (existingJpaConfig != null) {
// Update existing entity
existingJpaConfig.setConfig(notificationConfig);
existingJpaConfig.setProtocolVersion(resolvedVersion);
LOGGER.debug("Updated existing PushNotificationConfig for Task '{}' with ID: {}",
taskId, notificationConfig.id());
} else {
// Create new entity
JpaPushNotificationConfig jpaConfig = JpaPushNotificationConfig.createFromConfig(taskId, notificationConfig);
JpaPushNotificationConfig jpaConfig = JpaPushNotificationConfig.createFromConfig(taskId, notificationConfig, resolvedVersion);
em.persist(jpaConfig);
LOGGER.debug("Persisted new PushNotificationConfig for Task '{}' with ID: {}",
taskId, notificationConfig.id());
Expand Down Expand Up @@ -164,4 +173,27 @@ public void deleteInfo(String taskId, String configId) {
}
}

@Transactional
@Override
public @Nullable String getProtocolVersion(String taskId, String configId) {
JpaPushNotificationConfig jpaConfig = em.find(JpaPushNotificationConfig.class,
new TaskConfigId(taskId, configId));
return jpaConfig != null ? jpaConfig.getProtocolVersion() : null;
}

@Transactional
@Override
public java.util.Map<String, String> getProtocolVersions(String taskId) {
List<Object[]> results = em.createQuery(
"SELECT c.id.configId, c.protocolVersion FROM JpaPushNotificationConfig c " +
"WHERE c.id.taskId = :taskId AND c.protocolVersion IS NOT NULL", Object[].class)
.setParameter("taskId", taskId)
.getResultList();
java.util.Map<String, String> versions = new java.util.HashMap<>();
for (Object[] row : results) {
versions.put((String) row[0], (String) row[1]);
}
return versions;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import org.a2aproject.sdk.jsonrpc.common.json.JsonProcessingException;
import org.a2aproject.sdk.jsonrpc.common.json.JsonUtil;
import org.a2aproject.sdk.spec.TaskPushNotificationConfig;
import org.jspecify.annotations.Nullable;

import java.time.Instant;

@Entity
Expand All @@ -21,6 +23,9 @@ public class JpaPushNotificationConfig {
@Column(name = "task_data", columnDefinition = "TEXT", nullable = false)
private String configJson;

@Column(name = "protocol_version")
private String protocolVersion;

@Column(name = "created_at")
private Instant createdAt;

Expand Down Expand Up @@ -79,11 +84,20 @@ public void setCreatedAt(Instant createdAt) {
this.createdAt = createdAt;
}

static JpaPushNotificationConfig createFromConfig(String taskId, TaskPushNotificationConfig config) throws JsonProcessingException {
public @Nullable String getProtocolVersion() {
return protocolVersion;
}

public void setProtocolVersion(String protocolVersion) {
this.protocolVersion = protocolVersion;
}

static JpaPushNotificationConfig createFromConfig(String taskId, TaskPushNotificationConfig config, @Nullable String protocolVersion) throws JsonProcessingException {
String json = JsonUtil.toJson(config);
JpaPushNotificationConfig jpaPushNotificationConfig =
new JpaPushNotificationConfig(new TaskConfigId(taskId, config.id()), json);
jpaPushNotificationConfig.config = config;
jpaPushNotificationConfig.protocolVersion = protocolVersion;
return jpaPushNotificationConfig;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public void testDirectNotificationTrigger() {
.build();

// Directly trigger the mock
mockPushNotificationSender.sendNotification(testTask);
mockPushNotificationSender.sendNotification(testTask, null);

// Verify it was captured
Queue<Task> captured = mockPushNotificationSender.getCapturedTasks();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ public void testSendNotificationSuccess() throws Exception {
when(mockPostBuilder.post()).thenReturn(mockHttpResponse);
when(mockHttpResponse.success()).thenReturn(true);

notificationSender.sendNotification(task);
notificationSender.sendNotification(task, null);

// Verify HTTP client was called
ArgumentCaptor<String> bodyCaptor = ArgumentCaptor.forClass(String.class);
Expand Down Expand Up @@ -281,7 +281,7 @@ public void testSendNotificationWithToken() throws Exception {
when(mockPostBuilder.post()).thenReturn(mockHttpResponse);
when(mockHttpResponse.success()).thenReturn(true);

notificationSender.sendNotification(task);
notificationSender.sendNotification(task, null);

// TODO: Once token authentication is implemented, verify that:
// 1. The token is included in request headers (e.g., X-A2A-Notification-Token)
Expand All @@ -307,7 +307,7 @@ public void testSendNotificationNoConfig() throws Exception {
String taskId = "task_send_no_config";
Task task = createSampleTask(taskId, TaskState.TASK_STATE_COMPLETED);

notificationSender.sendNotification(task);
notificationSender.sendNotification(task, null);

// Verify HTTP client was never called
verify(mockHttpClient, never()).createPost();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public class MockPushNotificationSender implements PushNotificationSender {
private final Queue<StreamingEventKind> capturedEvents = new ConcurrentLinkedQueue<>();

@Override
public void sendNotification(StreamingEventKind event) {
public void sendNotification(StreamingEventKind event, Task taskSnapshot) {
capturedEvents.add(event);
}

Expand Down
Loading