diff --git a/temporalio/nexus/__init__.py b/temporalio/nexus/__init__.py index ea049d90e..8c2a15709 100644 --- a/temporalio/nexus/__init__.py +++ b/temporalio/nexus/__init__.py @@ -3,7 +3,7 @@ See https://github.com/temporalio/sdk-python/tree/main#nexus """ -from ._decorators import workflow_run_operation +from ._decorators import temporal_operation, workflow_run_operation from ._operation_context import ( Info, LoggerAdapter, @@ -18,6 +18,7 @@ wait_for_worker_shutdown, wait_for_worker_shutdown_sync, ) +from ._temporal_client import TemporalNexusClient, TemporalOperationResult from ._token import WorkflowHandle __all__ = ( @@ -35,4 +36,7 @@ "wait_for_worker_shutdown", "wait_for_worker_shutdown_sync", "WorkflowHandle", + "TemporalNexusClient", + "TemporalOperationResult", + "temporal_operation", ) diff --git a/temporalio/nexus/_decorators.py b/temporalio/nexus/_decorators.py index 6dfd3daff..cfb7ebb7c 100644 --- a/temporalio/nexus/_decorators.py +++ b/temporalio/nexus/_decorators.py @@ -13,11 +13,20 @@ StartOperationContext, ) +from temporalio.nexus._temporal_client import ( + TemporalNexusClient, + TemporalOperationResult, +) + from ._operation_context import WorkflowRunOperationContext -from ._operation_handlers import WorkflowRunOperationHandler +from ._operation_handlers import ( + TemporalNexusOperationHandler, + WorkflowRunOperationHandler, +) from ._token import WorkflowHandle from ._util import ( get_callable_name, + get_temporal_operation_start_method_input_and_output_type_annotations, get_workflow_run_start_method_input_and_output_type_annotations, set_operation_factory, ) @@ -130,3 +139,111 @@ async def _start( return decorator return decorator(start) + + +@overload +def temporal_operation( + start: Callable[ + [ServiceHandlerT, StartOperationContext, TemporalNexusClient, InputT], + Awaitable[TemporalOperationResult[OutputT]], + ], +) -> Callable[ + [ServiceHandlerT, StartOperationContext, TemporalNexusClient, InputT], + Awaitable[TemporalOperationResult[OutputT]], +]: ... + + +@overload +def temporal_operation( + *, + name: str | None = None, +) -> Callable[ + [ + Callable[ + [ServiceHandlerT, StartOperationContext, TemporalNexusClient, InputT], + Awaitable[TemporalOperationResult[OutputT]], + ] + ], + Callable[ + [ServiceHandlerT, StartOperationContext, TemporalNexusClient, InputT], + Awaitable[TemporalOperationResult[OutputT]], + ], +]: ... + + +def temporal_operation( + start: None + | ( + Callable[ + [ServiceHandlerT, StartOperationContext, TemporalNexusClient, InputT], + Awaitable[TemporalOperationResult[OutputT]], + ] + ) = None, + *, + name: str | None = None, +) -> ( + Callable[ + [ServiceHandlerT, StartOperationContext, TemporalNexusClient, InputT], + Awaitable[TemporalOperationResult[OutputT]], + ] + | Callable[ + [ + Callable[ + [ServiceHandlerT, StartOperationContext, TemporalNexusClient, InputT], + Awaitable[TemporalOperationResult[OutputT]], + ] + ], + Callable[ + [ServiceHandlerT, StartOperationContext, TemporalNexusClient, InputT], + Awaitable[TemporalOperationResult[OutputT]], + ], + ] +): + """Decorator marking a method as the start method for an operation that interacts with Temporal.""" + + def decorator( + start: Callable[ + [ServiceHandlerT, StartOperationContext, TemporalNexusClient, InputT], + Awaitable[TemporalOperationResult[OutputT]], + ], + ) -> Callable[ + [ServiceHandlerT, StartOperationContext, TemporalNexusClient, InputT], + Awaitable[TemporalOperationResult[OutputT]], + ]: + ( + input_type, + output_type, + ) = get_temporal_operation_start_method_input_and_output_type_annotations(start) + + def operation_handler_factory( + self: ServiceHandlerT, + ) -> OperationHandler[InputT, OutputT]: + async def _start( + ctx: StartOperationContext, client: TemporalNexusClient, input: InputT + ) -> TemporalOperationResult[OutputT]: + return await start( + self, + ctx, + client, + input, + ) + + _start.__doc__ = start.__doc__ + return TemporalNexusOperationHandler(_start) + + method_name = get_callable_name(start) + op = nexusrpc.Operation( + name=name or method_name, + input_type=input_type, + output_type=output_type, + ) + op.method_name = method_name + nexusrpc.set_operation(operation_handler_factory, op) + + set_operation_factory(start, operation_handler_factory) + return start + + if start is None: + return decorator + + return decorator(start) diff --git a/temporalio/nexus/_operation_context.py b/temporalio/nexus/_operation_context.py index 04462c900..e1a3f6987 100644 --- a/temporalio/nexus/_operation_context.py +++ b/temporalio/nexus/_operation_context.py @@ -492,50 +492,34 @@ async def start_workflow( Nexus caller is itself a workflow, this means that the workflow in the caller namespace web UI will contain links to the started workflow, and vice versa. """ - # We must pass nexus_completion_callbacks, workflow_event_links, and request_id, - # but these are deliberately not exposed in overloads, hence the type-check - # violation. - - # Here we are starting a "nexus-backing" workflow. That means that the StartWorkflow request - # contains nexus-specific data such as a completion callback (used by the handler server - # namespace to deliver the result to the caller namespace when the workflow reaches a - # terminal state) and inbound links to the caller workflow (attached to history events of - # the workflow started in the handler namespace, and displayed in the UI). - with _nexus_backing_workflow_start_context(): - wf_handle = await self._temporal_context.client.start_workflow( # type: ignore - workflow=workflow, - arg=arg, - args=args, - id=id, - task_queue=task_queue or self._temporal_context.info().task_queue, - result_type=result_type, - execution_timeout=execution_timeout, - run_timeout=run_timeout, - task_timeout=task_timeout, - id_reuse_policy=id_reuse_policy, - id_conflict_policy=id_conflict_policy, - retry_policy=retry_policy, - cron_schedule=cron_schedule, - memo=memo, - search_attributes=search_attributes, - static_summary=static_summary, - static_details=static_details, - start_delay=start_delay, - start_signal=start_signal, - start_signal_args=start_signal_args, - rpc_metadata=rpc_metadata, - rpc_timeout=rpc_timeout, - request_eager_start=request_eager_start, - priority=priority, - versioning_override=versioning_override, - callbacks=self._temporal_context._get_callbacks(), - workflow_event_links=self._temporal_context._get_workflow_event_links(), - request_id=self._temporal_context.nexus_context.request_id, - ) - - self._temporal_context._add_outbound_links(wf_handle) - - return WorkflowHandle[ReturnType]._unsafe_from_client_workflow_handle(wf_handle) + return await _start_nexus_backing_workflow( + temporal_context=self._temporal_context, + workflow=workflow, + arg=arg, + args=args, + id=id, + task_queue=task_queue, + result_type=result_type, + execution_timeout=execution_timeout, + run_timeout=run_timeout, + task_timeout=task_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + cron_schedule=cron_schedule, + memo=memo, + search_attributes=search_attributes, + static_summary=static_summary, + static_details=static_details, + start_delay=start_delay, + start_signal=start_signal, + start_signal_args=start_signal_args, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + request_eager_start=request_eager_start, + priority=priority, + versioning_override=versioning_override, + ) @dataclass(frozen=True) @@ -586,3 +570,81 @@ def process( logger = LoggerAdapter(logging.getLogger("temporalio.nexus"), None) """Logger that emits additional data describing the current Nexus operation.""" + + +async def _start_nexus_backing_workflow( + temporal_context: _TemporalStartOperationContext, + workflow: str | Callable[..., Awaitable[ReturnType]], + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: str | None = None, + result_type: type | None = None, + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: Mapping[str, Any] | None = None, + search_attributes: None + | ( + temporalio.common.TypedSearchAttributes | temporalio.common.SearchAttributes + ) = None, + static_summary: str | None = None, + static_details: str | None = None, + start_delay: timedelta | None = None, + start_signal: str | None = None, + start_signal_args: Sequence[Any] = [], + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + request_eager_start: bool = False, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: temporalio.common.VersioningOverride | None = None, +) -> WorkflowHandle[ReturnType]: + # We must pass nexus_completion_callbacks, workflow_event_links, and request_id, + # but these are deliberately not exposed in overloads, hence the type-check + # violation. + + # Here we are starting a "nexus-backing" workflow. That means that the StartWorkflow request + # contains nexus-specific data such as a completion callback (used by the handler server + # namespace to deliver the result to the caller namespace when the workflow reaches a + # terminal state) and inbound links to the caller workflow (attached to history events of + # the workflow started in the handler namespace, and displayed in the UI). + with _nexus_backing_workflow_start_context(): + wf_handle = await temporal_context.client.start_workflow( # type: ignore + workflow=workflow, + arg=arg, + args=args, + id=id, + task_queue=task_queue or temporal_context.info().task_queue, + result_type=result_type, + execution_timeout=execution_timeout, + run_timeout=run_timeout, + task_timeout=task_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + cron_schedule=cron_schedule, + memo=memo, + search_attributes=search_attributes, + static_summary=static_summary, + static_details=static_details, + start_delay=start_delay, + start_signal=start_signal, + start_signal_args=start_signal_args, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + request_eager_start=request_eager_start, + priority=priority, + versioning_override=versioning_override, + callbacks=temporal_context._get_callbacks(), + workflow_event_links=temporal_context._get_workflow_event_links(), + request_id=temporal_context.nexus_context.request_id, + ) + + temporal_context._add_outbound_links(wf_handle) + + return WorkflowHandle[ReturnType]._unsafe_from_client_workflow_handle(wf_handle) diff --git a/temporalio/nexus/_operation_handlers.py b/temporalio/nexus/_operation_handlers.py index 68035ca41..b01b68561 100644 --- a/temporalio/nexus/_operation_handlers.py +++ b/temporalio/nexus/_operation_handlers.py @@ -1,9 +1,7 @@ from __future__ import annotations from collections.abc import Awaitable, Callable -from typing import ( - Any, -) +from typing import Any from nexusrpc import ( HandlerError, @@ -16,12 +14,18 @@ OperationHandler, StartOperationContext, StartOperationResultAsync, + StartOperationResultSync, ) from temporalio.nexus._operation_context import ( _temporal_cancel_operation_context, + _TemporalCancelOperationContext, +) +from temporalio.nexus._temporal_client import ( + TemporalNexusClient, + TemporalOperationResult, ) -from temporalio.nexus._token import WorkflowHandle +from temporalio.nexus._token import OperationToken, OperationTokenType, WorkflowHandle from ._util import ( is_async_callable, @@ -112,3 +116,54 @@ async def _cancel_workflow( type=HandlerErrorType.NOT_FOUND, ) from err await client_workflow_handle.cancel(**kwargs) + + +class TemporalNexusOperationHandler(OperationHandler[InputT, OutputT]): + """Operation handler for Nexus operations that interact with Temporal.""" + + def __init__( + self, + start: Callable[ + [StartOperationContext, TemporalNexusClient, InputT], + Awaitable[TemporalOperationResult[OutputT]], + ], + ) -> None: + """Initialize the Temporal operation handler.""" + if not is_async_callable(start): + raise RuntimeError( + f"{start} is not an `async def` method. " + "TemporalNexusOperationHandler must be initialized with an " + "`async def` start method." + ) + self._start = start + if start.__doc__: + if start_func := getattr(self.start, "__func__", None): + start_func.__doc__ = start.__doc__ + + async def start( + self, ctx: StartOperationContext, input: InputT + ) -> StartOperationResultSync[OutputT] | StartOperationResultAsync: + nexus_client = TemporalNexusClient() + result = await self._start(ctx, nexus_client, input) + return result._to_nexus_result() + + async def cancel(self, ctx: CancelOperationContext, token: str) -> None: + temporal_context = _TemporalCancelOperationContext.get() + client = temporal_context.client + + operation_token = OperationToken.decode(token) + if client.namespace != operation_token.namespace: + raise ValueError( + f"Client namespace {client.namespace} does not match " + f"operation token namespace {operation_token.namespace}" + ) + + match operation_token.type: + case OperationTokenType.WORKFLOW: + await self.cancel_workflow_run(ctx, operation_token.workflow_id) + + async def cancel_workflow_run(self, _ctx: CancelOperationContext, workflow_id: str): + """Cancels the workflow identified by workflow_id""" + temporal_context = _TemporalCancelOperationContext.get() + workflow_handle = temporal_context.client.get_workflow_handle(workflow_id) + await workflow_handle.cancel() diff --git a/temporalio/nexus/_temporal_client.py b/temporalio/nexus/_temporal_client.py new file mode 100644 index 000000000..0bcd21990 --- /dev/null +++ b/temporalio/nexus/_temporal_client.py @@ -0,0 +1,283 @@ +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable, Callable, Mapping, Sequence +from dataclasses import dataclass +from datetime import timedelta +from typing import ( + TYPE_CHECKING, + Any, + Concatenate, + Generic, + Self, + TypeVar, + cast, + overload, +) + +from nexusrpc import HandlerError, HandlerErrorType +from nexusrpc.handler import StartOperationResultAsync, StartOperationResultSync + +import temporalio.common +from temporalio.nexus._operation_context import ( + _start_nexus_backing_workflow, + _TemporalStartOperationContext, +) +from temporalio.types import ( + MethodAsyncNoParam, + MethodAsyncSingleParam, + MultiParamSpec, + ParamType, + ReturnType, + SelfType, +) + +if TYPE_CHECKING: + import temporalio.client + + +_ResultT = TypeVar("_ResultT") + + +@dataclass(frozen=True) +class TemporalOperationResult(Generic[_ResultT]): + """Unified result: sync value or async token.""" + + value: _ResultT | object = temporalio.common._arg_unset + token: str | None = None + + @classmethod + def sync(cls, value: _ResultT) -> "TemporalOperationResult[_ResultT]": + return cls(value=value) + + @classmethod + def async_token(cls, token: str) -> Self: + return cls(token=token) + + def _to_nexus_result( + self, + ) -> StartOperationResultSync[_ResultT] | StartOperationResultAsync: + if self.token is not None: + return StartOperationResultAsync(self.token) + elif self.value is not temporalio.common._arg_unset: + return StartOperationResultSync(cast(_ResultT, self.value)) + else: + raise RuntimeError( + "Invalid TemporalOperationResult. Neither token nor value are set." + ) + + +class TemporalNexusClient: + """Nexus-aware wrapper around a Temporal Client.""" + + def __init__(self) -> None: + self._temporal_context = _TemporalStartOperationContext.get() + self.started_async = asyncio.Event() + + @property + def client(self) -> temporalio.client.Client: + return self._temporal_context.client + + # Overload for no-param workflow + @overload + async def start_workflow( + self, + workflow: MethodAsyncNoParam[SelfType, ReturnType], + *, + id: str, + task_queue: str | None = None, + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: Mapping[str, Any] | None = None, + search_attributes: None + | ( + temporalio.common.TypedSearchAttributes | temporalio.common.SearchAttributes + ) = None, + static_summary: str | None = None, + static_details: str | None = None, + start_delay: timedelta | None = None, + start_signal: str | None = None, + start_signal_args: Sequence[Any] = [], + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + request_eager_start: bool = False, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: temporalio.common.VersioningOverride | None = None, + ) -> TemporalOperationResult[ReturnType]: ... + + # Overload for single-param workflow + @overload + async def start_workflow( + self, + workflow: MethodAsyncSingleParam[SelfType, ParamType, ReturnType], + arg: ParamType, + *, + id: str, + task_queue: str | None = None, + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: Mapping[str, Any] | None = None, + search_attributes: None + | ( + temporalio.common.TypedSearchAttributes | temporalio.common.SearchAttributes + ) = None, + static_summary: str | None = None, + static_details: str | None = None, + start_delay: timedelta | None = None, + start_signal: str | None = None, + start_signal_args: Sequence[Any] = [], + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + request_eager_start: bool = False, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: temporalio.common.VersioningOverride | None = None, + ) -> TemporalOperationResult[ReturnType]: ... + + # Overload for multi-param workflow + @overload + async def start_workflow( + self, + workflow: Callable[ + Concatenate[SelfType, MultiParamSpec], Awaitable[ReturnType] + ], + *, + args: Sequence[Any], + id: str, + task_queue: str | None = None, + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: Mapping[str, Any] | None = None, + search_attributes: None + | ( + temporalio.common.TypedSearchAttributes | temporalio.common.SearchAttributes + ) = None, + static_summary: str | None = None, + static_details: str | None = None, + start_delay: timedelta | None = None, + start_signal: str | None = None, + start_signal_args: Sequence[Any] = [], + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + request_eager_start: bool = False, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: temporalio.common.VersioningOverride | None = None, + ) -> TemporalOperationResult[ReturnType]: ... + + # Overload for string-name workflow + @overload + async def start_workflow( + self, + workflow: str, + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: str | None = None, + result_type: type[ReturnType] | None = None, + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: Mapping[str, Any] | None = None, + search_attributes: None + | ( + temporalio.common.TypedSearchAttributes | temporalio.common.SearchAttributes + ) = None, + static_summary: str | None = None, + static_details: str | None = None, + start_delay: timedelta | None = None, + start_signal: str | None = None, + start_signal_args: Sequence[Any] = [], + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + request_eager_start: bool = False, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: temporalio.common.VersioningOverride | None = None, + ) -> TemporalOperationResult[ReturnType]: ... + + async def start_workflow( + self, + workflow: str | Callable[..., Awaitable[ReturnType]], + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: str | None = None, + result_type: type | None = None, + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: Mapping[str, Any] | None = None, + search_attributes: None + | ( + temporalio.common.TypedSearchAttributes | temporalio.common.SearchAttributes + ) = None, + static_summary: str | None = None, + static_details: str | None = None, + start_delay: timedelta | None = None, + start_signal: str | None = None, + start_signal_args: Sequence[Any] = [], + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + request_eager_start: bool = False, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: temporalio.common.VersioningOverride | None = None, + ) -> TemporalOperationResult[ReturnType]: + if self.started_async.is_set(): + raise HandlerError( + "Only one async operation can be started per operation handler invocation. Use TemporalNexusClient.client for additional workflow interactions", + type=HandlerErrorType.BAD_REQUEST, + ) + + wf_handle = await _start_nexus_backing_workflow( + temporal_context=self._temporal_context, + workflow=workflow, + arg=arg, + args=args, + id=id, + task_queue=task_queue, + result_type=result_type, + execution_timeout=execution_timeout, + run_timeout=run_timeout, + task_timeout=task_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + cron_schedule=cron_schedule, + memo=memo, + search_attributes=search_attributes, + static_summary=static_summary, + static_details=static_details, + start_delay=start_delay, + start_signal=start_signal, + start_signal_args=start_signal_args, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + request_eager_start=request_eager_start, + priority=priority, + versioning_override=versioning_override, + ) + self.started_async.set() + return TemporalOperationResult.async_token(wf_handle.to_token()) diff --git a/temporalio/nexus/_token.py b/temporalio/nexus/_token.py index 0a3d27375..39e62703b 100644 --- a/temporalio/nexus/_token.py +++ b/temporalio/nexus/_token.py @@ -3,17 +3,99 @@ import base64 import json from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generic, Literal +from enum import IntEnum +from typing import TYPE_CHECKING, Any, Generic, Self from nexusrpc import OutputT -OperationTokenType = Literal[1] -OPERATION_TOKEN_TYPE_WORKFLOW: OperationTokenType = 1 + +class OperationTokenType(IntEnum): + WORKFLOW = 1 + if TYPE_CHECKING: import temporalio.client +@dataclass(frozen=True, kw_only=True) +class OperationToken: + version: int | None = None + type: OperationTokenType + namespace: str + workflow_id: str + + def encode(self) -> str: + """Convert handle to a base64url-encoded token string.""" + token_details: dict[str, Any] = { + "t": self.type, + "ns": self.namespace, + "wid": self.workflow_id, + } + if self.version is not None: + token_details["v"] = self.version + return _base64url_encode_no_padding( + json.dumps( + token_details, + separators=(",", ":"), + ).encode("utf-8") + ) + + @classmethod + def decode(cls, token: str) -> Self: + """Decodes and validates a token from its base64url-encoded string representation.""" + if not token: + raise TypeError("invalid token: token is empty") + try: + decoded_bytes = _base64url_decode_no_padding(token) + except Exception as err: + raise TypeError("failed to decode token as base64url") from err + try: + token_details = json.loads(decoded_bytes.decode("utf-8")) + except Exception as err: + raise TypeError("failed to unmarshal operation token") from err + + if not isinstance(token_details, dict): + raise TypeError(f"invalid token: expected dict, got {type(token_details)}") + + raw_token_type = token_details.get("t") + if not isinstance(raw_token_type, int): + raise TypeError( + f"invalid token: expected token type to be an int, got {type(raw_token_type)}" + ) + + token_type = OperationTokenType(raw_token_type) + + version = token_details.get("v") + if version is not None and not isinstance(version, int): + raise TypeError( + f"invalid token: expected version to be an int or null, got {type(version)}" + ) + + workflow_id = token_details.get("wid") + if not isinstance(workflow_id, str): + raise TypeError( + f"invalid token: expected workflow id to be a string, got {type(workflow_id)}" + ) + + if token_type == OperationTokenType.WORKFLOW and not workflow_id: + raise TypeError( + "invalid token: expected non-empty workflow id for token type `WORKFLOW`" + ) + + namespace = token_details.get("ns") + if not isinstance(namespace, str) or not namespace: + raise TypeError( + f"invalid token: expected namespace to be a string or null, got {type(namespace)}" + ) + + return cls( + type=OperationTokenType(token_type), + namespace=namespace, + workflow_id=workflow_id, + version=version, + ) + + @dataclass(frozen=True) class WorkflowHandle(Generic[OutputT]): """A handle to a workflow that is backing a Nexus operation. @@ -59,65 +141,36 @@ def _unsafe_from_client_workflow_handle( def to_token(self) -> str: """Convert handle to a base64url-encoded token string.""" - return _base64url_encode_no_padding( - json.dumps( - { - "t": OPERATION_TOKEN_TYPE_WORKFLOW, - "ns": self.namespace, - "wid": self.workflow_id, - }, - separators=(",", ":"), - ).encode("utf-8") - ) + return OperationToken( + type=OperationTokenType.WORKFLOW, + namespace=self.namespace, + workflow_id=self.workflow_id, + ).encode() @classmethod def from_token(cls, token: str) -> WorkflowHandle[OutputT]: """Decodes and validates a token from its base64url-encoded string representation.""" - if not token: - raise TypeError("invalid workflow token: token is empty") - try: - decoded_bytes = _base64url_decode_no_padding(token) - except Exception as err: - raise TypeError("failed to decode token as base64url") from err - try: - workflow_operation_token = json.loads(decoded_bytes.decode("utf-8")) - except Exception as err: - raise TypeError("failed to unmarshal workflow operation token") from err - - if not isinstance(workflow_operation_token, dict): + op_token = OperationToken.decode(token) + if op_token.type != OperationTokenType.WORKFLOW: raise TypeError( - f"invalid workflow token: expected dict, got {type(workflow_operation_token)}" + f"invalid workflow token type: {op_token.type}, expected: {OperationTokenType.WORKFLOW}" ) - token_type = workflow_operation_token.get("t") - if token_type != OPERATION_TOKEN_TYPE_WORKFLOW: - raise TypeError( - f"invalid workflow token type: {token_type}, expected: {OPERATION_TOKEN_TYPE_WORKFLOW}" - ) - - version = workflow_operation_token.get("v") - if version is not None and version != 0: + if op_token.version is not None and op_token.version != 0: raise TypeError( "invalid workflow token: 'v' field, if present, must be 0 or null/absent" ) - workflow_id = workflow_operation_token.get("wid") - if not workflow_id or not isinstance(workflow_id, str): - raise TypeError( - "invalid workflow token: missing, empty, or non-string workflow ID (wid)" - ) - - namespace = workflow_operation_token.get("ns") - if namespace is None or not isinstance(namespace, str): + if not isinstance(op_token.namespace, str): # Allow empty string for ns, but it must be present and a string raise TypeError( "invalid workflow token: missing or non-string namespace (ns)" ) return cls( - namespace=namespace, - workflow_id=workflow_id, - version=version, + namespace=op_token.namespace, + workflow_id=op_token.workflow_id, + version=op_token.version, ) diff --git a/temporalio/nexus/_util.py b/temporalio/nexus/_util.py index 48d3ad644..6cba8493e 100644 --- a/temporalio/nexus/_util.py +++ b/temporalio/nexus/_util.py @@ -15,8 +15,13 @@ InputT, OutputT, ) +from nexusrpc.handler import StartOperationContext from temporalio.nexus._operation_context import WorkflowRunOperationContext +from temporalio.nexus._temporal_client import ( + TemporalNexusClient, + TemporalOperationResult, +) from ._token import ( WorkflowHandle as WorkflowHandle, @@ -39,13 +44,54 @@ def get_workflow_run_start_method_input_and_output_type_annotations( ``start`` must be a type-annotated start method that returns a :py:class:`temporalio.nexus.WorkflowHandle`. """ - input_type, output_type = _get_start_method_input_and_output_type_annotations(start) + return _get_wrapped_start_method_input_and_output_type_annotations( + start, + expected_param_types=(WorkflowRunOperationContext,), + expected_return_origin=WorkflowHandle, + ) + + +def get_temporal_operation_start_method_input_and_output_type_annotations( + start: Callable[ + [ServiceHandlerT, StartOperationContext, TemporalNexusClient, InputT], + Awaitable[TemporalOperationResult[OutputT]], + ], +) -> tuple[ + type[InputT] | None, + type[OutputT] | None, +]: + """Return operation input and output types. + + ``start`` must be a type-annotated start method that returns a + :py:class:`temporalio.nexus.TemporalOperationResult`. + """ + return _get_wrapped_start_method_input_and_output_type_annotations( + start, + expected_param_types=(StartOperationContext, TemporalNexusClient), + expected_return_origin=TemporalOperationResult, + ) + + +def _get_wrapped_start_method_input_and_output_type_annotations( + start: Callable[..., Any], + *, + expected_param_types: tuple[type[Any], ...], + expected_return_origin: type[Any], +) -> tuple[ + type[Any] | None, + type[Any] | None, +]: + input_type, output_type = _get_start_method_input_and_output_type_annotations( + start, + expected_param_types=expected_param_types, + ) origin_type = typing.get_origin(output_type) if not origin_type: output_type = None - elif not issubclass(origin_type, WorkflowHandle): + elif not _is_subclass(origin_type, expected_return_origin): warnings.warn( - f"Expected return type of {start.__name__} to be a subclass of WorkflowHandle, " + f"Expected return type of {start.__name__} to be a subclass of " + f"{expected_return_origin.__name__}, " f"but is {output_type}" ) output_type = None @@ -65,13 +111,12 @@ def get_workflow_run_start_method_input_and_output_type_annotations( def _get_start_method_input_and_output_type_annotations( - start: Callable[ - [ServiceHandlerT, WorkflowRunOperationContext, InputT], - Awaitable[WorkflowHandle[OutputT]], - ], + start: Callable[..., Any], + *, + expected_param_types: tuple[type[Any], ...], ) -> tuple[ - type[InputT] | None, - type[OutputT] | None, + type[Any] | None, + type[Any] | None, ]: try: type_annotations = typing.get_type_hints(start) @@ -81,27 +126,39 @@ def _get_start_method_input_and_output_type_annotations( ) return None, None output_type = type_annotations.pop("return", None) + expected_parameter_count = len(expected_param_types) + 1 - if len(type_annotations) != 2: + if len(type_annotations) != expected_parameter_count: suffix = f": {type_annotations}" if type_annotations else "" warnings.warn( - f"Expected decorated start method {start} to have exactly 2 " - f"type-annotated parameters (ctx and input), but it has {len(type_annotations)}" + f"Expected decorated start method {start} to have exactly " + f"{expected_parameter_count} type-annotated parameters, " + f"but it has {len(type_annotations)}" f"{suffix}." ) input_type = None else: - ctx_type, input_type = type_annotations.values() - if not issubclass(ctx_type, WorkflowRunOperationContext): - warnings.warn( - f"Expected first parameter of {start} to be an instance of " - f"WorkflowRunOperationContext, but is {ctx_type}." - ) - input_type = None + *param_types, input_type = type_annotations.values() + for index, (param_type, expected_param_type) in enumerate( + zip(param_types, expected_param_types), start=1 + ): + if not _is_subclass(param_type, expected_param_type): + warnings.warn( + f"Expected parameter {index} of {start} to be an instance of " + f"{expected_param_type.__name__}, but is {param_type}." + ) + input_type = None return input_type, output_type +def _is_subclass(cls: Any, class_or_tuple: type[Any]) -> bool: + try: + return issubclass(cls, class_or_tuple) + except TypeError: + return False + + def get_callable_name(fn: Callable[..., Any]) -> str: """Return the name of a callable object.""" method_name = getattr(fn, "__name__", None) diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 40f17302c..6df9bf629 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -5557,6 +5557,31 @@ async def start_operation( summary: str | None = None, ) -> NexusOperationHandle[OutputT]: ... + # Overload for temporal_operation methods + @overload + @abstractmethod + async def start_operation( + self, + operation: Callable[ + [ + ServiceHandlerT, + nexusrpc.handler.StartOperationContext, + temporalio.nexus.TemporalNexusClient, + InputT, + ], + Awaitable[temporalio.nexus.TemporalOperationResult[OutputT]], + ], + input: InputT, + *, + output_type: type[OutputT] | None = None, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, + headers: Mapping[str, str] | None = None, + summary: str | None = None, + ) -> NexusOperationHandle[OutputT]: ... + # Overload for sync_operation methods (async def) @overload @abstractmethod @@ -5703,6 +5728,31 @@ async def execute_operation( summary: str | None = None, ) -> OutputT: ... + # Overload for temporal_operation methods + @overload + @abstractmethod + async def execute_operation( + self, + operation: Callable[ + [ + ServiceHandlerT, + nexusrpc.handler.StartOperationContext, + temporalio.nexus.TemporalNexusClient, + InputT, + ], + Awaitable[temporalio.nexus.TemporalOperationResult[OutputT]], + ], + input: InputT, + *, + output_type: type[OutputT] | None = None, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, + headers: Mapping[str, str] | None = None, + summary: str | None = None, + ) -> OutputT: ... + # Overload for sync_operation methods (async def) @overload @abstractmethod diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index f467f8aa3..fe37296e9 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -254,6 +254,44 @@ async def check() -> PendingActivityInfo: return await assert_eventually(check, timeout=timeout) +async def assert_event_subsequence( + wf_handle: WorkflowHandle, + expected_events: list[EventType.ValueType], + timeout: timedelta = timedelta(seconds=5), +) -> None: + """ + Given a workflow handle and a sequence of event types, assert that the workflow's history + contains that subsequence of events in the order specified. + """ + + async def check(): + history = await wf_handle.fetch_history() + + _all_events = iter(history.events) + _expected_events = iter(expected_events) + + previous_expected_event_type_name = None + for expected_event_type in _expected_events: + expected_event_type_name = EventType.Name(expected_event_type).removeprefix( + "EVENT_TYPE_" + ) + has_expected = next( + (e for e in _all_events if e.event_type == expected_event_type), + None, + ) + if not has_expected: + if previous_expected_event_type_name is not None: + prefix = f"After {previous_expected_event_type_name}, " + else: + prefix = "" + raise AssertionError( + f"{prefix}expected {expected_event_type_name} in workflow {wf_handle.id}" + ) + previous_expected_event_type_name = expected_event_type_name + + await assert_eventually(check, timeout=timeout) + + async def get_pending_activity_info( handle: WorkflowHandle, activity_id: str, diff --git a/tests/nexus/test_type_errors.py b/tests/nexus/test_nexus_type_errors.py similarity index 56% rename from tests/nexus/test_type_errors.py rename to tests/nexus/test_nexus_type_errors.py index 1f5d3e2a7..bbb76cda0 100644 --- a/tests/nexus/test_type_errors.py +++ b/tests/nexus/test_nexus_type_errors.py @@ -21,10 +21,55 @@ class MyOutput: pass +@workflow.defn +class MyNoArgProcWorkflow: + @workflow.run + async def run(self) -> None: + pass + + +@workflow.defn +class MyOneArgProcWorkflow: + @workflow.run + async def run(self, _input: MyInput) -> None: + pass + + +@workflow.defn +class MyTwoArgProcWorkflow: + @workflow.run + async def run(self, _input: MyInput, _arg2: int) -> None: + pass + + +@workflow.defn +class MyThreeArgProcWorkflow: + @workflow.run + async def run(self, _input: MyInput, _arg2: int, _arg3: int) -> None: + pass + + +@workflow.defn +class MyFourArgProcWorkflow: + @workflow.run + async def run(self, _input: MyInput, _arg2: int, _arg3: int, _arg4: int) -> None: + pass + + +@workflow.defn +class MyFiveArgProcWorkflow: + @workflow.run + async def run( + self, _input: MyInput, _arg2: int, _arg3: int, _arg4: int, _arg5: int + ) -> None: + pass + + @nexusrpc.service class MyService: my_sync_operation: nexusrpc.Operation[MyInput, MyOutput] my_workflow_run_operation: nexusrpc.Operation[MyInput, MyOutput] + my_temporal_operation: nexusrpc.Operation[int, None] @nexusrpc.handler.service_handler(service=MyService) @@ -41,6 +86,71 @@ async def my_workflow_run_operation( ) -> temporalio.nexus.WorkflowHandle[MyOutput]: raise NotImplementedError + @temporalio.nexus.temporal_operation + async def my_temporal_operation( + self, + _ctx: nexusrpc.handler.StartOperationContext, + client: temporalio.nexus.TemporalNexusClient, + input: int, + ) -> temporalio.nexus.TemporalOperationResult[None]: + """ + Typed proc workflow starts from a generic Temporal Nexus operation handler + infer TemporalOperationResult[None] for 0 to 5 workflow parameters. + """ + if input == 0: + result_0: temporalio.nexus.TemporalOperationResult[ + None + ] = await client.start_workflow(MyNoArgProcWorkflow.run, id="proc-0") + return result_0 + if input == 1: + result_1: temporalio.nexus.TemporalOperationResult[ + None + ] = await client.start_workflow( + MyOneArgProcWorkflow.run, MyInput(), id="proc-1" + ) + return result_1 + if input == 2: + result_2: temporalio.nexus.TemporalOperationResult[ + None + ] = await client.start_workflow( + MyTwoArgProcWorkflow.run, args=[MyInput(), 2], id="proc-2" + ) + return result_2 + if input == 3: + result_3: temporalio.nexus.TemporalOperationResult[ + None + ] = await client.start_workflow( + MyThreeArgProcWorkflow.run, + args=[MyInput(), 2, 3], + id="proc-3", + ) + return result_3 + if input == 4: + result_4: temporalio.nexus.TemporalOperationResult[ + None + ] = await client.start_workflow( + MyFourArgProcWorkflow.run, + args=[MyInput(), 2, 3, 4], + id="proc-4", + ) + return result_4 + if input == 5: + result_5: temporalio.nexus.TemporalOperationResult[ + None + ] = await client.start_workflow( + MyFiveArgProcWorkflow.run, + args=[MyInput(), 2, 3, 4, 5], + id="proc-5", + ) + return result_5 + # assert-type-error-pyright: 'No overloads for "start_workflow" match' + return await client.start_workflow( # type: ignore + MyOneArgProcWorkflow.run, + # assert-type-error-pyright: 'Argument of type .+ cannot be assigned to parameter' + "wrong-input-type", # type: ignore + id="proc-wrong-input", + ) + @nexusrpc.handler.service_handler(service=MyService) class MyServiceHandler2: @@ -56,6 +166,15 @@ async def my_workflow_run_operation( ) -> temporalio.nexus.WorkflowHandle[MyOutput]: raise NotImplementedError + @temporalio.nexus.temporal_operation + async def my_temporal_operation( + self, + _ctx: nexusrpc.handler.StartOperationContext, + _client: temporalio.nexus.TemporalNexusClient, + _input: int, + ) -> temporalio.nexus.TemporalOperationResult[None]: + raise NotImplementedError + @nexusrpc.handler.service_handler class MyServiceHandlerWithoutServiceDefinition: @@ -71,6 +190,15 @@ async def my_workflow_run_operation( ) -> temporalio.nexus.WorkflowHandle[MyOutput]: raise NotImplementedError + @temporalio.nexus.temporal_operation + async def my_temporal_operation( + self, + _ctx: nexusrpc.handler.StartOperationContext, + _client: temporalio.nexus.TemporalNexusClient, + _input: int, + ) -> temporalio.nexus.TemporalOperationResult[None]: + raise NotImplementedError + @workflow.defn class MyWorkflow1: @@ -106,6 +234,15 @@ async def test_invoke_by_operation_definition_happy_path(self) -> None: ) _output_2_1: MyOutput = await _handle_2 + # temporal operation + _output_3: None = await nexus_client.execute_operation( # type: ignore + MyService.my_temporal_operation, 0 + ) + _handle_3: workflow.NexusOperationHandle[ + None + ] = await nexus_client.start_operation(MyService.my_temporal_operation, 0) + _output_3_1: None = await _handle_3 # type: ignore + @workflow.defn class MyWorkflow2: @@ -143,6 +280,17 @@ async def test_invoke_by_operation_handler_happy_path(self) -> None: ) _output_2_1: MyOutput = await _handle_2 + # temporal operation + _output_3: None = await nexus_client.execute_operation( # type: ignore + MyServiceHandler.my_temporal_operation, 0 + ) + _handle_3: workflow.NexusOperationHandle[ + None + ] = await nexus_client.start_operation( + MyServiceHandler.my_temporal_operation, 0 + ) + _output_3_1: None = await _handle_3 # type: ignore + @workflow.defn class MyWorkflow3: @@ -162,6 +310,12 @@ async def test_invoke_by_operation_definition_wrong_input_type(self) -> None: # assert-type-error-pyright: 'Argument of type .+ cannot be assigned to parameter "input"' "wrong-input-type", # type: ignore ) + # assert-type-error-pyright: 'No overloads for "execute_operation" match' + await nexus_client.execute_operation( # type: ignore + MyService.my_temporal_operation, + # assert-type-error-pyright: 'Argument of type .+ cannot be assigned to parameter "input"' + "wrong-input-type", # type: ignore + ) @workflow.defn @@ -182,6 +336,12 @@ async def test_invoke_by_operation_handler_wrong_input_type(self) -> None: # assert-type-error-pyright: 'Argument of type .+ cannot be assigned to parameter "input"' "wrong-input-type", # type: ignore ) + # assert-type-error-pyright: 'No overloads for "execute_operation" match' + await nexus_client.execute_operation( # type: ignore + MyServiceHandler.my_temporal_operation, # type: ignore[arg-type] + # assert-type-error-pyright: 'Argument of type .+ cannot be assigned to parameter "input"' + "wrong-input-type", # type: ignore + ) @workflow.defn diff --git a/tests/nexus/test_temporal_operation.py b/tests/nexus/test_temporal_operation.py new file mode 100644 index 000000000..0c5e59570 --- /dev/null +++ b/tests/nexus/test_temporal_operation.py @@ -0,0 +1,438 @@ +import uuid +from dataclasses import dataclass + +import nexusrpc +import pytest +from nexusrpc import HandlerErrorType, Operation, service +from nexusrpc.handler import ( + StartOperationContext, + service_handler, +) + +import temporalio.exceptions +from temporalio import nexus, workflow +from temporalio.client import Client, WorkflowFailureError +from temporalio.testing import WorkflowEnvironment +from temporalio.worker import Worker +from tests.helpers import EventType, assert_event_subsequence +from tests.helpers.nexus import make_nexus_endpoint_name + + +@dataclass +class Input: + value: str + task_queue: str + + +@workflow.defn +class EchoWorkflow: + @workflow.run + async def run(self, input: Input) -> str: + return input.value + + +@service +class TestService: + echo: Operation[Input, str] + blocking: Operation[None, None] + double_start: Operation[Input, None] + sync_result: Operation[Input, str] + + +@service_handler(service=TestService) +class EchoServiceHandler: + @nexus.temporal_operation + async def echo( + self, + _ctx: StartOperationContext, + client: nexus.TemporalNexusClient, + input: Input, + ) -> nexus.TemporalOperationResult[str]: + return await client.start_workflow( + EchoWorkflow.run, input, id=f"echo-{input.value}" + ) + + @nexus.temporal_operation + async def blocking( + self, + _ctx: StartOperationContext, + client: nexus.TemporalNexusClient, + _input: None, + ) -> nexus.TemporalOperationResult[None]: + return await client.start_workflow( + BlockingWorkflow.run, id=f"blocking-{uuid.uuid4}" + ) + + @nexus.temporal_operation + async def double_start( + self, + _ctx: StartOperationContext, + client: nexus.TemporalNexusClient, + input: Input, + ) -> nexus.TemporalOperationResult[None]: + await client.start_workflow( + EchoWorkflow.run, input, id=f"double-start-{uuid.uuid4}" + ) + await client.start_workflow( + EchoWorkflow.run, input, id=f"double-start-{uuid.uuid4}" + ) + return nexus.TemporalOperationResult.sync(None) + + @nexus.temporal_operation + async def sync_result( + self, + _ctx: StartOperationContext, + _client: nexus.TemporalNexusClient, + input: Input, + ) -> nexus.TemporalOperationResult[str]: + return nexus.TemporalOperationResult.sync(input.value) + + +@workflow.defn +class EchoWorkflowCaller: + @workflow.run + async def run(self, input: Input) -> str: + client = workflow.create_nexus_client( + service=TestService, endpoint=make_nexus_endpoint_name(input.task_queue) + ) + return await client.execute_operation(TestService.echo, input) + + +async def test_temporal_operation_start_workflow( + client: Client, env: WorkflowEnvironment +): + task_queue = str(uuid.uuid4()) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[EchoServiceHandler()], + workflows=[EchoWorkflow, EchoWorkflowCaller], + ): + wf_handle = await client.start_workflow( + EchoWorkflowCaller.run, + Input(value="test", task_queue=task_queue), + task_queue=task_queue, + id=str(uuid.uuid4()), + ) + result = await wf_handle.result() + assert result == "test" + + await assert_event_subsequence( + wf_handle, + [ + EventType.EVENT_TYPE_NEXUS_OPERATION_SCHEDULED, + EventType.EVENT_TYPE_NEXUS_OPERATION_STARTED, + EventType.EVENT_TYPE_NEXUS_OPERATION_COMPLETED, + ], + ) + + +@workflow.defn +class BlockingWorkflow: + done: bool = False + + @workflow.run + async def run(self) -> None: + await workflow.wait_condition(lambda: self.done) + + @workflow.update + async def unblock(self): + self.done = True + + +@workflow.defn +class CancelBlockingWorkflowCaller: + op_started = False + + @workflow.run + async def run(self, input: Input) -> None: + client = workflow.create_nexus_client( + service=TestService, endpoint=make_nexus_endpoint_name(input.task_queue) + ) + op_handle = await client.start_operation(TestService.blocking, None) + self.op_started = True + return await op_handle + + @workflow.update + async def wait_operation_started(self): + await workflow.wait_condition(lambda: self.op_started) + + +async def test_temporal_operation_cancel_workflow( + client: Client, env: WorkflowEnvironment +): + task_queue = str(uuid.uuid4()) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[EchoServiceHandler()], + workflows=[BlockingWorkflow, CancelBlockingWorkflowCaller], + ): + wf_handle = await client.start_workflow( + CancelBlockingWorkflowCaller.run, + Input(value="test", task_queue=task_queue), + task_queue=task_queue, + id=f"blocking-{uuid.uuid4()}", + ) + + await wf_handle.execute_update( + CancelBlockingWorkflowCaller.wait_operation_started + ) + + await wf_handle.cancel() + + await assert_event_subsequence( + wf_handle, + [ + EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED, + EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_COMPLETED, + EventType.EVENT_TYPE_NEXUS_OPERATION_CANCELED, + ], + ) + + +@workflow.defn +class DoubleStartWorkflowCaller: + @workflow.run + async def run(self, input: Input) -> None: + client = workflow.create_nexus_client( + service=TestService, endpoint=make_nexus_endpoint_name(input.task_queue) + ) + op_handle = await client.start_operation(TestService.double_start, input) + return await op_handle + + +async def test_temporal_operation_double_start_raises_handler_err( + client: Client, env: WorkflowEnvironment +): + task_queue = str(uuid.uuid4()) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[EchoServiceHandler()], + workflows=[EchoWorkflow, DoubleStartWorkflowCaller], + ): + with pytest.raises(WorkflowFailureError) as err: + await client.execute_workflow( + DoubleStartWorkflowCaller.run, + Input(value="test", task_queue=task_queue), + task_queue=task_queue, + id=f"double-start-{uuid.uuid4()}", + ) + + assert isinstance(err.value.cause, temporalio.exceptions.NexusOperationError) + assert isinstance(err.value.cause.cause, nexusrpc.HandlerError) + assert err.value.cause.cause.type == HandlerErrorType.BAD_REQUEST + assert ( + "Only one async operation can be started per operation handler invocation" + in err.value.cause.cause.message + ) + + +@workflow.defn +class SyncResultCaller: + @workflow.run + async def run(self, input: Input) -> str: + client = workflow.create_nexus_client( + service=TestService, endpoint=make_nexus_endpoint_name(input.task_queue) + ) + return await client.execute_operation(TestService.sync_result, input) + + +async def test_temporal_operation_sync_result(client: Client, env: WorkflowEnvironment): + task_queue = str(uuid.uuid4()) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[EchoServiceHandler()], + workflows=[SyncResultCaller], + ): + wf_handle = await client.start_workflow( + SyncResultCaller.run, + Input(value="test", task_queue=task_queue), + task_queue=task_queue, + id=str(uuid.uuid4()), + ) + result = await wf_handle.result() + assert result == "test" + + # Sync results do not produce a NEXUS_OPERATION_STARTED event, + await assert_event_subsequence( + wf_handle, + [ + EventType.EVENT_TYPE_NEXUS_OPERATION_SCHEDULED, + EventType.EVENT_TYPE_NEXUS_OPERATION_COMPLETED, + ], + ) + + +@dataclass +class TemporalOperationOverloadTestValue: + value: int + + +@workflow.defn +class TemporalOperationOverloadTestWorkflow: + @workflow.run + async def run( + self, input: TemporalOperationOverloadTestValue + ) -> TemporalOperationOverloadTestValue: + return TemporalOperationOverloadTestValue(value=input.value * 2) + + +@workflow.defn +class TemporalOperationOverloadTestWorkflowNoParam: + @workflow.run + async def run(self) -> TemporalOperationOverloadTestValue: + return TemporalOperationOverloadTestValue(value=0) + + +@service_handler +class TemporalOperationOverloadTestServiceHandler: + @nexus.temporal_operation + async def no_param( + self, + _ctx: StartOperationContext, + client: nexus.TemporalNexusClient, + _input: TemporalOperationOverloadTestValue, + ) -> nexus.TemporalOperationResult[TemporalOperationOverloadTestValue]: + return await client.start_workflow( + TemporalOperationOverloadTestWorkflowNoParam.run, + id=str(uuid.uuid4()), + ) + + @nexus.temporal_operation + async def single_param( + self, + _ctx: StartOperationContext, + client: nexus.TemporalNexusClient, + input: TemporalOperationOverloadTestValue, + ) -> nexus.TemporalOperationResult[TemporalOperationOverloadTestValue]: + return await client.start_workflow( + TemporalOperationOverloadTestWorkflow.run, + input, + id=str(uuid.uuid4()), + ) + + @nexus.temporal_operation + async def multi_param( + self, + _ctx: StartOperationContext, + client: nexus.TemporalNexusClient, + input: TemporalOperationOverloadTestValue, + ) -> nexus.TemporalOperationResult[TemporalOperationOverloadTestValue]: + return await client.start_workflow( + TemporalOperationOverloadTestWorkflow.run, + args=[input], + id=str(uuid.uuid4()), + ) + + @nexus.temporal_operation + async def by_name( + self, + _ctx: StartOperationContext, + client: nexus.TemporalNexusClient, + input: TemporalOperationOverloadTestValue, + ) -> nexus.TemporalOperationResult[TemporalOperationOverloadTestValue]: + return await client.start_workflow( + "TemporalOperationOverloadTestWorkflow", + input, + id=str(uuid.uuid4()), + result_type=TemporalOperationOverloadTestValue, + ) + + @nexus.temporal_operation + async def by_name_multi_param( + self, + _ctx: StartOperationContext, + client: nexus.TemporalNexusClient, + input: TemporalOperationOverloadTestValue, + ) -> nexus.TemporalOperationResult[TemporalOperationOverloadTestValue]: + return await client.start_workflow( + "TemporalOperationOverloadTestWorkflow", + args=[input], + id=str(uuid.uuid4()), + result_type=TemporalOperationOverloadTestValue, + ) + + +@workflow.defn +class TemporalOperationOverloadTestCallerWorkflow: + @workflow.run + async def run( + self, op: str, input: TemporalOperationOverloadTestValue + ) -> TemporalOperationOverloadTestValue: + client = workflow.create_nexus_client( + service=TemporalOperationOverloadTestServiceHandler, + endpoint=make_nexus_endpoint_name(workflow.info().task_queue), + ) + + if op == "no_param": + return await client.execute_operation( + TemporalOperationOverloadTestServiceHandler.no_param, input + ) + elif op == "single_param": + return await client.execute_operation( + TemporalOperationOverloadTestServiceHandler.single_param, input + ) + elif op == "multi_param": + return await client.execute_operation( + TemporalOperationOverloadTestServiceHandler.multi_param, input + ) + elif op == "by_name": + return await client.execute_operation( + TemporalOperationOverloadTestServiceHandler.by_name, input + ) + elif op == "by_name_multi_param": + return await client.execute_operation( + TemporalOperationOverloadTestServiceHandler.by_name_multi_param, input + ) + else: + raise ValueError(f"Unknown op: {op}") + + +@pytest.mark.parametrize( + "op", + [ + "no_param", + "single_param", + "multi_param", + "by_name", + "by_name_multi_param", + ], +) +async def test_temporal_operation_overloads( + client: Client, env: WorkflowEnvironment, op: str +): + task_queue = str(uuid.uuid4()) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) + async with Worker( + client, + task_queue=task_queue, + workflows=[ + TemporalOperationOverloadTestCallerWorkflow, + TemporalOperationOverloadTestWorkflow, + TemporalOperationOverloadTestWorkflowNoParam, + ], + nexus_service_handlers=[TemporalOperationOverloadTestServiceHandler()], + ): + result = await client.execute_workflow( + TemporalOperationOverloadTestCallerWorkflow.run, + args=[op, TemporalOperationOverloadTestValue(value=2)], + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + assert result == ( + TemporalOperationOverloadTestValue(value=0) + if op == "no_param" + else TemporalOperationOverloadTestValue(value=4) + ) diff --git a/tests/nexus/test_workflow_caller_cancellation_types.py b/tests/nexus/test_workflow_caller_cancellation_types.py index 6ebba5759..e2e1b3da9 100644 --- a/tests/nexus/test_workflow_caller_cancellation_types.py +++ b/tests/nexus/test_workflow_caller_cancellation_types.py @@ -20,6 +20,7 @@ from temporalio.common import WorkflowIDConflictPolicy from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker +from tests.helpers import assert_event_subsequence from tests.helpers.nexus import make_nexus_endpoint_name @@ -456,38 +457,3 @@ async def get_event_time( return event.event_time.ToDatetime().replace(tzinfo=timezone.utc) event_type_name = EventType.Name(event_type).removeprefix("EVENT_TYPE_") assert False, f"Event {event_type_name} not found in {wf_handle.id}" - - -async def assert_event_subsequence( - wf_handle: WorkflowHandle, - expected_events: list[EventType.ValueType], -) -> None: - """ - Given a workflow handle and a sequence of event types, assert that the workflow's history - contains that subsequence of events in the order specified. - """ - all_events = [] - async for e in wf_handle.fetch_history_events(): - all_events.append(e) - - _all_events = iter(all_events) - _expected_events = iter(expected_events) - - previous_expected_event_type_name = None - for expected_event_type in _expected_events: - expected_event_type_name = EventType.Name(expected_event_type).removeprefix( - "EVENT_TYPE_" - ) - has_expected = next( - (e for e in _all_events if e.event_type == expected_event_type), - None, - ) - if not has_expected: - if previous_expected_event_type_name is not None: - prefix = f"After {previous_expected_event_type_name}, " - else: - prefix = "" - pytest.fail( - f"{prefix}expected {expected_event_type_name} in workflow {wf_handle.id}" - ) - previous_expected_event_type_name = expected_event_type_name diff --git a/tests/nexus/test_workflow_caller_cancellation_types_when_cancel_handler_fails.py b/tests/nexus/test_workflow_caller_cancellation_types_when_cancel_handler_fails.py index a344f1b5c..e7826895d 100644 --- a/tests/nexus/test_workflow_caller_cancellation_types_when_cancel_handler_fails.py +++ b/tests/nexus/test_workflow_caller_cancellation_types_when_cancel_handler_fails.py @@ -23,9 +23,9 @@ from temporalio.common import WorkflowIDConflictPolicy from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker +from tests.helpers import assert_event_subsequence from tests.helpers.nexus import make_nexus_endpoint_name from tests.nexus.test_workflow_caller_cancellation_types import ( - assert_event_subsequence, get_event_time, has_event, )