Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions sdks/python/apache_beam/ml/inference/agent_development_kit.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,17 +229,6 @@ def run_inference(
for element in batch:
session_id: str = inference_args.get("session_id", str(uuid.uuid4()))

# Ensure a session exists for this invocation
try:
model.session_service.create_session(
app_name=self._app_name,
user_id=user_id,
session_id=session_id,
)
except sessions.SessionExistsError:
# It's okay if the session already exists for shared session IDs.
pass

# Wrap plain strings in a Content object
if isinstance(element, str):
# pyrefly: ignore[bad-instantiation]
Expand Down Expand Up @@ -288,6 +277,16 @@ async def _invoke_agent(
The text of the agent's final response, or ``None`` if the agent
produced no final text response.
"""
# Ensure a session exists for this invocation
try:
await model.session_service.create_session(
app_name=self._app_name,
user_id=user_id,
session_id=session_id,
)
except sessions.SessionExistsError:
# It's okay if the session already exists for shared session IDs.
pass
async for event in runner.run_async(
user_id=user_id,
session_id=session_id,
Expand Down
Loading