Skip to content

Commit 27bd23b

Browse files
bokelleyclaude
andcommitted
feat: add protocol field extraction for A2A responses
Extend response parsing to handle protocol-level fields (message, context_id, task_id, status, timestamp) that A2A servers may include alongside task data. These are now separated during validation to prevent schema mismatches, then preserved at the TaskResult level. Resolves #109. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
1 parent c1d812f commit 27bd23b

File tree

2 files changed

+228
-6
lines changed

2 files changed

+228
-6
lines changed

src/adcp/utils/response_parser.py

Lines changed: 80 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,66 @@ def parse_mcp_content(content: list[dict[str, Any]], response_type: type[T]) ->
129129
)
130130

131131

132+
# Protocol-level fields from ProtocolResponse (core/response.json) and
133+
# ProtocolEnvelope (core/protocol_envelope.json). These are separated from
134+
# task data for schema validation, but preserved at the TaskResult level.
135+
# Note: 'data' and 'payload' are handled separately as wrapper fields.
136+
PROTOCOL_FIELDS = {
137+
"message", # Human-readable summary
138+
"context_id", # Session continuity identifier
139+
"task_id", # Async operation identifier
140+
"status", # Task execution state
141+
"timestamp", # Response timestamp
142+
}
143+
144+
145+
def _extract_task_data(data: dict[str, Any]) -> dict[str, Any]:
146+
"""
147+
Extract task-specific data from a protocol response.
148+
149+
Servers may return responses in ProtocolResponse format:
150+
{"message": "...", "context_id": "...", "data": {...}}
151+
152+
Or ProtocolEnvelope format:
153+
{"message": "...", "status": "...", "payload": {...}}
154+
155+
Or task data directly with protocol fields mixed in:
156+
{"message": "...", "products": [...], ...}
157+
158+
This function separates task-specific data for schema validation.
159+
Protocol fields are preserved at the TaskResult level.
160+
161+
Args:
162+
data: Response data dict
163+
164+
Returns:
165+
Task-specific data suitable for schema validation.
166+
Returns the same dict object if no extraction is needed.
167+
"""
168+
# Check for wrapped payload fields (ProtocolResponse uses 'data', ProtocolEnvelope uses 'payload')
169+
if "data" in data and isinstance(data["data"], dict):
170+
return data["data"]
171+
if "payload" in data and isinstance(data["payload"], dict):
172+
return data["payload"]
173+
174+
# Check if any protocol fields are present
175+
if not any(k in PROTOCOL_FIELDS for k in data):
176+
return data # Return same object for identity check
177+
178+
# Separate task data from protocol fields
179+
return {k: v for k, v in data.items() if k not in PROTOCOL_FIELDS}
180+
181+
132182
def parse_json_or_text(data: Any, response_type: type[T]) -> T:
133183
"""
134184
Parse data that might be JSON string, dict, or other format.
135185
136186
Used by A2A adapter for flexible response parsing.
137187
188+
Handles protocol-level wrapping where servers return:
189+
- {"message": "...", "data": {...task_data...}}
190+
- {"message": "...", ...task_fields...}
191+
138192
Args:
139193
data: Response data (string, dict, or other)
140194
response_type: Expected Pydantic model type
@@ -147,22 +201,42 @@ def parse_json_or_text(data: Any, response_type: type[T]) -> T:
147201
"""
148202
# If already a dict, try direct validation
149203
if isinstance(data, dict):
204+
# Try direct validation first
205+
original_error: Exception | None = None
150206
try:
151207
return _validate_union_type(data, response_type)
152-
except ValidationError as e:
153-
# Get the type name, handling Union types
154-
type_name = getattr(response_type, "__name__", str(response_type))
155-
raise ValueError(f"Response doesn't match expected schema {type_name}: {e}") from e
208+
except (ValidationError, ValueError) as e:
209+
original_error = e
210+
211+
# Try extracting task data (separates protocol fields)
212+
task_data = _extract_task_data(data)
213+
if task_data is not data:
214+
try:
215+
return _validate_union_type(task_data, response_type)
216+
except (ValidationError, ValueError):
217+
pass # Fall through to raise original error
218+
219+
# Report the original validation error
220+
type_name = getattr(response_type, "__name__", str(response_type))
221+
raise ValueError(
222+
f"Response doesn't match expected schema {type_name}: {original_error}"
223+
) from original_error
156224

157225
# If string, try JSON parsing
158226
if isinstance(data, str):
159227
try:
160228
parsed = json.loads(data)
161-
return _validate_union_type(parsed, response_type)
162229
except json.JSONDecodeError as e:
163230
raise ValueError(f"Response is not valid JSON: {e}") from e
231+
232+
# Recursively handle dict parsing (which includes protocol field extraction)
233+
if isinstance(parsed, dict):
234+
return parse_json_or_text(parsed, response_type)
235+
236+
# Non-dict JSON (shouldn't happen for AdCP responses)
237+
try:
238+
return _validate_union_type(parsed, response_type)
164239
except ValidationError as e:
165-
# Get the type name, handling Union types
166240
type_name = getattr(response_type, "__name__", str(response_type))
167241
raise ValueError(f"Response doesn't match expected schema {type_name}: {e}") from e
168242

tests/test_response_parser.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,151 @@ def test_json_string_not_matching_schema_raises_error(self):
135135

136136
with pytest.raises(ValueError, match="doesn't match expected schema"):
137137
parse_json_or_text(data, SampleResponse)
138+
139+
140+
class ProductResponse(BaseModel):
141+
"""Response type without protocol fields for testing protocol field stripping."""
142+
143+
products: list[str]
144+
total: int = 0
145+
146+
147+
class TestProtocolFieldExtraction:
148+
"""Tests for protocol field extraction from A2A responses.
149+
150+
A2A servers may include protocol-level fields (message, context_id, data)
151+
that are not part of task-specific response schemas. These are separated
152+
for task data validation, but preserved at the TaskResult level.
153+
154+
See: https://github.com/adcontextprotocol/adcp-client-python/issues/109
155+
"""
156+
157+
def test_response_with_message_field_separated(self):
158+
"""Test that protocol 'message' field is separated before validation."""
159+
# A2A server returns task data with protocol message mixed in
160+
data = {
161+
"message": "No products matched your requirements.",
162+
"products": ["product-1", "product-2"],
163+
"total": 2,
164+
}
165+
166+
result = parse_json_or_text(data, ProductResponse)
167+
168+
assert isinstance(result, ProductResponse)
169+
assert result.products == ["product-1", "product-2"]
170+
assert result.total == 2
171+
172+
def test_response_with_context_id_separated(self):
173+
"""Test that protocol 'context_id' field is separated before validation."""
174+
data = {
175+
"context_id": "session-123",
176+
"products": ["product-1"],
177+
"total": 1,
178+
}
179+
180+
result = parse_json_or_text(data, ProductResponse)
181+
182+
assert isinstance(result, ProductResponse)
183+
assert result.products == ["product-1"]
184+
185+
def test_response_with_multiple_protocol_fields_separated(self):
186+
"""Test that multiple protocol fields are separated."""
187+
data = {
188+
"message": "Found products",
189+
"context_id": "session-456",
190+
"products": ["a", "b", "c"],
191+
"total": 3,
192+
}
193+
194+
result = parse_json_or_text(data, ProductResponse)
195+
196+
assert isinstance(result, ProductResponse)
197+
assert result.products == ["a", "b", "c"]
198+
assert result.total == 3
199+
200+
def test_response_with_data_wrapper_extracted(self):
201+
"""Test that ProtocolResponse 'data' wrapper is extracted."""
202+
# Full ProtocolResponse format: {"message": "...", "data": {...task_data...}}
203+
data = {
204+
"message": "Operation completed",
205+
"context_id": "ctx-789",
206+
"data": {
207+
"products": ["wrapped-product"],
208+
"total": 1,
209+
},
210+
}
211+
212+
result = parse_json_or_text(data, ProductResponse)
213+
214+
assert isinstance(result, ProductResponse)
215+
assert result.products == ["wrapped-product"]
216+
assert result.total == 1
217+
218+
def test_response_with_payload_wrapper_extracted(self):
219+
"""Test that ProtocolEnvelope 'payload' wrapper is extracted."""
220+
# Full ProtocolEnvelope format
221+
data = {
222+
"message": "Operation completed",
223+
"status": "completed",
224+
"task_id": "task-123",
225+
"timestamp": "2025-01-01T00:00:00Z",
226+
"payload": {
227+
"products": ["envelope-product"],
228+
"total": 1,
229+
},
230+
}
231+
232+
result = parse_json_or_text(data, ProductResponse)
233+
234+
assert isinstance(result, ProductResponse)
235+
assert result.products == ["envelope-product"]
236+
assert result.total == 1
237+
238+
def test_exact_match_still_works(self):
239+
"""Test that responses exactly matching schema still work."""
240+
data = {
241+
"products": ["exact-match"],
242+
"total": 1,
243+
}
244+
245+
result = parse_json_or_text(data, ProductResponse)
246+
247+
assert result.products == ["exact-match"]
248+
assert result.total == 1
249+
250+
def test_json_string_with_protocol_fields(self):
251+
"""Test JSON string with protocol fields is parsed correctly."""
252+
data = json.dumps(
253+
{
254+
"message": "Success",
255+
"products": ["from-json-string"],
256+
"total": 1,
257+
}
258+
)
259+
260+
result = parse_json_or_text(data, ProductResponse)
261+
262+
assert result.products == ["from-json-string"]
263+
264+
def test_invalid_data_after_separation_raises_error(self):
265+
"""Test that invalid data still raises error after separation."""
266+
data = {
267+
"message": "Some message",
268+
"wrong_field": "value",
269+
}
270+
271+
with pytest.raises(ValueError, match="doesn't match expected schema"):
272+
parse_json_or_text(data, ProductResponse)
273+
274+
def test_model_with_message_field_validates_directly(self):
275+
"""Test that models containing 'message' field validate without separation."""
276+
# SampleResponse has a 'message' field, so it should validate directly
277+
data = {
278+
"message": "Hello",
279+
"count": 42,
280+
}
281+
282+
result = parse_json_or_text(data, SampleResponse)
283+
284+
assert result.message == "Hello"
285+
assert result.count == 42

0 commit comments

Comments
 (0)