Skip to content

Commit 0285533

Browse files
committed
sessions on top of postgres
1 parent 0c66e2f commit 0285533

12 files changed

Lines changed: 1565 additions & 21 deletions

openai_agents/memory/README.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Session Memory Examples
2+
3+
Session memory examples for OpenAI Agents SDK integrated with Temporal workflows.
4+
5+
*Adapted from [OpenAI Agents SDK session memory examples](https://github.com/openai/openai-agents-python/tree/main/examples/memory)*
6+
7+
Before running these examples, be sure to review the [prerequisites and background on the integration](../README.md).
8+
9+
## Running the Examples
10+
11+
### PostgreSQL Session Memory
12+
13+
This example uses a PostgreSQL database to store session data.
14+
15+
You need can use the standard PostgreSQL environment variables to configure the database connection.
16+
These include `PGDATABASE`, `PGUSER`, `PGPASSWORD`, `PGHOST`, and `PGPORT`.
17+
We also support the `DATABASE_URL` environment variable.
18+
19+
To confirm that your environment is configured correctly, just run the `psql` command after setting the environment variables.
20+
For example:
21+
```bash
22+
PGDATABASE=postgres psql
23+
```
24+
25+
Start the worker:
26+
```bash
27+
PGDATABASE=postgres uv run openai_agents/memory/run_postgres_session_worker.py
28+
```
29+
30+
Then run the workflow:
31+
32+
```bash
33+
uv run openai_agents/memory/run_postgres_session_workflow.py
34+
```
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""Worker-level database connection state management.
2+
3+
WARNING: This implementation uses global state and is not safe for concurrent
4+
testing (e.g., pytest-xdist). Run tests sequentially to avoid race conditions.
5+
"""
6+
7+
import asyncpg
8+
from typing import Optional
9+
10+
11+
# Module-level connection state
12+
_connection: Optional[asyncpg.Connection] = None
13+
14+
15+
def set_worker_connection(connection: asyncpg.Connection) -> None:
16+
"""Set the worker-level database connection."""
17+
global _connection
18+
_connection = connection
19+
20+
21+
def get_worker_connection() -> asyncpg.Connection:
22+
"""Get the worker-level database connection.
23+
24+
Raises:
25+
RuntimeError: If no connection has been set.
26+
"""
27+
if _connection is None:
28+
raise RuntimeError(
29+
"No worker-level database connection has been set. "
30+
"Call set_worker_connection() before using activities."
31+
)
32+
return _connection
33+
34+
35+
def clear_worker_connection() -> None:
36+
"""Clear the worker-level database connection."""
37+
global _connection
38+
_connection = None
39+
40+
41+
def has_worker_connection() -> bool:
42+
"""Check if a worker-level connection is available."""
43+
return _connection is not None

openai_agents/memory/db_utils.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import json
2+
import asyncpg
3+
from typing import Callable, Awaitable, TypeVar
4+
from temporalio import activity
5+
from pydantic import BaseModel
6+
7+
T = TypeVar("T")
8+
9+
10+
class IdempotenceHelper(BaseModel):
11+
table_name: str
12+
13+
def __init__(self, table_name: str):
14+
super().__init__(table_name=table_name)
15+
self.table_name = table_name
16+
17+
async def create_table(self, conn: asyncpg.Connection) -> None:
18+
await conn.execute(
19+
f"""
20+
CREATE TABLE IF NOT EXISTS {self.table_name} (
21+
run_id UUID NOT NULL,
22+
activity_id TEXT NOT NULL,
23+
operation_started_at TIMESTAMP NOT NULL,
24+
operation_completed_at TIMESTAMP NULL,
25+
operation_result TEXT NULL,
26+
PRIMARY KEY (run_id, activity_id)
27+
)
28+
"""
29+
)
30+
31+
async def idempotent_update(
32+
self,
33+
conn: asyncpg.Connection,
34+
operation: Callable[[asyncpg.Connection], Awaitable[T]],
35+
) -> T | None:
36+
"""Insert idempotence row; on conflict, read and return existing result.
37+
38+
The operation must be an async callable of the form:
39+
async def op(conn: asyncpg.Connection) -> T
40+
"""
41+
activity_info = activity.info()
42+
run_id = activity_info.workflow_run_id
43+
activity_id = activity_info.activity_id
44+
45+
async with conn.transaction():
46+
did_insert = await conn.fetchrow(
47+
(
48+
f"INSERT INTO {self.table_name} "
49+
f"(run_id, activity_id, operation_started_at) "
50+
f"VALUES ($1, $2, NOW()) "
51+
f"ON CONFLICT (run_id, activity_id) DO NOTHING "
52+
f"RETURNING 1"
53+
),
54+
run_id,
55+
activity_id,
56+
)
57+
58+
if did_insert:
59+
res = await operation(conn)
60+
61+
if hasattr(res, "model_dump_json"):
62+
op_result = res.model_dump_json()
63+
else:
64+
op_result = json.dumps(res)
65+
66+
await conn.execute(
67+
f"UPDATE {self.table_name} SET operation_completed_at = NOW(), operation_result = $1 WHERE run_id = $2 AND activity_id = $3",
68+
op_result,
69+
run_id,
70+
activity_id,
71+
)
72+
return res
73+
else:
74+
row = await conn.fetchrow(
75+
f"SELECT operation_result FROM {self.table_name} WHERE run_id = $1 AND activity_id = $2",
76+
run_id,
77+
activity_id,
78+
)
79+
if not row or row["operation_result"] is None:
80+
return None
81+
try:
82+
return json.loads(row["operation_result"])
83+
except Exception:
84+
return row["operation_result"]

0 commit comments

Comments
 (0)