Skip to content

Commit c69ca20

Browse files
fix: improve types for tree operations (#889)
Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Tobias Raabe <raabe@posteo.de>
1 parent 4cf361b commit c69ca20

8 files changed

Lines changed: 323 additions & 97 deletions

File tree

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@ chronological order. Releases follow [semantic versioning](https://semver.org/)
55
releases are available on [PyPI](https://pypi.org/project/pytask) and
66
[Anaconda.org](https://anaconda.org/conda-forge/pytask).
77

8+
## Unreleased
9+
10+
- [#889](https://github.com/pytask-dev/pytask/pull/889) improves typing for tree
11+
operations by wrapping optree's pytree utilities with pytask-specific signatures
12+
and requiring optree 0.16.0 or newer.
13+
814
## 0.6.0 - 2026-05-01
915

1016
- [#875](https://github.com/pytask-dev/pytask/pull/875) improves the documentation

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ dependencies = [
2424
"click>=8.1.8,!=8.2.0",
2525
"click-default-group>=1.2.4",
2626
"msgspec>=0.18.6",
27-
"optree>=0.9.0",
27+
"optree>=0.16.0",
2828
"packaging>=23.0.0",
2929
"pluggy>=1.3.0",
3030
"rich>=13.8.0",

src/_pytask/collect_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
from _pytask._inspect import get_annotations
1313
from _pytask.exceptions import NodeNotCollectedError
1414
from _pytask.models import NodeInfo
15+
from _pytask.node_protocols import NodeTree
1516
from _pytask.node_protocols import PNode
1617
from _pytask.node_protocols import PProvisionalNode
1718
from _pytask.nodes import PythonNode
1819
from _pytask.task_utils import parse_keyword_arguments_from_signature_defaults
19-
from _pytask.tree_util import PyTree
2020
from _pytask.tree_util import tree_leaves
2121
from _pytask.tree_util import tree_map_with_path
2222
from _pytask.typing import ProductType
@@ -254,7 +254,7 @@ def _collect_nodes_and_provisional_nodes( # noqa: PLR0913
254254
task_path: Path | None,
255255
parameter_name: str,
256256
value: Any,
257-
) -> PyTree[PProvisionalNode | PNode]:
257+
) -> NodeTree:
258258
return tree_map_with_path(
259259
lambda p, x: collection_func(
260260
session,

src/_pytask/node_protocols.py

Lines changed: 62 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,29 @@
33
from typing import TYPE_CHECKING
44
from typing import Any
55
from typing import Protocol
6+
from typing import TypeAlias
67
from typing import runtime_checkable
78

9+
from _pytask.tree_util import PyTree
10+
811
if TYPE_CHECKING:
912
from collections.abc import Callable
1013
from pathlib import Path
1114

1215
from _pytask.mark import Mark
13-
from _pytask.tree_util import PyTree
1416
from _pytask.typing import NodePath
1517

1618

17-
__all__ = ["PNode", "PPathNode", "PProvisionalNode", "PTask", "PTaskWithPath"]
19+
__all__ = [
20+
"NodeTree",
21+
"PNode",
22+
"PPathNode",
23+
"PProvisionalNode",
24+
"PTask",
25+
"PTaskWithPath",
26+
"TaskIO",
27+
"TaskNode",
28+
]
1829

1930

2031
@runtime_checkable
@@ -64,45 +75,6 @@ class PPathNode(PNode, Protocol):
6475
path: NodePath
6576

6677

67-
@runtime_checkable
68-
class PTask(Protocol):
69-
"""Protocol for nodes."""
70-
71-
name: str
72-
depends_on: dict[str, PyTree[PNode | PProvisionalNode]]
73-
produces: dict[str, PyTree[PNode | PProvisionalNode]]
74-
function: Callable[..., Any]
75-
markers: list[Mark]
76-
report_sections: list[tuple[str, str, str]]
77-
attributes: dict[Any, Any]
78-
79-
@property
80-
def signature(self) -> str:
81-
"""Return the signature of the node."""
82-
83-
def state(self) -> str | None:
84-
"""Return the state of the node.
85-
86-
The state can be something like a hash or a last modified timestamp. If the node
87-
does not exist, you can also return ``None``.
88-
89-
"""
90-
91-
def execute(self, **kwargs: Any) -> Any:
92-
"""Return the value of the node that will be injected into the task."""
93-
94-
95-
@runtime_checkable
96-
class PTaskWithPath(PTask, Protocol):
97-
"""Tasks with paths.
98-
99-
Tasks with paths receive special handling when it comes to printing their names.
100-
101-
"""
102-
103-
path: Path
104-
105-
10678
@runtime_checkable
10779
class PProvisionalNode(Protocol):
10880
"""A protocol for provisional nodes.
@@ -141,3 +113,52 @@ def load(self, is_product: bool = False) -> Any: # pragma: no cover
141113

142114
def collect(self) -> list[Any]:
143115
"""Collect the objects that are defined by the provisional nodes."""
116+
117+
118+
TaskNode: TypeAlias = PNode | PProvisionalNode
119+
"""A concrete or provisional pytask node."""
120+
121+
NodeTree: TypeAlias = PyTree[TaskNode]
122+
"""A pytask tree whose leaves are concrete or provisional nodes."""
123+
124+
TaskIO: TypeAlias = dict[str, NodeTree]
125+
"""The top-level task argument mapping for dependencies and products."""
126+
127+
128+
@runtime_checkable
129+
class PTask(Protocol):
130+
"""Protocol for nodes."""
131+
132+
name: str
133+
depends_on: TaskIO
134+
produces: TaskIO
135+
function: Callable[..., Any]
136+
markers: list[Mark]
137+
report_sections: list[tuple[str, str, str]]
138+
attributes: dict[Any, Any]
139+
140+
@property
141+
def signature(self) -> str:
142+
"""Return the signature of the node."""
143+
144+
def state(self) -> str | None:
145+
"""Return the state of the node.
146+
147+
The state can be something like a hash or a last modified timestamp. If the node
148+
does not exist, you can also return ``None``.
149+
150+
"""
151+
152+
def execute(self, **kwargs: Any) -> Any:
153+
"""Return the value of the node that will be injected into the task."""
154+
155+
156+
@runtime_checkable
157+
class PTaskWithPath(PTask, Protocol):
158+
"""Tasks with paths.
159+
160+
Tasks with paths receive special handling when it comes to printing their names.
161+
162+
"""
163+
164+
path: Path

src/_pytask/nodes.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from _pytask.node_protocols import PProvisionalNode
2323
from _pytask.node_protocols import PTask
2424
from _pytask.node_protocols import PTaskWithPath
25+
from _pytask.node_protocols import TaskIO
2526
from _pytask.path import hash_path
2627
from _pytask.typing import NoDefault
2728
from _pytask.typing import NodePath
@@ -34,7 +35,6 @@
3435

3536
from _pytask.mark import Mark
3637
from _pytask.models import NodeInfo
37-
from _pytask.tree_util import PyTree
3838

3939

4040
__all__ = [
@@ -77,10 +77,8 @@ class TaskWithoutPath(PTask):
7777

7878
name: str
7979
function: Callable[..., Any]
80-
depends_on: dict[str, PyTree[PNode | PProvisionalNode]] = field(
81-
default_factory=dict
82-
)
83-
produces: dict[str, PyTree[PNode | PProvisionalNode]] = field(default_factory=dict)
80+
depends_on: TaskIO = field(default_factory=dict)
81+
produces: TaskIO = field(default_factory=dict)
8482
markers: list[Mark] = field(default_factory=list)
8583
report_sections: list[tuple[str, str, str]] = field(default_factory=list)
8684
attributes: dict[Any, Any] = field(default_factory=dict)
@@ -133,10 +131,8 @@ class Task(PTaskWithPath):
133131
path: Path
134132
function: Callable[..., Any]
135133
name: str = field(default="", init=False)
136-
depends_on: dict[str, PyTree[PNode | PProvisionalNode]] = field(
137-
default_factory=dict
138-
)
139-
produces: dict[str, PyTree[PNode | PProvisionalNode]] = field(default_factory=dict)
134+
depends_on: TaskIO = field(default_factory=dict)
135+
produces: TaskIO = field(default_factory=dict)
140136
markers: list[Mark] = field(default_factory=list)
141137
report_sections: list[tuple[str, str, str]] = field(default_factory=list)
142138
attributes: dict[Any, Any] = field(default_factory=dict)

src/_pytask/provisional_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,12 @@
1010
from _pytask.collect_utils import collect_dependency
1111
from _pytask.dag import create_dag_from_session
1212
from _pytask.models import NodeInfo
13-
from _pytask.node_protocols import PNode
13+
from _pytask.node_protocols import NodeTree
1414
from _pytask.node_protocols import PProvisionalNode
1515
from _pytask.node_protocols import PTask
1616
from _pytask.node_protocols import PTaskWithPath
1717
from _pytask.nodes import Task
1818
from _pytask.reports import ExecutionReport
19-
from _pytask.tree_util import PyTree
2019
from _pytask.tree_util import tree_map_with_path
2120
from _pytask.typing import is_task_generator
2221

@@ -29,7 +28,7 @@
2928

3029
def collect_provisional_nodes(
3130
session: Session, task: PTask, node: Any, path: tuple[Any, ...]
32-
) -> PyTree[PNode | PProvisionalNode]:
31+
) -> NodeTree:
3332
"""Collect provisional nodes.
3433
3534
1. Call the [`pytask.PProvisionalNode.collect`][] to receive the raw nodes.

0 commit comments

Comments
 (0)