Skip to content

Commit 50cec66

Browse files
committed
codec
1 parent b739289 commit 50cec66

6 files changed

Lines changed: 42 additions & 12 deletions

File tree

src/agentex/lib/core/clients/temporal/temporal_client.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,10 @@
7676

7777

7878
class TemporalClient:
79-
def __init__(self, temporal_client: Client | None = None, plugins: list[Any] = []):
79+
def __init__(self, temporal_client: Client | None = None, plugins: list[Any] = [], payload_codec: Any | None = None):
8080
self._client: Client | None = temporal_client
8181
self._plugins = plugins
82+
self._payload_codec = payload_codec
8283

8384
@property
8485
def client(self) -> Client:
@@ -88,7 +89,7 @@ def client(self) -> Client:
8889
return self._client
8990

9091
@classmethod
91-
async def create(cls, temporal_address: str, plugins: list[Any] = []):
92+
async def create(cls, temporal_address: str, plugins: list[Any] = [], payload_codec: Any | None = None):
9293
if temporal_address in [
9394
"false",
9495
"False",
@@ -101,8 +102,8 @@ async def create(cls, temporal_address: str, plugins: list[Any] = []):
101102
]:
102103
_client = None
103104
else:
104-
_client = await get_temporal_client(temporal_address, plugins=plugins)
105-
return cls(_client, plugins)
105+
_client = await get_temporal_client(temporal_address, plugins=plugins, payload_codec=payload_codec)
106+
return cls(_client, plugins, payload_codec)
106107

107108
async def setup(self, temporal_address: str):
108109
self._client = await self._get_temporal_client(temporal_address=temporal_address)
@@ -120,7 +121,7 @@ async def _get_temporal_client(self, temporal_address: str) -> Client | None:
120121
]:
121122
return None
122123
else:
123-
return await get_temporal_client(temporal_address, plugins=self._plugins)
124+
return await get_temporal_client(temporal_address, plugins=self._plugins, payload_codec=self._payload_codec)
124125

125126
async def start_workflow(
126127
self,

src/agentex/lib/core/clients/temporal/utils.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22

33
from typing import Any
44

5+
import dataclasses
6+
57
from temporalio.client import Client, Plugin as ClientPlugin
8+
from temporalio.converter import PayloadCodec
69
from temporalio.worker import Interceptor
710
from temporalio.runtime import Runtime, TelemetryConfig, OpenTelemetryConfig
811
from temporalio.contrib.pydantic import pydantic_data_converter
@@ -79,14 +82,20 @@ def validate_worker_interceptors(interceptors: list[Any]) -> None:
7982
)
8083

8184

82-
async def get_temporal_client(temporal_address: str, metrics_url: str | None = None, plugins: list[Any] = []) -> Client:
85+
async def get_temporal_client(
86+
temporal_address: str,
87+
metrics_url: str | None = None,
88+
plugins: list[Any] = [],
89+
payload_codec: PayloadCodec | None = None,
90+
) -> Client:
8391
"""
8492
Create a Temporal client with plugin integration.
8593
8694
Args:
8795
temporal_address: Temporal server address
8896
metrics_url: Optional metrics endpoint URL
8997
plugins: List of Temporal plugins to include
98+
payload_codec: Optional payload codec for encoding/decoding payloads (e.g. encryption, compression)
9099
91100
Returns:
92101
Configured Temporal client
@@ -109,7 +118,10 @@ async def get_temporal_client(temporal_address: str, metrics_url: str | None = N
109118
}
110119

111120
if not has_openai_plugin:
112-
connect_kwargs["data_converter"] = pydantic_data_converter
121+
data_converter = pydantic_data_converter
122+
if payload_codec:
123+
data_converter = dataclasses.replace(data_converter, payload_codec=payload_codec)
124+
connect_kwargs["data_converter"] = data_converter
113125

114126
if not metrics_url:
115127
client = await Client.connect(**connect_kwargs)

src/agentex/lib/core/temporal/workers/worker.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from temporalio.runtime import Runtime, TelemetryConfig, OpenTelemetryConfig
2020
from temporalio.converter import (
2121
DataConverter,
22+
PayloadCodec,
2223
JSONTypeConverter,
2324
AdvancedJSONEncoder,
2425
DefaultPayloadConverter,
@@ -89,7 +90,12 @@ def _validate_interceptors(interceptors: list) -> None:
8990
)
9091

9192

92-
async def get_temporal_client(temporal_address: str, metrics_url: str | None = None, plugins: list = []) -> Client:
93+
async def get_temporal_client(
94+
temporal_address: str,
95+
metrics_url: str | None = None,
96+
plugins: list = [],
97+
payload_codec: PayloadCodec | None = None,
98+
) -> Client:
9399
if plugins != []: # We don't need to validate the plugins if they are empty
94100
_validate_plugins(plugins)
95101

@@ -108,7 +114,10 @@ async def get_temporal_client(temporal_address: str, metrics_url: str | None = N
108114

109115
# Only set data_converter if OpenAI plugin is not present
110116
if not has_openai_plugin:
111-
connect_kwargs["data_converter"] = custom_data_converter
117+
data_converter = custom_data_converter
118+
if payload_codec:
119+
data_converter = dataclasses.replace(data_converter, payload_codec=payload_codec)
120+
connect_kwargs["data_converter"] = data_converter
112121

113122
if not metrics_url:
114123
client = await Client.connect(**connect_kwargs)
@@ -129,6 +138,7 @@ def __init__(
129138
plugins: list = [],
130139
interceptors: list = [],
131140
metrics_url: str | None = None,
141+
payload_codec: PayloadCodec | None = None,
132142
):
133143
self.task_queue = task_queue
134144
self.activity_handles = []
@@ -140,6 +150,7 @@ def __init__(
140150
self.plugins = plugins
141151
self.interceptors = interceptors
142152
self.metrics_url = metrics_url
153+
self.payload_codec = payload_codec
143154

144155
@overload
145156
async def run(
@@ -175,6 +186,7 @@ async def run(
175186
temporal_address=os.environ.get("TEMPORAL_ADDRESS", "localhost:7233"),
176187
plugins=self.plugins,
177188
metrics_url=self.metrics_url,
189+
payload_codec=self.payload_codec,
178190
)
179191

180192
# Enable debug mode if AgentEx debug is enabled (disables deadlock detection)

src/agentex/lib/sdk/fastacp/fastacp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def create_async_acp(config: AsyncACPConfig, **kwargs) -> BaseACPServer:
6363
temporal_config["plugins"] = config.plugins # type: ignore[attr-defined]
6464
if hasattr(config, "interceptors"):
6565
temporal_config["interceptors"] = config.interceptors # type: ignore[attr-defined]
66+
if hasattr(config, "payload_codec"):
67+
temporal_config["payload_codec"] = config.payload_codec # type: ignore[attr-defined]
6668
return implementation_class.create(**temporal_config)
6769
else:
6870
return implementation_class.create(**kwargs)

src/agentex/lib/sdk/fastacp/impl/temporal_acp.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,22 @@ def __init__(
3131
temporal_task_service: TemporalTaskService | None = None,
3232
plugins: list[Any] | None = None,
3333
interceptors: list[Any] | None = None,
34+
payload_codec: Any | None = None,
3435
):
3536
super().__init__()
3637
self._temporal_task_service = temporal_task_service
3738
self._temporal_address = temporal_address
3839
self._plugins = plugins or []
3940
self._interceptors = interceptors or []
41+
self._payload_codec = payload_codec
4042

4143
@classmethod
4244
@override
43-
def create(cls, temporal_address: str, plugins: list[Any] | None = None, interceptors: list[Any] | None = None) -> "TemporalACP":
45+
def create(cls, temporal_address: str, plugins: list[Any] | None = None, interceptors: list[Any] | None = None, payload_codec: Any | None = None) -> "TemporalACP":
4446
logger.info("Initializing TemporalACP instance")
4547

4648
# Create instance without temporal client initially
47-
temporal_acp = cls(temporal_address=temporal_address, plugins=plugins, interceptors=interceptors)
49+
temporal_acp = cls(temporal_address=temporal_address, plugins=plugins, interceptors=interceptors, payload_codec=payload_codec)
4850
temporal_acp._setup_handlers()
4951
logger.info("TemporalACP instance initialized now")
5052
return temporal_acp
@@ -60,7 +62,7 @@ async def lifespan(app: FastAPI):
6062
if self._temporal_task_service is None:
6163
env_vars = EnvironmentVariables.refresh()
6264
temporal_client = await TemporalClient.create(
63-
temporal_address=self._temporal_address, plugins=self._plugins
65+
temporal_address=self._temporal_address, plugins=self._plugins, payload_codec=self._payload_codec
6466
)
6567
self._temporal_task_service = TemporalTaskService(
6668
temporal_client=temporal_client,

src/agentex/lib/types/fastacp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class TemporalACPConfig(AsyncACPConfig):
5656
temporal_address: str = Field(default="temporal-frontend.temporal.svc.cluster.local:7233", frozen=True)
5757
plugins: list[Any] = Field(default=[], frozen=True)
5858
interceptors: list[Any] = Field(default=[], frozen=True)
59+
payload_codec: Any = Field(default=None, frozen=True)
5960

6061
@field_validator("plugins")
6162
@classmethod

0 commit comments

Comments
 (0)