Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions temporalio/worker/_workflow_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1739,10 +1739,13 @@ async def workflow_sleep(
else None
)
fut = self.create_future()
self._timer_impl(
timer_handle = self._timer_impl(
duration,
_TimerOptions(user_metadata=user_metadata),
lambda: fut.set_result(None),
lambda: fut.set_result(None) if not fut.done() else None,
)
fut.add_done_callback(
lambda f: timer_handle.cancel() if f.cancelled() else None
)
await fut

Expand Down
60 changes: 60 additions & 0 deletions tests/worker/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3431,6 +3431,66 @@ async def test_workflow_cancel_signal_and_timer_fired_in_same_task(
await result_task


@workflow.defn
class CancelWorkflowSleepTaskWorkflow:
"""Like CancelSignalAndTimerFiredInSameTaskWorkflow but uses workflow.sleep."""

_ready = False
timer_task: asyncio.Task[None] # type: ignore[reportUninitializedInstanceVariable]

@workflow.run
async def run(self) -> str:
self.timer_task = asyncio.create_task(workflow.sleep(60 * 60))
self._ready = True
try:
await self.timer_task
return "timer_completed"
except asyncio.CancelledError:
return "timer_cancelled"

@workflow.query
def ready(self) -> bool:
return self._ready

@workflow.signal
def cancel_timer(self) -> None:
self.timer_task.cancel()


async def test_workflow_sleep_task_cancellation(
client: Client,
):
async with new_worker(
client,
CancelWorkflowSleepTaskWorkflow,
) as worker:
handle = await client.start_workflow(
CancelWorkflowSleepTaskWorkflow.run,
id=f"workflow-{uuid.uuid4()}",
task_queue=worker.task_queue,
)

async def ready() -> bool:
return await handle.query(CancelWorkflowSleepTaskWorkflow.ready)

await assert_eq_eventually(True, ready)
await handle.signal(CancelWorkflowSleepTaskWorkflow.cancel_timer)
result = await handle.result()

assert result == "timer_cancelled"
# Verify the Temporal timer was actually cancelled on the server
resp = await client.workflow_service.get_workflow_execution_history(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not super important, but it's easier to use WorkflowHandle::fetch_history_events

GetWorkflowExecutionHistoryRequest(
namespace=client.namespace,
execution=WorkflowExecution(workflow_id=handle.id),
)
)
timer_canceled = any(
e.event_type == EventType.EVENT_TYPE_TIMER_CANCELED for e in resp.history.events
)
assert timer_canceled, "Expected TimerCanceled event in history"


class MyCustomError(ApplicationError):
def __init__(self, message: str) -> None:
super().__init__(message, type="MyCustomError", non_retryable=True)
Expand Down
Loading