Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion temporalio/nexus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
See https://github.com/temporalio/sdk-python/tree/main#nexus
"""

from ._decorators import workflow_run_operation
from ._decorators import workflow_run_operation, temporal_operation
from ._operation_context import (
Info,
LoggerAdapter,
Expand All @@ -19,6 +19,7 @@
wait_for_worker_shutdown_sync,
)
from ._token import WorkflowHandle
from ._temporal_client import TemporalNexusClient, TemporalOperationResult

__all__ = (
"workflow_run_operation",
Expand All @@ -35,4 +36,7 @@
"wait_for_worker_shutdown",
"wait_for_worker_shutdown_sync",
"WorkflowHandle",
"TemporalNexusClient",
"TemporalOperationResult",
"temporal_operation",
)
119 changes: 118 additions & 1 deletion temporalio/nexus/_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,20 @@
StartOperationContext,
)

from temporalio.nexus._temporal_client import (
TemporalNexusClient,
TemporalOperationResult,
)

from ._operation_context import WorkflowRunOperationContext
from ._operation_handlers import WorkflowRunOperationHandler
from ._operation_handlers import (
TemporalNexusOperationHandler,
WorkflowRunOperationHandler,
)
from ._token import WorkflowHandle
from ._util import (
get_callable_name,
get_temporal_operation_start_method_input_and_output_type_annotations,
get_workflow_run_start_method_input_and_output_type_annotations,
set_operation_factory,
)
Expand Down Expand Up @@ -130,3 +139,111 @@ async def _start(
return decorator

return decorator(start)


@overload
def temporal_operation(
start: Callable[
[ServiceHandlerT, StartOperationContext, TemporalNexusClient, InputT],
Awaitable[TemporalOperationResult[OutputT]],
],
) -> Callable[
[ServiceHandlerT, StartOperationContext, TemporalNexusClient, InputT],
Awaitable[TemporalOperationResult[OutputT]],
]: ...


@overload
def temporal_operation(
*,
name: str | None = None,
) -> Callable[
[
Callable[
[ServiceHandlerT, StartOperationContext, TemporalNexusClient, InputT],
Awaitable[TemporalOperationResult[OutputT]],
]
],
Callable[
[ServiceHandlerT, StartOperationContext, TemporalNexusClient, InputT],
Awaitable[TemporalOperationResult[OutputT]],
],
]: ...


def temporal_operation(
start: None
| (
Callable[
[ServiceHandlerT, StartOperationContext, TemporalNexusClient, InputT],
Awaitable[TemporalOperationResult[OutputT]],
]
) = None,
*,
name: str | None = None,
) -> (
Callable[
[ServiceHandlerT, StartOperationContext, TemporalNexusClient, InputT],
Awaitable[TemporalOperationResult[OutputT]],
]
| Callable[
[
Callable[
[ServiceHandlerT, StartOperationContext, TemporalNexusClient, InputT],
Awaitable[TemporalOperationResult[OutputT]],
]
],
Callable[
[ServiceHandlerT, StartOperationContext, TemporalNexusClient, InputT],
Awaitable[TemporalOperationResult[OutputT]],
],
]
):
"""Decorator marking a method as the start method for an operation that interacts with Temporal."""

def decorator(
start: Callable[
[ServiceHandlerT, StartOperationContext, TemporalNexusClient, InputT],
Awaitable[TemporalOperationResult[OutputT]],
],
) -> Callable[
[ServiceHandlerT, StartOperationContext, TemporalNexusClient, InputT],
Awaitable[TemporalOperationResult[OutputT]],
]:
(
input_type,
output_type,
) = get_temporal_operation_start_method_input_and_output_type_annotations(start)

def operation_handler_factory(
self: ServiceHandlerT,
) -> OperationHandler[InputT, OutputT]:
async def _start(
ctx: StartOperationContext, client: TemporalNexusClient, input: InputT
) -> TemporalOperationResult[OutputT]:
return await start(
self,
ctx,
client,
input,
)

_start.__doc__ = start.__doc__
return TemporalNexusOperationHandler(_start)

method_name = get_callable_name(start)
op = nexusrpc.Operation(
name=name or method_name,
input_type=input_type,
output_type=output_type,
)
op.method_name = method_name
nexusrpc.set_operation(operation_handler_factory, op)

set_operation_factory(start, operation_handler_factory)
return start

if start is None:
return decorator

return decorator(start)
150 changes: 106 additions & 44 deletions temporalio/nexus/_operation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,50 +492,34 @@ async def start_workflow(
Nexus caller is itself a workflow, this means that the workflow in the caller
namespace web UI will contain links to the started workflow, and vice versa.
"""
# We must pass nexus_completion_callbacks, workflow_event_links, and request_id,
# but these are deliberately not exposed in overloads, hence the type-check
# violation.

# Here we are starting a "nexus-backing" workflow. That means that the StartWorkflow request
# contains nexus-specific data such as a completion callback (used by the handler server
# namespace to deliver the result to the caller namespace when the workflow reaches a
# terminal state) and inbound links to the caller workflow (attached to history events of
# the workflow started in the handler namespace, and displayed in the UI).
with _nexus_backing_workflow_start_context():
wf_handle = await self._temporal_context.client.start_workflow( # type: ignore
workflow=workflow,
arg=arg,
args=args,
id=id,
task_queue=task_queue or self._temporal_context.info().task_queue,
result_type=result_type,
execution_timeout=execution_timeout,
run_timeout=run_timeout,
task_timeout=task_timeout,
id_reuse_policy=id_reuse_policy,
id_conflict_policy=id_conflict_policy,
retry_policy=retry_policy,
cron_schedule=cron_schedule,
memo=memo,
search_attributes=search_attributes,
static_summary=static_summary,
static_details=static_details,
start_delay=start_delay,
start_signal=start_signal,
start_signal_args=start_signal_args,
rpc_metadata=rpc_metadata,
rpc_timeout=rpc_timeout,
request_eager_start=request_eager_start,
priority=priority,
versioning_override=versioning_override,
callbacks=self._temporal_context._get_callbacks(),
workflow_event_links=self._temporal_context._get_workflow_event_links(),
request_id=self._temporal_context.nexus_context.request_id,
)

self._temporal_context._add_outbound_links(wf_handle)

return WorkflowHandle[ReturnType]._unsafe_from_client_workflow_handle(wf_handle)
return await _start_nexus_backing_workflow(
temporal_context=self._temporal_context,
workflow=workflow,
arg=arg,
args=args,
id=id,
task_queue=task_queue,
result_type=result_type,
execution_timeout=execution_timeout,
run_timeout=run_timeout,
task_timeout=task_timeout,
id_reuse_policy=id_reuse_policy,
id_conflict_policy=id_conflict_policy,
retry_policy=retry_policy,
cron_schedule=cron_schedule,
memo=memo,
search_attributes=search_attributes,
static_summary=static_summary,
static_details=static_details,
start_delay=start_delay,
start_signal=start_signal,
start_signal_args=start_signal_args,
rpc_metadata=rpc_metadata,
rpc_timeout=rpc_timeout,
request_eager_start=request_eager_start,
priority=priority,
versioning_override=versioning_override,
)


@dataclass(frozen=True)
Expand Down Expand Up @@ -586,3 +570,81 @@ def process(

logger = LoggerAdapter(logging.getLogger("temporalio.nexus"), None)
"""Logger that emits additional data describing the current Nexus operation."""


async def _start_nexus_backing_workflow(
temporal_context: _TemporalStartOperationContext,
workflow: str | Callable[..., Awaitable[ReturnType]],
arg: Any = temporalio.common._arg_unset,
*,
args: Sequence[Any] = [],
id: str,
task_queue: str | None = None,
result_type: type | None = None,
execution_timeout: timedelta | None = None,
run_timeout: timedelta | None = None,
task_timeout: timedelta | None = None,
id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE,
id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED,
retry_policy: temporalio.common.RetryPolicy | None = None,
cron_schedule: str = "",
memo: Mapping[str, Any] | None = None,
search_attributes: None
| (
temporalio.common.TypedSearchAttributes | temporalio.common.SearchAttributes
) = None,
static_summary: str | None = None,
static_details: str | None = None,
start_delay: timedelta | None = None,
start_signal: str | None = None,
start_signal_args: Sequence[Any] = [],
rpc_metadata: Mapping[str, str | bytes] = {},
rpc_timeout: timedelta | None = None,
request_eager_start: bool = False,
priority: temporalio.common.Priority = temporalio.common.Priority.default,
versioning_override: temporalio.common.VersioningOverride | None = None,
) -> WorkflowHandle[ReturnType]:
# We must pass nexus_completion_callbacks, workflow_event_links, and request_id,
# but these are deliberately not exposed in overloads, hence the type-check
# violation.

# Here we are starting a "nexus-backing" workflow. That means that the StartWorkflow request
# contains nexus-specific data such as a completion callback (used by the handler server
# namespace to deliver the result to the caller namespace when the workflow reaches a
# terminal state) and inbound links to the caller workflow (attached to history events of
# the workflow started in the handler namespace, and displayed in the UI).
with _nexus_backing_workflow_start_context():
wf_handle = await temporal_context.client.start_workflow( # type: ignore
workflow=workflow,
arg=arg,
args=args,
id=id,
task_queue=task_queue or temporal_context.info().task_queue,
result_type=result_type,
execution_timeout=execution_timeout,
run_timeout=run_timeout,
task_timeout=task_timeout,
id_reuse_policy=id_reuse_policy,
id_conflict_policy=id_conflict_policy,
retry_policy=retry_policy,
cron_schedule=cron_schedule,
memo=memo,
search_attributes=search_attributes,
static_summary=static_summary,
static_details=static_details,
start_delay=start_delay,
start_signal=start_signal,
start_signal_args=start_signal_args,
rpc_metadata=rpc_metadata,
rpc_timeout=rpc_timeout,
request_eager_start=request_eager_start,
priority=priority,
versioning_override=versioning_override,
callbacks=temporal_context._get_callbacks(),
workflow_event_links=temporal_context._get_workflow_event_links(),
request_id=temporal_context.nexus_context.request_id,
)

temporal_context._add_outbound_links(wf_handle)

return WorkflowHandle[ReturnType]._unsafe_from_client_workflow_handle(wf_handle)
Loading
Loading