Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions src/a2a/compat/v0_3/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@

from typing import Any

from cryptography.fernet import (
Fernet,
)
from google.protobuf.json_format import MessageToDict, ParseDict

from a2a.compat.v0_3 import types as types_v03
from a2a.server.models import PushNotificationConfigModel, TaskModel
from a2a.types import a2a_pb2 as pb2_v10


Expand Down Expand Up @@ -1367,3 +1371,77 @@ def to_compat_get_extended_agent_card_request(
) -> types_v03.GetAuthenticatedExtendedCardRequest:
"""Convert get extended agent card request to v0.3 compat type."""
return types_v03.GetAuthenticatedExtendedCardRequest(id=request_id)


def core_to_compat_task_model(task: pb2_v10.Task, owner: str) -> TaskModel:
"""Converts a 1.0 core Task to a TaskModel using v0.3 JSON structure."""
compat_task = to_compat_task(task)
data = compat_task.model_dump(mode='json')

return TaskModel(
id=task.id,
context_id=task.context_id,
owner=owner,
status=data.get('status'),
history=data.get('history'),
artifacts=data.get('artifacts'),
task_metadata=data.get('metadata'),
protocol_version='0.3',
)


def compat_task_model_to_core(task_model: TaskModel) -> pb2_v10.Task:
"""Converts a TaskModel with v0.3 structure to a 1.0 core Task."""
compat_task = types_v03.Task(
id=task_model.id,
context_id=task_model.context_id,
status=types_v03.TaskStatus.model_validate(task_model.status),
artifacts=(
[types_v03.Artifact.model_validate(a) for a in task_model.artifacts]
if task_model.artifacts
else []
),
history=(
[types_v03.Message.model_validate(h) for h in task_model.history]
if task_model.history
else []
),
metadata=task_model.task_metadata,
)
return to_core_task(compat_task)


def core_to_compat_push_notification_config_model(
task_id: str,
config: pb2_v10.TaskPushNotificationConfig,
owner: str,
fernet: Fernet | None = None,
) -> PushNotificationConfigModel:
"""Converts a 1.0 core TaskPushNotificationConfig to a PushNotificationConfigModel using v0.3 JSON structure."""
compat_config = to_compat_push_notification_config(config)

json_payload = compat_config.model_dump_json().encode('utf-8')
data_to_store = fernet.encrypt(json_payload) if fernet else json_payload

return PushNotificationConfigModel(
task_id=task_id,
config_id=config.id,
owner=owner,
config_data=data_to_store,
protocol_version='0.3',
)


def compat_push_notification_config_model_to_core(
model_instance: str, task_id: str
) -> pb2_v10.TaskPushNotificationConfig:
"""Converts a PushNotificationConfigModel with v0.3 structure back to a 1.0 core TaskPushNotificationConfig."""
inner_config = types_v03.PushNotificationConfig.model_validate_json(
model_instance
)
return to_core_task_push_notification_config(
types_v03.TaskPushNotificationConfig(
task_id=task_id,
push_notification_config=inner_config,
)
)
43 changes: 28 additions & 15 deletions src/a2a/server/tasks/database_push_notification_config_store.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# ruff: noqa: PLC0415
import inspect
import logging

from typing import TYPE_CHECKING
Expand All @@ -13,9 +14,7 @@
AsyncSession,
async_sessionmaker,
)
from sqlalchemy.orm import (
class_mapper,
)
from sqlalchemy.orm import class_mapper
except ImportError as e:
raise ImportError(
'DatabasePushNotificationConfigStore requires SQLAlchemy and a database driver. '
Expand All @@ -26,8 +25,12 @@
"or 'pip install a2a-sdk[sql]'"
) from e

from a2a.compat.v0_3 import conversions
from a2a.compat.v0_3 import types as types_v03
if TYPE_CHECKING:
from collections.abc import Callable

from a2a.compat.v0_3.conversions import (
compat_push_notification_config_model_to_core,
)
from a2a.server.context import ServerCallContext
from a2a.server.models import (
Base,
Expand All @@ -43,8 +46,6 @@

if TYPE_CHECKING:
from cryptography.fernet import Fernet


logger = logging.getLogger(__name__)


Expand All @@ -62,6 +63,9 @@ class DatabasePushNotificationConfigStore(PushNotificationConfigStore):
_fernet: 'Fernet | None'
owner_resolver: OwnerResolver

core_to_model_conversion: 'Callable[[str, TaskPushNotificationConfig, str, Fernet | None], PushNotificationConfigModel] | None' = None
model_to_core_conversion: 'Callable[[PushNotificationConfigModel], TaskPushNotificationConfig] | None' = None

def __init__(
self,
engine: AsyncEngine,
Expand Down Expand Up @@ -152,6 +156,13 @@ def _to_orm(

The config data is serialized to JSON bytes, and encrypted if a key is configured.
"""
if conversion := self.core_to_model_conversion:
# If it's a bound method of this instance, call the underlying function
# to avoid passing 'self' twice.
if inspect.ismethod(conversion):
return conversion.__func__(task_id, config, owner, self._fernet)
return conversion(task_id, config, owner, self._fernet)

json_payload = MessageToJson(config).encode('utf-8')

if self._fernet:
Expand All @@ -174,6 +185,13 @@ def _from_orm(

Handles decryption if a key is configured, with a fallback to plain JSON.
"""
if conversion := self.model_to_core_conversion:
# If it's a bound method of this instance, call the underlying function
# to avoid passing 'self' twice.
if inspect.ismethod(conversion):
return conversion.__func__(model_instance)
return conversion(model_instance)

payload = model_instance.config_data

if self._fernet:
Expand Down Expand Up @@ -359,12 +377,7 @@ def _parse_config(
"""
if protocol_version == '1.0':
return Parse(json_payload, TaskPushNotificationConfig())
inner_config = types_v03.PushNotificationConfig.model_validate_json(
json_payload
)
return conversions.to_core_task_push_notification_config(
types_v03.TaskPushNotificationConfig(
task_id=task_id or '',
push_notification_config=inner_config,
)

return compat_push_notification_config_model_to_core(
json_payload, task_id or ''
)
60 changes: 25 additions & 35 deletions src/a2a/server/tasks/database_task_store.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,18 @@
import inspect
import logging

from datetime import datetime, timezone
from typing import TYPE_CHECKING


try:
from sqlalchemy import (
Table,
and_,
delete,
func,
or_,
select,
)
from sqlalchemy import Table, and_, delete, func, or_, select
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
)
from sqlalchemy.orm import (
class_mapper,
)
from sqlalchemy.orm import class_mapper
except ImportError as e:
raise ImportError(
'DatabaseTaskStore requires SQLAlchemy and a database driver. '
Expand All @@ -30,25 +23,27 @@
"or 'pip install a2a-sdk[sql]'"
) from e

if TYPE_CHECKING:
from collections.abc import Callable

from google.protobuf.json_format import MessageToDict, ParseDict

from a2a.compat.v0_3 import conversions
from a2a.compat.v0_3 import types as types_v03
from a2a.server.context import ServerCallContext
from a2a.server.models import Base, TaskModel, create_task_model
from a2a.server.owner_resolver import OwnerResolver, resolve_user_scope
from a2a.server.tasks.task_store import TaskStore
from a2a.types import a2a_pb2
from a2a.types.a2a_pb2 import Task
from a2a.utils.constants import DEFAULT_LIST_TASKS_PAGE_SIZE
from a2a.utils.errors import InvalidParamsError
from a2a.utils.task import decode_page_token, encode_page_token


logger = logging.getLogger(__name__)


class DatabaseTaskStore(TaskStore):

Check notice on line 46 in src/a2a/server/tasks/database_task_store.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/server/tasks/inmemory_task_store.py (5-17)
"""SQLAlchemy-based implementation of TaskStore.

Stores task objects in a database supported by SQLAlchemy.
Expand All @@ -61,6 +56,9 @@
task_model: type[TaskModel]
owner_resolver: OwnerResolver

core_to_model_conversion: 'Callable[[Task, str], TaskModel] | None' = None
model_to_core_conversion: 'Callable[[TaskModel], Task] | None' = None

def __init__(
self,
engine: AsyncEngine,
Expand Down Expand Up @@ -119,6 +117,13 @@

def _to_orm(self, task: Task, owner: str) -> TaskModel:
"""Maps a Proto Task to a SQLAlchemy TaskModel instance."""
if conversion := self.core_to_model_conversion:
# If it's a bound method of this instance, call the underlying function
# to avoid passing 'self' twice.
if inspect.ismethod(conversion):
return conversion.__func__(task, owner)
return conversion(task, owner)

return self.task_model(
id=task.id,
context_id=task.context_id,
Expand All @@ -140,6 +145,13 @@

def _from_orm(self, task_model: TaskModel) -> Task:
"""Maps a SQLAlchemy TaskModel to a Proto Task instance."""
if conversion := self.model_to_core_conversion:
# If it's a bound method of this instance, call the underlying function
# to avoid passing 'self' twice.
if inspect.ismethod(conversion):
return conversion.__func__(task_model)
return conversion(task_model)

if task_model.protocol_version == '1.0':
task = Task(
id=task_model.id,
Expand All @@ -160,29 +172,7 @@
return task

# Legacy conversion
legacy_task = types_v03.Task(
id=task_model.id,
context_id=task_model.context_id,
status=types_v03.TaskStatus.model_validate(task_model.status),
artifacts=(
[
types_v03.Artifact.model_validate(a)
for a in task_model.artifacts
]
if task_model.artifacts
else []
),
history=(
[
types_v03.Message.model_validate(m)
for m in task_model.history
]
if task_model.history
else []
),
metadata=task_model.task_metadata or {},
)
return conversions.to_core_task(legacy_task)
return conversions.compat_task_model_to_core(task_model)

async def save(
self, task: Task, context: ServerCallContext | None = None
Expand Down
Loading
Loading