Skip to content

Commit 02d83e0

Browse files
committed
PR feedback
1 parent bf37849 commit 02d83e0

5 files changed

Lines changed: 188 additions & 55 deletions

File tree

durabletask/__init__.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,22 @@
33

44
"""Durable Task SDK for Python"""
55

6-
from durabletask.worker import ConcurrencyOptions, VersioningOptions, WorkItemFilters
6+
from durabletask.worker import (
7+
ActivityWorkItemFilter,
8+
ConcurrencyOptions,
9+
EntityWorkItemFilter,
10+
OrchestrationWorkItemFilter,
11+
VersioningOptions,
12+
WorkItemFilters,
13+
)
714

8-
__all__ = ["ConcurrencyOptions", "VersioningOptions", "WorkItemFilters"]
15+
__all__ = [
16+
"ActivityWorkItemFilter",
17+
"ConcurrencyOptions",
18+
"EntityWorkItemFilter",
19+
"OrchestrationWorkItemFilter",
20+
"VersioningOptions",
21+
"WorkItemFilters",
22+
]
923

1024
PACKAGE_NAME = "durabletask"

durabletask/testing/in_memory_backend.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,11 @@ def GetWorkItems(self, request: pb.GetWorkItemsRequest, context):
594594
skipped_entities.append(entity_id)
595595
continue
596596
except ValueError:
597-
pass
597+
self._logger.warning(
598+
f"Cannot parse entity ID '{entity_id}' "
599+
f"for filter matching; skipping")
600+
skipped_entities.append(entity_id)
601+
continue
598602

599603
# Skip if this entity is already being processed
600604
if entity_id in self._entity_in_flight:

durabletask/worker.py

Lines changed: 29 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import random
1010
import time
1111
from concurrent.futures import ThreadPoolExecutor
12+
from dataclasses import dataclass, field
1213
from datetime import datetime, timedelta, timezone
1314
from threading import Event, Thread
1415
from types import GeneratorType
@@ -145,49 +146,43 @@ def __init__(self, version: Optional[str] = None,
145146
_AUTO_GENERATE_FILTERS = object()
146147

147148

149+
@dataclass(frozen=True)
148150
class OrchestrationWorkItemFilter:
149151
"""Specifies a filter for orchestration work items."""
150152

151-
def __init__(self, name: str, versions: Optional[list[str]] = None):
152-
"""Initialize an orchestration filter.
153-
154-
Args:
155-
name: The name of the orchestration to filter.
156-
versions: Optional list of versions to filter.
157-
"""
158-
self.name = name
159-
self.versions: list[str] = versions if versions is not None else []
153+
name: str
154+
"""The name of the orchestration to filter."""
155+
versions: list[str] = field(default_factory=list)
156+
"""Optional list of versions to filter."""
160157

161158

159+
@dataclass(frozen=True)
162160
class ActivityWorkItemFilter:
163161
"""Specifies a filter for activity work items."""
164162

165-
def __init__(self, name: str, versions: Optional[list[str]] = None):
166-
"""Initialize an activity filter.
167-
168-
Args:
169-
name: The name of the activity to filter.
170-
versions: Optional list of versions to filter.
171-
"""
172-
self.name = name
173-
self.versions: list[str] = versions if versions is not None else []
163+
name: str
164+
"""The name of the activity to filter."""
165+
versions: list[str] = field(default_factory=list)
166+
"""Optional list of versions to filter."""
174167

175168

169+
@dataclass(frozen=True)
176170
class EntityWorkItemFilter:
177-
"""Specifies a filter for entity work items."""
171+
"""Specifies a filter for entity work items.
178172
179-
def __init__(self, name: str):
180-
"""Initialize an entity filter.
173+
The name is normalized to lowercase to match entity registration
174+
and instance ID conventions.
175+
"""
181176

182-
Args:
183-
name: The name of the entity to filter.
184-
The name is normalized to lowercase to match
185-
entity registration and instance ID conventions.
186-
"""
187-
EntityInstanceId.validate_entity_name(name)
188-
self.name = name.lower()
177+
name: str
178+
"""The name of the entity to filter."""
189179

180+
def __post_init__(self):
181+
EntityInstanceId.validate_entity_name(self.name)
182+
object.__setattr__(self, 'name', self.name.lower())
190183

184+
185+
@dataclass(frozen=True)
191186
class WorkItemFilters:
192187
"""Work item filters for a Durable Task Worker.
193188
@@ -199,28 +194,12 @@ class WorkItemFilters:
199194
:meth:`TaskHubGrpcWorker.use_work_item_filters` to enable filtering.
200195
"""
201196

202-
def __init__(
203-
self,
204-
orchestrations: Optional[list[OrchestrationWorkItemFilter]] = None,
205-
activities: Optional[list[ActivityWorkItemFilter]] = None,
206-
entities: Optional[list[EntityWorkItemFilter]] = None,
207-
):
208-
"""Initialize work item filters.
209-
210-
Args:
211-
orchestrations: List of orchestration filters.
212-
activities: List of activity filters.
213-
entities: List of entity filters.
214-
"""
215-
self.orchestrations: list[OrchestrationWorkItemFilter] = (
216-
orchestrations if orchestrations is not None else []
217-
)
218-
self.activities: list[ActivityWorkItemFilter] = (
219-
activities if activities is not None else []
220-
)
221-
self.entities: list[EntityWorkItemFilter] = (
222-
entities if entities is not None else []
223-
)
197+
orchestrations: list[OrchestrationWorkItemFilter] = field(default_factory=list)
198+
"""List of orchestration filters."""
199+
activities: list[ActivityWorkItemFilter] = field(default_factory=list)
200+
"""List of activity filters."""
201+
entities: list[EntityWorkItemFilter] = field(default_factory=list)
202+
"""List of entity filters."""
224203

225204
@classmethod
226205
def _from_registry(cls, registry: '_Registry') -> 'WorkItemFilters':

examples/work_item_filtering.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,5 +100,3 @@ def farewell_orchestrator(ctx: task.OrchestrationContext, name: str):
100100
print(f" Completed: {state.serialized_output}")
101101
elif state:
102102
print(f" Failed: {state.failure_details}")
103-
104-
exit()

tests/durabletask-azuremanaged/test_dts_work_item_filters_e2e.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
ActivityWorkItemFilter,
1414
EntityWorkItemFilter,
1515
OrchestrationWorkItemFilter,
16+
VersioningOptions,
17+
VersionMatchStrategy,
1618
WorkItemFilters,
1719
)
1820
from durabletask.azuremanaged.client import DurableTaskSchedulerClient
@@ -262,3 +264,139 @@ def ping(self, _):
262264

263265
assert matched_invoked
264266
assert not unmatched_invoked
267+
268+
269+
# ------------------------------------------------------------------
270+
# Tests: version-aware filtering with strict versioning
271+
# ------------------------------------------------------------------
272+
273+
def _simple_v2_orchestrator(ctx: task.OrchestrationContext, input: int):
274+
"""Orchestrator that returns immediately (no activities) for version tests."""
275+
return input + 1
276+
277+
278+
def test_strict_version_matching_orchestration_completes():
279+
"""Orchestration scheduled with the matching version is processed."""
280+
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
281+
taskhub=taskhub_name, token_credential=None) as w:
282+
w.add_orchestrator(_simple_v2_orchestrator)
283+
w.use_versioning(VersioningOptions(
284+
version="2.0",
285+
match_strategy=VersionMatchStrategy.STRICT,
286+
))
287+
w.use_work_item_filters() # auto-generate with version
288+
w.start()
289+
290+
c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
291+
taskhub=taskhub_name, token_credential=None)
292+
id = c.schedule_new_orchestration(
293+
_simple_v2_orchestrator, input=10, version="2.0")
294+
state = c.wait_for_orchestration_completion(id, timeout=30)
295+
296+
assert state is not None
297+
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
298+
assert state.serialized_output == "11"
299+
300+
301+
def test_strict_version_incompatible_orchestration_stays_pending():
302+
"""Orchestration with an incompatible version is not dispatched and stays pending."""
303+
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
304+
taskhub=taskhub_name, token_credential=None) as w:
305+
w.add_orchestrator(_simple_v2_orchestrator)
306+
w.use_versioning(VersioningOptions(
307+
version="2.0",
308+
match_strategy=VersionMatchStrategy.STRICT,
309+
))
310+
w.use_work_item_filters()
311+
w.start()
312+
313+
c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
314+
taskhub=taskhub_name, token_credential=None)
315+
316+
# Schedule with version "1.0" — incompatible with the worker's "2.0"
317+
bad_id = c.schedule_new_orchestration(
318+
_simple_v2_orchestrator, input=5, version="1.0")
319+
320+
# Schedule a compatible one so we can confirm the worker is active
321+
good_id = c.schedule_new_orchestration(
322+
_simple_v2_orchestrator, input=5, version="2.0")
323+
good_state = c.wait_for_orchestration_completion(good_id, timeout=30)
324+
325+
assert good_state is not None
326+
assert good_state.runtime_status == client.OrchestrationStatus.COMPLETED
327+
328+
# The incompatible orchestration must remain pending (not failed)
329+
bad_state = c.get_orchestration_state(bad_id)
330+
assert bad_state is not None
331+
assert bad_state.runtime_status == client.OrchestrationStatus.PENDING
332+
333+
334+
def test_strict_version_no_version_orchestration_stays_pending():
335+
"""Orchestration scheduled without a version is not dispatched by a strict worker."""
336+
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
337+
taskhub=taskhub_name, token_credential=None) as w:
338+
w.add_orchestrator(_simple_v2_orchestrator)
339+
w.use_versioning(VersioningOptions(
340+
version="2.0",
341+
match_strategy=VersionMatchStrategy.STRICT,
342+
))
343+
w.use_work_item_filters()
344+
w.start()
345+
346+
c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
347+
taskhub=taskhub_name, token_credential=None)
348+
349+
# Schedule without any version
350+
no_ver_id = c.schedule_new_orchestration(
351+
_simple_v2_orchestrator, input=1)
352+
353+
# Schedule a compatible one to prove the worker is running
354+
good_id = c.schedule_new_orchestration(
355+
_simple_v2_orchestrator, input=1, version="2.0")
356+
good_state = c.wait_for_orchestration_completion(good_id, timeout=30)
357+
assert good_state is not None
358+
assert good_state.runtime_status == client.OrchestrationStatus.COMPLETED
359+
360+
# The unversioned orchestration must remain pending
361+
no_ver_state = c.get_orchestration_state(no_ver_id)
362+
assert no_ver_state is not None
363+
assert no_ver_state.runtime_status == client.OrchestrationStatus.PENDING
364+
365+
366+
def test_strict_version_explicit_filters_with_versions():
367+
"""Explicit filters with version constraints enforce strict matching."""
368+
custom_filters = WorkItemFilters(
369+
orchestrations=[
370+
OrchestrationWorkItemFilter(
371+
name=task.get_name(_simple_v2_orchestrator),
372+
versions=["3.0"],
373+
),
374+
],
375+
)
376+
377+
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
378+
taskhub=taskhub_name, token_credential=None) as w:
379+
w.add_orchestrator(_simple_v2_orchestrator)
380+
w.use_work_item_filters(custom_filters)
381+
w.start()
382+
383+
c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
384+
taskhub=taskhub_name, token_credential=None)
385+
386+
# Version "2.0" does not match the filter's "3.0"
387+
bad_id = c.schedule_new_orchestration(
388+
_simple_v2_orchestrator, input=1, version="2.0")
389+
390+
# Version "3.0" should match
391+
good_id = c.schedule_new_orchestration(
392+
_simple_v2_orchestrator, input=1, version="3.0")
393+
good_state = c.wait_for_orchestration_completion(good_id, timeout=30)
394+
395+
assert good_state is not None
396+
assert good_state.runtime_status == client.OrchestrationStatus.COMPLETED
397+
assert good_state.serialized_output == "2"
398+
399+
# Mismatched version must remain pending
400+
bad_state = c.get_orchestration_state(bad_id)
401+
assert bad_state is not None
402+
assert bad_state.runtime_status == client.OrchestrationStatus.PENDING

0 commit comments

Comments
 (0)