Skip to content

Commit bf37849

Browse files
committed
PR Feedback
1 parent 079d6e1 commit bf37849

5 files changed

Lines changed: 211 additions & 28 deletions

File tree

.vscode/mcp.json

Lines changed: 0 additions & 16 deletions
This file was deleted.

durabletask/testing/in_memory_backend.py

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class ActivityWorkItem:
5757
task_id: int
5858
input: Optional[str]
5959
completion_token: int
60+
version: Optional[str] = None
6061

6162

6263
@dataclass
@@ -439,14 +440,50 @@ def RestartInstance(self, request: pb.RestartInstanceRequest, context):
439440

440441
@staticmethod
441442
def _parse_work_item_filters(request: pb.GetWorkItemsRequest):
442-
"""Extract filter name sets from the request, or None if unfiltered."""
443+
"""Extract filters from the request.
444+
445+
Returns a tuple of three values, one per work-item category. Each
446+
value is either ``None`` (no filtering -- dispatch everything) or a
447+
``dict`` mapping a task name to a ``frozenset`` of accepted versions
448+
(empty frozenset means *any* version of that name is accepted).
449+
An empty ``dict`` means the worker opted into filtering for that
450+
category but listed no names, so *nothing* should match.
451+
"""
443452
if not request.HasField("workItemFilters"):
444453
return None, None, None
445454
wf = request.workItemFilters
446-
orch_names = {f.name for f in wf.orchestrations} if wf.orchestrations else None
447-
activity_names = {f.name for f in wf.activities} if wf.activities else None
448-
entity_names = {f.name for f in wf.entities} if wf.entities else None
449-
return orch_names, activity_names, entity_names
455+
456+
def _build_filter(filters):
457+
result: dict[str, frozenset[str]] = {}
458+
for f in filters:
459+
versions = frozenset(f.versions) if f.versions else frozenset()
460+
existing = result.get(f.name, frozenset())
461+
result[f.name] = existing | versions
462+
return result
463+
464+
orch_filter = _build_filter(wf.orchestrations)
465+
activity_filter = _build_filter(wf.activities)
466+
entity_filter = {f.name: frozenset() for f in wf.entities}
467+
return orch_filter, activity_filter, entity_filter
468+
469+
@staticmethod
470+
def _matches_filter(name: str, version: Optional[str],
471+
filt: Optional[dict[str, frozenset[str]]]) -> bool:
472+
"""Check whether a work item matches the parsed filter.
473+
474+
*filt* is ``None`` when the worker did not opt into filtering
475+
(everything matches). Otherwise it is a dict mapping accepted
476+
names to a frozenset of accepted versions. An empty frozenset
477+
means any version of that name is accepted.
478+
"""
479+
if filt is None:
480+
return True
481+
accepted_versions = filt.get(name)
482+
if accepted_versions is None:
483+
return False
484+
if not accepted_versions:
485+
return True # empty set -- any version
486+
return (version or "") in accepted_versions
450487

451488
def GetWorkItems(self, request: pb.GetWorkItemsRequest, context):
452489
"""Streams work items to the worker (orchestration and activity work items)."""
@@ -468,8 +505,9 @@ def GetWorkItems(self, request: pb.GetWorkItemsRequest, context):
468505
if not instance or not instance.pending_events:
469506
continue
470507

471-
# Skip if orchestration name doesn't match filters
472-
if orch_filter is not None and instance.name not in orch_filter:
508+
# Skip if orchestration doesn't match filters
509+
if not self._matches_filter(
510+
instance.name, instance.version, orch_filter):
473511
skipped_orchs.append(instance_id)
474512
continue
475513

@@ -515,7 +553,9 @@ def GetWorkItems(self, request: pb.GetWorkItemsRequest, context):
515553
matched_activity = None
516554
while self._activity_queue:
517555
candidate = self._activity_queue.popleft()
518-
if activity_filter is not None and candidate.name not in activity_filter:
556+
if not self._matches_filter(
557+
candidate.name, candidate.version,
558+
activity_filter):
519559
skipped.append(candidate)
520560
continue
521561
matched_activity = candidate
@@ -548,7 +588,9 @@ def GetWorkItems(self, request: pb.GetWorkItemsRequest, context):
548588
if entity_filter is not None:
549589
try:
550590
parsed = EntityInstanceId.parse(entity_id)
551-
if parsed.entity not in entity_filter:
591+
if not self._matches_filter(
592+
parsed.entity, None,
593+
entity_filter):
552594
skipped_entities.append(entity_id)
553595
continue
554596
except ValueError:
@@ -1313,12 +1355,15 @@ def _process_schedule_task_action(self, instance: OrchestrationInstance,
13131355
instance.status = pb.ORCHESTRATION_STATUS_RUNNING
13141356

13151357
# Queue activity for execution
1358+
task_version = schedule_task.version.value \
1359+
if schedule_task.HasField("version") else None
13161360
self._activity_queue.append(ActivityWorkItem(
13171361
instance_id=instance.instance_id,
13181362
name=task_name,
13191363
task_id=task_id,
13201364
input=input_value,
1321-
completion_token=instance.completion_token
1365+
completion_token=instance.completion_token,
1366+
version=task_version,
13221367
))
13231368
self._work_available.set()
13241369

durabletask/worker.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from threading import Event, Thread
1414
from types import GeneratorType
1515
from enum import Enum
16-
from typing import Any, Generator, Optional, Sequence, Tuple, TypeVar, Union
16+
from typing import Any, Generator, Optional, Sequence, Tuple, TypeVar, Union, overload
1717
import uuid
1818
from packaging.version import InvalidVersion, parse
1919

@@ -181,8 +181,11 @@ def __init__(self, name: str):
181181
182182
Args:
183183
name: The name of the entity to filter.
184+
The name is normalized to lowercase to match
185+
entity registration and instance ID conventions.
184186
"""
185-
self.name = name
187+
EntityInstanceId.validate_entity_name(name)
188+
self.name = name.lower()
186189

187190

188191
class WorkItemFilters:
@@ -516,6 +519,15 @@ def use_versioning(self, version: VersioningOptions) -> None:
516519
raise RuntimeError("Cannot set default version while the worker is running.")
517520
self._registry.versioning = version
518521

522+
@overload
523+
def use_work_item_filters(self) -> None: ...
524+
525+
@overload
526+
def use_work_item_filters(self, filters: WorkItemFilters) -> None: ...
527+
528+
@overload
529+
def use_work_item_filters(self, filters: None) -> None: ...
530+
519531
def use_work_item_filters(
520532
self,
521533
filters: Union[WorkItemFilters, None, object] = _AUTO_GENERATE_FILTERS,

tests/durabletask/test_work_item_filters.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,18 @@ def test_defaults(self):
4747
f = EntityWorkItemFilter(name="myentity")
4848
assert f.name == "myentity"
4949

50+
def test_name_normalized_to_lowercase(self):
51+
f = EntityWorkItemFilter(name="Counter")
52+
assert f.name == "counter"
53+
54+
def test_invalid_name_raises(self):
55+
with pytest.raises(ValueError):
56+
EntityWorkItemFilter(name="bad@name")
57+
58+
def test_empty_name_raises(self):
59+
with pytest.raises(ValueError):
60+
EntityWorkItemFilter(name="")
61+
5062

5163
# ---------------------------------------------------------------------------
5264
# WorkItemFilters construction

tests/durabletask/test_work_item_filters_e2e.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
ActivityWorkItemFilter,
1313
EntityWorkItemFilter,
1414
OrchestrationWorkItemFilter,
15+
VersioningOptions,
16+
VersionMatchStrategy,
1517
WorkItemFilters,
1618
)
1719
from durabletask.testing import create_test_backend
@@ -256,3 +258,131 @@ def ping(self, _):
256258

257259
assert matched_invoked
258260
assert not unmatched_invoked
261+
262+
263+
# ------------------------------------------------------------------
264+
# Tests: version-aware filtering with strict versioning
265+
# ------------------------------------------------------------------
266+
267+
def _simple_v2_orchestrator(ctx: task.OrchestrationContext, input: int):
268+
"""Orchestrator that returns immediately (no activities) for version tests."""
269+
return input + 1
270+
271+
272+
def test_strict_version_matching_orchestration_completes():
273+
"""Orchestration scheduled with the matching version is processed."""
274+
with worker.TaskHubGrpcWorker(host_address=HOST) as w:
275+
w.add_orchestrator(_simple_v2_orchestrator)
276+
w.use_versioning(VersioningOptions(
277+
version="2.0",
278+
match_strategy=VersionMatchStrategy.STRICT,
279+
))
280+
w.use_work_item_filters() # auto-generate with version
281+
w.start()
282+
283+
c = client.TaskHubGrpcClient(host_address=HOST)
284+
id = c.schedule_new_orchestration(
285+
_simple_v2_orchestrator, input=10, version="2.0")
286+
state = c.wait_for_orchestration_completion(id, timeout=30)
287+
288+
assert state is not None
289+
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
290+
assert state.serialized_output == "11"
291+
292+
293+
def test_strict_version_incompatible_orchestration_stays_pending():
294+
"""Orchestration with an incompatible version is not dispatched and stays pending."""
295+
with worker.TaskHubGrpcWorker(host_address=HOST) as w:
296+
w.add_orchestrator(_simple_v2_orchestrator)
297+
w.use_versioning(VersioningOptions(
298+
version="2.0",
299+
match_strategy=VersionMatchStrategy.STRICT,
300+
))
301+
w.use_work_item_filters()
302+
w.start()
303+
304+
c = client.TaskHubGrpcClient(host_address=HOST)
305+
306+
# Schedule with version "1.0" — incompatible with the worker's "2.0"
307+
bad_id = c.schedule_new_orchestration(
308+
_simple_v2_orchestrator, input=5, version="1.0")
309+
310+
# Schedule a compatible one so we can confirm the worker is active
311+
good_id = c.schedule_new_orchestration(
312+
_simple_v2_orchestrator, input=5, version="2.0")
313+
good_state = c.wait_for_orchestration_completion(good_id, timeout=30)
314+
315+
assert good_state is not None
316+
assert good_state.runtime_status == client.OrchestrationStatus.COMPLETED
317+
318+
# The incompatible orchestration must remain pending (not failed)
319+
bad_state = c.get_orchestration_state(bad_id)
320+
assert bad_state is not None
321+
assert bad_state.runtime_status == client.OrchestrationStatus.PENDING
322+
323+
324+
def test_strict_version_no_version_orchestration_stays_pending():
325+
"""Orchestration scheduled without a version is not dispatched by a strict worker."""
326+
with worker.TaskHubGrpcWorker(host_address=HOST) as w:
327+
w.add_orchestrator(_simple_v2_orchestrator)
328+
w.use_versioning(VersioningOptions(
329+
version="2.0",
330+
match_strategy=VersionMatchStrategy.STRICT,
331+
))
332+
w.use_work_item_filters()
333+
w.start()
334+
335+
c = client.TaskHubGrpcClient(host_address=HOST)
336+
337+
# Schedule without any version
338+
no_ver_id = c.schedule_new_orchestration(
339+
_simple_v2_orchestrator, input=1)
340+
341+
# Schedule a compatible one to prove the worker is running
342+
good_id = c.schedule_new_orchestration(
343+
_simple_v2_orchestrator, input=1, version="2.0")
344+
good_state = c.wait_for_orchestration_completion(good_id, timeout=30)
345+
assert good_state is not None
346+
assert good_state.runtime_status == client.OrchestrationStatus.COMPLETED
347+
348+
# The unversioned orchestration must remain pending
349+
no_ver_state = c.get_orchestration_state(no_ver_id)
350+
assert no_ver_state is not None
351+
assert no_ver_state.runtime_status == client.OrchestrationStatus.PENDING
352+
353+
354+
def test_strict_version_explicit_filters_with_versions():
355+
"""Explicit filters with version constraints enforce strict matching."""
356+
custom_filters = WorkItemFilters(
357+
orchestrations=[
358+
OrchestrationWorkItemFilter(
359+
name=task.get_name(_simple_v2_orchestrator),
360+
versions=["3.0"],
361+
),
362+
],
363+
)
364+
365+
with worker.TaskHubGrpcWorker(host_address=HOST) as w:
366+
w.add_orchestrator(_simple_v2_orchestrator)
367+
w.use_work_item_filters(custom_filters)
368+
w.start()
369+
370+
c = client.TaskHubGrpcClient(host_address=HOST)
371+
372+
# Version "2.0" does not match the filter's "3.0"
373+
bad_id = c.schedule_new_orchestration(
374+
_simple_v2_orchestrator, input=1, version="2.0")
375+
376+
# Version "3.0" should match
377+
good_id = c.schedule_new_orchestration(
378+
_simple_v2_orchestrator, input=1, version="3.0")
379+
good_state = c.wait_for_orchestration_completion(good_id, timeout=30)
380+
381+
assert good_state is not None
382+
assert good_state.runtime_status == client.OrchestrationStatus.COMPLETED
383+
assert good_state.serialized_output == "2"
384+
385+
# Mismatched version must remain pending
386+
bad_state = c.get_orchestration_state(bad_id)
387+
assert bad_state is not None
388+
assert bad_state.runtime_status == client.OrchestrationStatus.PENDING

0 commit comments

Comments
 (0)