Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions langfun/core/concurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def concurrent_execute(
exponential_backoff: bool = True,
max_retry_interval: int = 300,
return_jobs: bool = False,
shutdown_wait_on_timeout: bool = False,
) -> list[Any]:
"""Executes a function concurrently under current component context.

Expand Down Expand Up @@ -233,6 +234,11 @@ def square(x):
exponentially.
return_jobs: If True, return a list of `Job` objects. Otherwise, return a
list of outputs.
shutdown_wait_on_timeout: If True, wait for workers to finish during
shutdown in the finally block. Used because workers blocked in
synchronous HTTP cannot honor future.cancel(); join is required to
prevent orphan-thread accumulation that exhausts the LM max_concurrency
semaphore and deadlocks ParallelRunner.

Returns:
A list of ouputs. Each is the return value of `func` based on the input
Expand Down Expand Up @@ -281,7 +287,7 @@ def square(x):
finally:
if shutdown_after_finish:
# Do not wait threads to finish if they are timed out.
executor.shutdown(wait=False, cancel_futures=True)
executor.shutdown(wait=shutdown_wait_on_timeout, cancel_futures=True)


@dataclasses.dataclass
Expand Down Expand Up @@ -678,6 +684,7 @@ def concurrent_map(
retry_interval: int | tuple[int, int] = (5, 60),
exponential_backoff: bool = True,
return_jobs: bool = False,
shutdown_wait_on_timeout: bool = False,
) -> Iterator[Any]:
"""Maps inputs to outptus via func concurrently under current context.

Expand Down Expand Up @@ -754,6 +761,11 @@ def flaky_square(x):
exponential_backoff: If True, exponential wait time will be applied on top
of the base retry interval.
return_jobs: If True, the returned iterator will emit `Job` objects.
shutdown_wait_on_timeout: If True, wait for workers to finish during
shutdown in the finally block. Used because workers blocked in
synchronous HTTP cannot honor future.cancel(); join is required to
prevent orphan-thread accumulation that exhausts the LM max_concurrency
semaphore and deadlocks ParallelRunner.

Yields:
An iterator of (input, output, error) or Job object.
Expand Down Expand Up @@ -895,7 +907,7 @@ def update_progress_bar(progress: Progress) -> None:

if shutdown_after_finish:
# Do not wait threads to finish if they are timed out.
executor.shutdown(wait=False, cancel_futures=True)
executor.shutdown(wait=shutdown_wait_on_timeout, cancel_futures=True)


class ExecutorPool:
Expand Down
291 changes: 291 additions & 0 deletions langfun/core/concurrent_shutdown_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
# Copyright 2025 The Langfun Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Adversarial tests for shutdown_wait_on_timeout in langfun.core.concurrent."""

from concurrent import futures
import gc
import threading
import time
import unittest
from unittest import mock
import weakref

from langfun.core import concurrent


def _settle(baseline: int, max_delta: int = 2, budget_s: float = 2.0) -> int:
deadline = time.monotonic() + budget_s
while time.monotonic() < deadline:
n = threading.active_count()
if n <= baseline + max_delta:
return n
time.sleep(0.05)
return threading.active_count()


class ShutdownWaitOnTimeoutTest(unittest.TestCase):

def setUp(self):
super().setUp()
gc.collect()
self._baseline = threading.active_count()

def tearDown(self):
super().tearDown()
gc.collect()

def test_baseline_thread_cleanup_with_wait_true(self):
def slow(_):
time.sleep(5.0)
return None

t0 = time.monotonic()
for _ in range(5):
try:
list(
concurrent.concurrent_map(
slow,
[1],
max_workers=1,
timeout=0.5,
silence_on_errors=TimeoutError,
shutdown_wait_on_timeout=True,
)
)
except TimeoutError:
pass
n = _settle(self._baseline)
self.assertLessEqual(
n, self._baseline + 1, f"leak after iter: {n} vs {self._baseline}"
)
self.assertLess(time.monotonic() - t0, 5 * (0.5 + 5.0 + 1.0))

def test_backward_compat_default_false_still_leaks(self):
def slow(_):
time.sleep(5.0)
return None

t0 = time.monotonic()
for _ in range(3):
try:
list(
concurrent.concurrent_map(
slow,
[1],
max_workers=1,
timeout=0.5,
silence_on_errors=TimeoutError,
)
)
except TimeoutError:
pass
elapsed = time.monotonic() - t0
self.assertGreater(
threading.active_count(),
self._baseline,
"default=False should preserve leak signature",
)
self.assertLess(elapsed, 3 * (0.5 + 1.0))

def test_worker_ignoring_cancellation_bounded_shutdown(self):
stop = threading.Event()

def cooperative(_):
while not stop.is_set():
time.sleep(0.05)

ex = futures.ThreadPoolExecutor(max_workers=1)
ex.submit(cooperative, 0)
time.sleep(0.1)
shutdown_done = threading.Event()

def do_shutdown():
ex.shutdown(wait=True)
shutdown_done.set()

t = threading.Thread(target=do_shutdown)
t.start()
t.join(timeout=2.0)
self.assertTrue(t.is_alive())
self.assertFalse(shutdown_done.is_set())
stop.set()
t.join(timeout=2.0)
self.assertFalse(t.is_alive())
self.assertTrue(shutdown_done.is_set())
_settle(self._baseline)

def test_worker_raises_during_shutdown(self):
class BoomError(RuntimeError):
pass

def raiser(_):
raise BoomError("teardown")

with self.assertRaises(BoomError):
list(
concurrent.concurrent_map(
raiser,
[1],
max_workers=1,
silence_on_errors=None,
shutdown_wait_on_timeout=True,
)
)
n = _settle(self._baseline)
self.assertLessEqual(n, self._baseline + 1)
out = list(
concurrent.concurrent_map(
lambda x: x * 2,
[1, 2, 3],
max_workers=2,
shutdown_wait_on_timeout=True,
)
)
self.assertEqual(len(out), 3)

def test_concurrent_shutdown_calls(self):
ex = futures.ThreadPoolExecutor(max_workers=2)
for _ in range(2):
ex.submit(lambda: time.sleep(0.05))
t0 = time.monotonic()

def call_shutdown():
ex.shutdown(wait=True)

t1 = threading.Thread(target=call_shutdown)
t2 = threading.Thread(target=call_shutdown)
t1.start()
t2.start()
t1.join(timeout=5.0)
t2.join(timeout=5.0)
self.assertFalse(t1.is_alive() or t2.is_alive(), "deadlock")
self.assertLess(time.monotonic() - t0, 5.0)
_settle(self._baseline)

def test_nested_executor_pools(self):
def inner(_):
try:
list(
concurrent.concurrent_map(
lambda x: time.sleep(5.0),
[1],
max_workers=1,
timeout=0.3,
silence_on_errors=TimeoutError,
shutdown_wait_on_timeout=True,
)
)
except TimeoutError:
pass
return "inner-done"

t0 = time.monotonic()
out = list(
concurrent.concurrent_map(
inner, [1, 2], max_workers=2, shutdown_wait_on_timeout=True
)
)
self.assertEqual(len(out), 2)
self.assertLess(time.monotonic() - t0, 30.0)
n = _settle(self._baseline)
self.assertLessEqual(n, self._baseline + 2)

def test_high_fanout_stress_50_workers(self):
def mixed(i):
time.sleep(5.0 if i % 2 == 0 else 0.1)
return i

t0 = time.monotonic()
try:
list(
concurrent.concurrent_map(
mixed,
list(range(50)),
max_workers=50,
timeout=0.5,
silence_on_errors=TimeoutError,
shutdown_wait_on_timeout=True,
)
)
except TimeoutError:
pass
self.assertLess(time.monotonic() - t0, 10.0)
n = _settle(self._baseline, max_delta=2, budget_s=6.0)
self.assertLessEqual(
n, self._baseline + 1, f"leak after stress: {n} vs {self._baseline}"
)

# (h)
def test_num_attempts_retry_interaction(self):
"""Simulate Gemini retry loop: 10 attempts each timing out, bounded threads."""

def slow(_):
time.sleep(5.0)
return None

max_workers = 2
for _ in range(10):
try:
list(
concurrent.concurrent_map(
slow,
[1],
max_workers=max_workers,
timeout=0.2,
silence_on_errors=TimeoutError,
shutdown_wait_on_timeout=True,
)
)
except TimeoutError:
pass
n = threading.active_count()
self.assertLessEqual(
n,
self._baseline + max_workers,
f"unbounded growth: {n} vs baseline {self._baseline}",
)
final = _settle(self._baseline)
self.assertLessEqual(final, self._baseline + 1)

def test_resource_fault_injection_on_shutdown(self):
ex = futures.ThreadPoolExecutor(max_workers=1)
real_shutdown = ex.shutdown
calls = {"n": 0}

def flaky(*a, **kw):
calls["n"] += 1
if calls["n"] == 1:
raise OSError("injected")
return real_shutdown(*a, **kw)

ex.submit(lambda: time.sleep(0.05))
with mock.patch.object(ex, "shutdown", side_effect=flaky):
with self.assertRaises(OSError):
ex.shutdown(wait=True)
real_shutdown(wait=True)
n = _settle(self._baseline)
self.assertLessEqual(n, self._baseline + 1)

def test_executor_gc_after_shutdown(self):
ex = futures.ThreadPoolExecutor(max_workers=1)
list(ex.map(lambda x: x, [1, 2, 3]))
ex.shutdown(wait=True)
ref = weakref.ref(ex)
del ex
gc.collect()
self.assertIsNone(ref(), "executor not collected — strong ref leaked")


if __name__ == "__main__":
unittest.main()
Loading