Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/hooks/pre-commit
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/bin/sh

if hatch fmt --check; then
echo "Hatch fmt check passed!"
else
hatch fmt
echo "Error: hatch fmt modified your files. Please re-stage and commit again."
exit 1
fi
13 changes: 12 additions & 1 deletion examples/examples-catalog.json
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,17 @@
"ApplicationLogLevel": "DEBUG",
"LogFormat": "JSON"
}
}
},
{
"name": "Plugin",
"description": "Test plugin",
"handler": "execution_with_plugin.handler",
"integration": true,
"durableConfig": {
"RetentionPeriodInDays": 7,
"ExecutionTimeout": 300
},
"path": "./src/plugin/execution_with_plugin.py"
}
]
}
70 changes: 70 additions & 0 deletions examples/src/plugin/execution_with_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Demonstrates handler execution without any durable operations."""

import logging
from typing import Any

from aws_durable_execution_sdk_python import StepContext
from aws_durable_execution_sdk_python.context import (
DurableContext,
durable_step,
durable_with_child_context,
)
from aws_durable_execution_sdk_python.execution import durable_execution
from aws_durable_execution_sdk_python.plugin import (
DurableExecutionPlugin,
AttemptStartInfo,
)


class MyPlugin(DurableExecutionPlugin):
logger = logging.getLogger("MyPlugin")

def on_execution_start(self, info):
self.logger.info(f"Execution started: {info}")

def on_execution_end(self, info):
self.logger.info(f"Execution ended: {info}")

def on_operation_start(self, info):
self.logger.info(f"Operation started: {info}")

def on_operation_end(self, info):
self.logger.info(f"Operation ended: {info}")

def on_invocation_start(self, info):
self.logger.info(f"Invocation started: {info}")

def on_invocation_end(self, info):
self.logger.info(f"Invocation ended: {info}")

def on_operation_attempt_start(self, info: AttemptStartInfo) -> None:
self.logger.info(f"Attempt started: {info}")

def on_operation_attempt_end(self, info) -> None:
self.logger.info(f"Attempt ended: {info}")


@durable_step
def add_numbers(_step_context: StepContext, a: int, b: int) -> int:
return a + b


@durable_with_child_context
def add_numbers_in_child(child_context: DurableContext, a: int, b: int):
result: int = child_context.step(
add_numbers(a, b),
name="add-a-and-b",
)
return result


@durable_execution(plugins=[MyPlugin()])
def handler(_event: Any, context: DurableContext) -> int:
result: int = context.run_in_child_context(
add_numbers_in_child(6, 4),
name="add-6-and-4",
)
return context.step(
add_numbers(result, 2),
name="add-result-to-2",
)
18 changes: 18 additions & 0 deletions examples/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,24 @@
"ExecutionTimeout": 300
}
}
},
"ExecutionWithPlugin": {
"Type": "AWS::Serverless::Function",
"Properties": {
"CodeUri": "build/",
"Handler": "execution_with_plugin.handler",
"Description": "Test plugin",
"Role": {
"Fn::GetAtt": [
"DurableFunctionRole",
"Arn"
]
},
"DurableConfig": {
"RetentionPeriodInDays": 7,
"ExecutionTimeout": 300
}
}
}
}
}
24 changes: 24 additions & 0 deletions examples/test/plugin/test_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Tests for step example."""

import pytest
from aws_durable_execution_sdk_python.execution import InvocationStatus

from src.plugin import execution_with_plugin
from test.conftest import deserialize_operation_payload


@pytest.mark.example
@pytest.mark.durable_execution(
handler=execution_with_plugin.handler,
lambda_function_name="Plugin",
)
def test_plugin(durable_runner):
"""Test basic step example."""
with durable_runner:
result = durable_runner.run(input="{}", timeout=10)

assert result.status is InvocationStatus.SUCCEEDED
assert deserialize_operation_payload(result.result) == 12

step_result = result.get_step("add-result-to-2")
assert deserialize_operation_payload(step_result.result) == 12
89 changes: 31 additions & 58 deletions src/aws_durable_execution_sdk_python/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import logging
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, Any

from aws_durable_execution_sdk_python.context import DurableContext
Expand All @@ -26,6 +25,13 @@
Operation,
OperationType,
OperationUpdate,
InvocationStatus,
DurableExecutionInvocationOutput,
)
from aws_durable_execution_sdk_python.plugin import (
DurableExecutionPlugin,
PluginExecutor,
handle_plugins,
)
from aws_durable_execution_sdk_python.state import ExecutionState, ReplayStatus

Expand Down Expand Up @@ -149,77 +155,36 @@ def from_durable_execution_invocation_input(
)


class InvocationStatus(Enum):
SUCCEEDED = "SUCCEEDED"
FAILED = "FAILED"
PENDING = "PENDING"


@dataclass(frozen=True)
class DurableExecutionInvocationOutput:
"""Representation the DurableExecutionInvocationOutput. This is what the Durable lambda handler returns.

If the execution has been already completed via an update to the EXECUTION operation via CheckpointDurableExecution,
payload must be empty for SUCCEEDED/FAILED status.
"""

status: InvocationStatus
result: str | None = None
error: ErrorObject | None = None

@classmethod
def from_dict(
cls, data: MutableMapping[str, Any]
) -> DurableExecutionInvocationOutput:
"""Create an instance from a dictionary.

Args:
data: Dictionary with camelCase keys matching the original structure

Returns:
A DurableExecutionInvocationOutput instance
"""
status = InvocationStatus(data.get("Status"))
error = ErrorObject.from_dict(data["Error"]) if data.get("Error") else None
return cls(status=status, result=data.get("Result"), error=error)

def to_dict(self) -> MutableMapping[str, Any]:
"""Convert to a dictionary with the original field names.

Returns:
Dictionary with the original camelCase keys
"""
result: MutableMapping[str, Any] = {"Status": self.status.value}

if self.result is not None:
# large payloads return "", because checkpointed already
result["Result"] = self.result
if self.error:
result["Error"] = self.error.to_dict()

return result

@classmethod
def create_succeeded(cls, result: str) -> DurableExecutionInvocationOutput:
"""Create a succeeded invocation output."""
return cls(status=InvocationStatus.SUCCEEDED, result=result)


# endregion Invocation models


def durable_execution(
func: Callable[[Any, DurableContext], Any] | None = None,
*,
boto3_client: Boto3LambdaClient | None = None,
plugins: list[DurableExecutionPlugin] | None = None,
) -> Callable[[Any, LambdaContext], Any]:
"""
Decorator to create a durable execution handler.

Args:
func: The user function to decorate
boto3_client: Optional boto3 Lambda client to use
plugins: Optional list of plugins to use (EXPERIMENTAL: This
parameter may change or be removed.)
"""
# Decorator called with parameters
if func is None:
logger.debug("Decorator called with parameters")
return functools.partial(durable_execution, boto3_client=boto3_client)
return functools.partial(
durable_execution, boto3_client=boto3_client, plugins=plugins
)

logger.debug("Starting durable execution handler...")

plugin_executor = PluginExecutor(plugins)

@handle_plugins(plugin_executor)
def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
invocation_input: DurableExecutionInvocationInput
service_client: DurableServiceClient
Expand Down Expand Up @@ -255,6 +220,7 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
operations={},
service_client=service_client,
replay_status=ReplayStatus.NEW,
Comment thread
zhongkechen marked this conversation as resolved.
plugin_executor=plugin_executor,
)

try:
Expand Down Expand Up @@ -306,6 +272,13 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
) as executor,
contextlib.closing(execution_state) as execution_state,
):
# execute the plugins
plugin_executor.on_invocation_start(
durable_execution_arn=invocation_input.durable_execution_arn,
context=context,
execution_operation=execution_state.get_execution_operation(),
is_replaying=execution_state.is_replaying(),
)
# Thread 1: Run background checkpoint processing
executor.submit(execution_state.checkpoint_batches_forever)

Expand Down
64 changes: 64 additions & 0 deletions src/aws_durable_execution_sdk_python/lambda_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,70 @@ class OperationSubType(Enum):
CHAINED_INVOKE = "ChainedInvoke"


class InvocationStatus(Enum):
SUCCEEDED = "SUCCEEDED"
FAILED = "FAILED"
PENDING = "PENDING"

# Used internally only: the invocation failed and the backend will retry
RETRY = "RETRY"


@dataclass(frozen=True)
class DurableExecutionInvocationOutput:
"""Representation the DurableExecutionInvocationOutput. This is what the Durable lambda handler returns.

If the execution has been already completed via an update to the EXECUTION operation via CheckpointDurableExecution,
payload must be empty for SUCCEEDED/FAILED status.
"""

status: InvocationStatus
result: str | None = None
error: ErrorObject | None = None

@classmethod
def from_dict(
cls, data: MutableMapping[str, Any]
) -> DurableExecutionInvocationOutput:
"""Create an instance from a dictionary.

Args:
data: Dictionary with camelCase keys matching the original structure

Returns:
A DurableExecutionInvocationOutput instance
"""
status = InvocationStatus(data.get("Status"))
error = ErrorObject.from_dict(data["Error"]) if data.get("Error") else None
return cls(status=status, result=data.get("Result"), error=error)

def to_dict(self) -> MutableMapping[str, Any]:
"""Convert to a dictionary with the original field names.

Returns:
Dictionary with the original camelCase keys
"""
result: MutableMapping[str, Any] = {"Status": self.status.value}

if self.result is not None:
# large payloads return "", because checkpointed already
result["Result"] = self.result
if self.error:
result["Error"] = self.error.to_dict()

return result

@classmethod
def create_succeeded(cls, result: str) -> DurableExecutionInvocationOutput:
"""Create a succeeded invocation output."""
return cls(status=InvocationStatus.SUCCEEDED, result=result)

@classmethod
def create_retry(cls, error: ErrorObject) -> DurableExecutionInvocationOutput:
"""Create a failed invocation output."""
return cls(status=InvocationStatus.RETRY, error=error)


@dataclass(frozen=True)
class ExecutionDetails:
input_payload: str | None = None
Expand Down
Loading
Loading