Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions haystack/components/extractors/llm_metadata_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,10 +439,14 @@ async def run_async(self, documents: list[Document], page_range: list[str | int]
# Create ChatMessage prompts for each document
all_prompts = self._prepare_prompts(documents=documents, expanded_range=expanded_range)

# Run the LLM on each prompt
# Run the LLM on each prompt, bounding concurrency per task so max_workers is enforced.
sem = Semaphore(max(1, self.max_workers))
async with sem:
results = await gather(*[self._run_async(prompt) for prompt in all_prompts])

async def _bounded_run(prompt: ChatMessage | None) -> dict[str, Any]:
async with sem:
return await self._run_async(prompt)

results = await gather(*[_bounded_run(prompt) for prompt in all_prompts])

successful_documents, failed_documents = self._process_results(documents, results)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
fixes:
- |
Fixed a bug in ``LLMMetadataExtractor.run_async`` where the ``asyncio.Semaphore``
intended to bound concurrent LLM calls to ``max_workers`` was acquired once
around the outer ``gather(...)`` call instead of inside each task. As a result,
``max_workers`` had no effect in ``run_async`` and all LLM requests for a batch
were issued simultaneously. The semaphore is now acquired per task, so
``max_workers`` correctly caps in-flight requests.
36 changes: 36 additions & 0 deletions test/components/extractors/test_llm_metadata_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

import asyncio
import os
from unittest.mock import Mock

Expand Down Expand Up @@ -345,6 +346,41 @@ async def test_run_with_document_content_none_async(self, monkeypatch: pytest.Mo
# Ensure no attempt was made to call the LLM
mock_chat_generator.run_async.assert_not_called()

@pytest.mark.asyncio
async def test_run_async_respects_max_workers(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")

max_workers = 2
in_flight = 0
peak_in_flight = 0

mock_chat_generator = Mock(spec=OpenAIChatGenerator)

async def fake_run_async(messages, **kwargs):
nonlocal in_flight, peak_in_flight
in_flight += 1
peak_in_flight = max(peak_in_flight, in_flight)
try:
await asyncio.sleep(0.01)
return {"replies": [ChatMessage.from_assistant('{"entities": []}')]}
finally:
in_flight -= 1

mock_chat_generator.run_async = fake_run_async

extractor = LLMMetadataExtractor(
prompt="prompt {{document.content}}",
chat_generator=mock_chat_generator,
expected_keys=["entities"],
max_workers=max_workers,
)

docs = [Document(content=f"doc {i}") for i in range(10)]
result = await extractor.run_async(documents=docs)

assert len(result["documents"]) == 10
assert peak_in_flight <= max_workers

@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
Expand Down
Loading