Skip to content

Commit 44df444

Browse files
timsaucerclaude
andcommitted
docs(distributing-work): drop __main__ guard from example to match site style
Other code blocks in the user guide present snippets inline at module level; the worker-pool example was the only one using ``if __name__ == "__main__":``. Restructure as two blocks (worker function + driver code), both inline, with a prose note explaining when the guard is actually needed (saving to a .py file and running under ``spawn`` / ``forkserver``). Matches the look of the surrounding docs and keeps the snippet copy-pasteable for the interactive case. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 9a4af41 commit 44df444

1 file changed

Lines changed: 27 additions & 16 deletions

File tree

docs/source/user-guide/io/distributing_work.rst

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,13 @@ expression; the receiver does not need to pre-register them.
5555
Basic worker-pool example
5656
~~~~~~~~~~~~~~~~~~~~~~~~~
5757

58-
.. code-block:: python
58+
Define a worker function that takes the expression plus a batch and
59+
returns the evaluated result:
5960

60-
import multiprocessing as mp
61+
.. code-block:: python
6162
6263
import pyarrow as pa
63-
from datafusion import SessionContext, col, udf
64+
from datafusion import SessionContext
6465
6566
6667
def evaluate(expr, batch):
@@ -70,21 +71,31 @@ Basic worker-pool example
7071
df = ctx.from_pydict({"a": batch})
7172
return df.with_column("result", expr).select("result").to_pydict()["result"]
7273
74+
Then build the expression in the driver and fan it out:
7375

74-
if __name__ == "__main__":
75-
double = udf(
76-
lambda arr: pa.array([(v.as_py() or 0) * 2 for v in arr]),
77-
[pa.int64()], pa.int64(), volatility="immutable", name="double",
76+
.. code-block:: python
77+
78+
import multiprocessing as mp
79+
from datafusion import col, udf
80+
81+
double = udf(
82+
lambda arr: pa.array([(v.as_py() or 0) * 2 for v in arr]),
83+
[pa.int64()], pa.int64(), volatility="immutable", name="double",
84+
)
85+
expr = double(col("a"))
86+
87+
mp_ctx = mp.get_context("forkserver")
88+
with mp_ctx.Pool(processes=4) as pool:
89+
results = pool.starmap(
90+
evaluate,
91+
[(expr, [1, 2, 3]), (expr, [10, 20, 30])],
7892
)
79-
expr = double(col("a"))
80-
81-
mp_ctx = mp.get_context("forkserver")
82-
with mp_ctx.Pool(processes=4) as pool:
83-
results = pool.starmap(
84-
evaluate,
85-
[(expr, [1, 2, 3]), (expr, [10, 20, 30])],
86-
)
87-
print(results) # [[2, 4, 6], [20, 40, 60]]
93+
print(results) # [[2, 4, 6], [20, 40, 60]]
94+
95+
When saved to a ``.py`` file and executed with the ``spawn`` or
96+
``forkserver`` start method, wrap the driver block in
97+
``if __name__ == "__main__":`` so worker processes can re-import the
98+
module without re-running it.
8899

89100

90101
What travels with the expression

0 commit comments

Comments
 (0)