44import inspect
55import sys
66import time
7+ import warnings
78from concurrent .futures import Future
89from types import TracebackType
910from typing import Any
11+ from typing import Callable
1012
1113import cloudpickle
1214from pybaum .tree_util import tree_map
1315from pytask import console
1416from pytask import ExecutionReport
17+ from pytask import get_marks
1518from pytask import hookimpl
19+ from pytask import Mark
1620from pytask import remove_internal_traceback_frames_from_exc_info
1721from pytask import Session
1822from pytask import Task
1923from pytask_parallel .backends import PARALLEL_BACKENDS
2024from rich .console import ConsoleOptions
2125from rich .traceback import Traceback
2226
27+ # Can be removed if pinned to pytask >= 0.2.6.
28+ try :
29+ from pytask import parse_warning_filter
30+ from pytask import warning_record_to_str
31+ from pytask import WarningReport
32+ except ImportError :
33+ from _pytask .warnings import parse_warning_filter
34+ from _pytask .warnings import warning_record_to_str
35+ from _pytask .warnings_utils import WarningReport
36+
2337
2438@hookimpl
2539def pytask_post_parse (config : dict [str , Any ]) -> None :
@@ -85,42 +99,38 @@ def pytask_execute_build(session: Session) -> bool | None:
8599
86100 for task_name in list (running_tasks ):
87101 future = running_tasks [task_name ]
88- if future .done () and (
89- future .exception () is not None
90- or future .result () is not None
91- ):
92- task = session .dag .nodes [task_name ]["task" ]
93- if future .exception () is not None :
94- exception = future .exception ()
95- exc_info = (
96- type (exception ),
97- exception ,
98- exception .__traceback__ ,
99- )
100- else :
101- exc_info = future .result ()
102-
103- newly_collected_reports .append (
104- ExecutionReport .from_task_and_exception (task , exc_info )
102+ if future .done ():
103+ warning_reports , task_exception = future .result ()
104+ session .warnings .extend (warning_reports )
105+ exc_info = (
106+ _parse_future_exception (future .exception ())
107+ or task_exception
105108 )
106- running_tasks .pop (task_name )
107- session .scheduler .done (task_name )
108- elif future .done () and future .exception () is None :
109- task = session .dag .nodes [task_name ]["task" ]
110- try :
111- session .hook .pytask_execute_task_teardown (
112- session = session , task = task
113- )
114- except Exception :
115- report = ExecutionReport .from_task_and_exception (
116- task , sys .exc_info ()
109+ if exc_info is not None :
110+ task = session .dag .nodes [task_name ]["task" ]
111+ newly_collected_reports .append (
112+ ExecutionReport .from_task_and_exception (
113+ task , exc_info
114+ )
117115 )
116+ running_tasks .pop (task_name )
117+ session .scheduler .done (task_name )
118118 else :
119- report = ExecutionReport .from_task (task )
120-
121- running_tasks .pop (task_name )
122- newly_collected_reports .append (report )
123- session .scheduler .done (task_name )
119+ task = session .dag .nodes [task_name ]["task" ]
120+ try :
121+ session .hook .pytask_execute_task_teardown (
122+ session = session , task = task
123+ )
124+ except Exception :
125+ report = ExecutionReport .from_task_and_exception (
126+ task , sys .exc_info ()
127+ )
128+ else :
129+ report = ExecutionReport .from_task (task )
130+
131+ running_tasks .pop (task_name )
132+ newly_collected_reports .append (report )
133+ session .scheduler .done (task_name )
124134 else :
125135 pass
126136
@@ -144,6 +154,17 @@ def pytask_execute_build(session: Session) -> bool | None:
144154 return None
145155
146156
157+ def _parse_future_exception (
158+ exception : BaseException | None ,
159+ ) -> tuple [type [BaseException ], BaseException , TracebackType ] | None :
160+ """Parse a future exception."""
161+ return (
162+ None
163+ if exception is None
164+ else (type (exception ), exception , exception .__traceback__ )
165+ )
166+
167+
147168class ProcessesNameSpace :
148169 """The name space for hooks related to processes."""
149170
@@ -167,6 +188,9 @@ def pytask_execute_task(session: Session, task: Task) -> Future[Any] | None:
167188 bytes_kwargs = bytes_kwargs ,
168189 show_locals = session .config ["show_locals" ],
169190 console_options = console .options ,
191+ session_filterwarnings = session .config ["filterwarnings" ],
192+ task_filterwarnings = get_marks (task , "filterwarnings" ),
193+ task_short_name = task .short_name ,
170194 )
171195 return None
172196
@@ -176,7 +200,10 @@ def _unserialize_and_execute_task(
176200 bytes_kwargs : bytes ,
177201 show_locals : bool ,
178202 console_options : ConsoleOptions ,
179- ) -> tuple [type [BaseException ], BaseException , str ] | None :
203+ session_filterwarnings : tuple [str , ...],
204+ task_filterwarnings : tuple [Mark , ...],
205+ task_short_name : str ,
206+ ) -> tuple [list [WarningReport ], tuple [type [BaseException ], BaseException , str ] | None ]:
180207 """Unserialize and execute task.
181208
182209 This function receives bytes and unpickles them to a task which is them execute in a
@@ -188,13 +215,40 @@ def _unserialize_and_execute_task(
188215 task = cloudpickle .loads (bytes_function )
189216 kwargs = cloudpickle .loads (bytes_kwargs )
190217
191- try :
192- task .execute (** kwargs )
193- except Exception :
194- exc_info = sys .exc_info ()
195- processed_exc_info = _process_exception (exc_info , show_locals , console_options )
196- return processed_exc_info
197- return None
218+ with warnings .catch_warnings (record = True ) as log :
219+ # mypy can't infer that record=True means log is not None; help it.
220+ assert log is not None
221+
222+ for arg in session_filterwarnings :
223+ warnings .filterwarnings (* parse_warning_filter (arg , escape = False ))
224+
225+ # apply filters from "filterwarnings" marks
226+ for mark in task_filterwarnings :
227+ for arg in mark .args :
228+ warnings .filterwarnings (* parse_warning_filter (arg , escape = False ))
229+
230+ try :
231+ task .execute (** kwargs )
232+ except Exception :
233+ exc_info = sys .exc_info ()
234+ processed_exc_info = _process_exception (
235+ exc_info , show_locals , console_options
236+ )
237+ else :
238+ processed_exc_info = None
239+
240+ warning_reports = []
241+ for warning_message in log :
242+ fs_location = warning_message .filename , warning_message .lineno
243+ warning_reports .append (
244+ WarningReport (
245+ message = warning_record_to_str (warning_message ),
246+ fs_location = fs_location ,
247+ id_ = task_short_name ,
248+ )
249+ )
250+
251+ return warning_reports , processed_exc_info
198252
199253
200254def _process_exception (
@@ -224,11 +278,33 @@ def pytask_execute_task(session: Session, task: Task) -> Future[Any] | None:
224278 """
225279 if session .config ["n_workers" ] > 1 :
226280 kwargs = _create_kwargs_for_task (task )
227- return session .executor .submit (task .execute , ** kwargs )
281+ return session .executor .submit (
282+ _mock_processes_for_threads , func = task .execute , ** kwargs
283+ )
228284 else :
229285 return None
230286
231287
288+ def _mock_processes_for_threads (
289+ func : Callable [..., Any ], ** kwargs : Any
290+ ) -> tuple [list [Any ], tuple [type [BaseException ], BaseException , TracebackType ] | None ]:
291+ """Mock execution function such that it returns the same as for processes.
292+
293+ The function for processes returns ``warning_reports`` and an ``exception``. With
294+ threads, these object are collected by the main and not the subprocess. So, we just
295+ return placeholders.
296+
297+ """
298+ __tracebackhide__ = True
299+ try :
300+ func (** kwargs )
301+ except Exception :
302+ exc_info = sys .exc_info ()
303+ else :
304+ exc_info = None
305+ return [], exc_info
306+
307+
232308def _create_kwargs_for_task (task : Task ) -> dict [Any , Any ]:
233309 """Create kwargs for task function."""
234310 kwargs = {** task .kwargs }
0 commit comments