66import time
77import warnings
88from concurrent .futures import Future
9+ from functools import partial
910from pathlib import Path
1011from types import ModuleType
1112from types import TracebackType
@@ -296,23 +297,7 @@ def _execute_task( # noqa: PLR0913
296297 exc_info , show_locals , console_options
297298 )
298299 else :
299- if "return" in task .produces :
300- structure_out = tree_structure (out )
301- structure_return = tree_structure (task .produces ["return" ])
302- # strict must be false when none is leaf.
303- if not structure_return .is_prefix (structure_out , strict = False ):
304- msg = (
305- "The structure of the return annotation is not a subtree of "
306- "the structure of the function return.\n \n Function return: "
307- f"{ structure_out } \n \n Return annotation: { structure_return } "
308- )
309- raise ValueError (msg )
310-
311- nodes = tree_leaves (task .produces ["return" ])
312- values = structure_return .flatten_up_to (out )
313- for node , value in zip (nodes , values ):
314- node .save (value )
315-
300+ _handle_task_function_return (task , out )
316301 processed_exc_info = None
317302
318303 task_display_name = getattr (task , "display_name" , task .name )
@@ -347,6 +332,27 @@ def _process_exception(
347332 return (* exc_info [:2 ], text )
348333
349334
335+ def _handle_task_function_return (task : PTask , out : Any ) -> None :
336+ if "return" not in task .produces :
337+ return
338+
339+ structure_out = tree_structure (out )
340+ structure_return = tree_structure (task .produces ["return" ])
341+ # strict must be false when none is leaf.
342+ if not structure_return .is_prefix (structure_out , strict = False ):
343+ msg = (
344+ "The structure of the return annotation is not a subtree of "
345+ "the structure of the function return.\n \n Function return: "
346+ f"{ structure_out } \n \n Return annotation: { structure_return } "
347+ )
348+ raise ValueError (msg )
349+
350+ nodes = tree_leaves (task .produces ["return" ])
351+ values = structure_return .flatten_up_to (out )
352+ for node , value in zip (nodes , values ):
353+ node .save (value )
354+
355+
350356class DefaultBackendNameSpace :
351357 """The name space for hooks related to threads."""
352358
@@ -362,13 +368,13 @@ def pytask_execute_task(session: Session, task: Task) -> Future[Any] | None:
362368 if session .config ["n_workers" ] > 1 :
363369 kwargs = _create_kwargs_for_task (task )
364370 return session .config ["_parallel_executor" ].submit (
365- _mock_processes_for_threads , func = task . execute , ** kwargs
371+ _mock_processes_for_threads , task = task , ** kwargs
366372 )
367373 return None
368374
369375
370376def _mock_processes_for_threads (
371- func : Callable [..., Any ] , ** kwargs : Any
377+ task : PTask , ** kwargs : Any
372378) -> tuple [
373379 None , list [Any ], tuple [type [BaseException ], BaseException , TracebackType ] | None
374380]:
@@ -381,10 +387,11 @@ def _mock_processes_for_threads(
381387 """
382388 __tracebackhide__ = True
383389 try :
384- func (** kwargs )
390+ out = task . function (** kwargs )
385391 except Exception : # noqa: BLE001
386392 exc_info = sys .exc_info ()
387393 else :
394+ _handle_task_function_return (task , out )
388395 exc_info = None
389396 return None , [], exc_info
390397
@@ -430,18 +437,17 @@ def sleep(self) -> None:
430437def _get_module (func : Callable [..., Any ], path : Path | None ) -> ModuleType :
431438 """Get the module of a python function.
432439
433- For Python <3.10, functools.partial does not set a `__module__` attribute which is
434- why ``inspect.getmodule`` returns ``None`` and ``cloudpickle.pickle_by_value``
435- fails. In later versions, ``functools`` is returned and everything seems to work
436- fine.
440+ ``functools.partial`` obfuscates the module of the function and
441+ ``inspect.getmodule`` returns :mod`functools`. Therefore, we recover the original
442+ function.
437443
438- Therefore, we use the path from the task module to aid the search which works for
439- Python <3.10.
440-
441- We do not unwrap the partialed function with ``func.func``, since pytask in general
442- does not really support ``functools.partial``. Instead, use ``@task(kwargs=...)``.
444+ We use the path from the task module to aid the search although it is not clear
445+ whether it helps.
443446
444447 """
448+ if isinstance (func , partial ):
449+ func = func .func
450+
445451 if path :
446452 return inspect .getmodule (func , path .as_posix ())
447453 return inspect .getmodule (func )
0 commit comments