Skip to content
Merged
32 changes: 32 additions & 0 deletions e2e/pipelines/test_named_entity_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import sys

import pytest
from thinc.api import NumpyOps, get_current_ops, set_current_ops

from haystack import Document, Pipeline
from haystack.components.extractors import (
Expand Down Expand Up @@ -132,3 +133,34 @@ def _check_predictions(predicted, expected) -> None:
assert a.entity == b.entity
assert a.start == b.start
assert a.end == b.end

class TestNamedEntityExtractorDeviceRestoration:
def test_spacy_backend_restores_device_state(self):
"""
Verify that NamedEntityExtractor (spaCy) restores the previous Thinc Ops state
after the component runs.
"""
# 1. Setup a custom state
custom_ops = NumpyOps()
setattr(custom_ops, "owner", "test_marker")
set_current_ops(custom_ops)

try:
# 2. Initialize and run (triggering the context manager)
extractor = NamedEntityExtractor(backend="spacy", model="en_core_web_sm")

# Since _SpacyBackend is private, we access it via getattr to avoid IDE warnings
backend = getattr(extractor, "_backend")
select_device_method = getattr(backend, "_select_device")

with select_device_method():
# Inside the context, the state might change
pass

# 3. Verify state is restored
final_ops = get_current_ops()
assert getattr(final_ops, "owner", None) == "test_marker"

finally:
# Clean up global state
set_current_ops(NumpyOps())
11 changes: 3 additions & 8 deletions haystack/components/extractors/named_entity_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
with LazyImport(message="Run 'pip install spacy'") as spacy_import:
import spacy
from spacy import Language as SpacyPipeline
from thinc.api import get_current_ops, set_current_ops


class NamedEntityExtractorBackend(Enum):
Expand Down Expand Up @@ -492,17 +493,11 @@ def _select_device(self) -> Iterator[None]:
"""
Context manager used to run spaCy models on a specific GPU in a scoped manner.
"""

# TODO: This won't restore the active device.
# Since there are no opaque API functions to determine
# the active device in spaCy/Thinc, we can't do much
# about it as a consumer unless we start poking into their
# internals.
device_id = self._device.to_spacy()
previous_ops = get_current_ops()
try:
if device_id >= 0:
spacy.require_gpu(device_id)
yield
finally:
if device_id >= 0:
spacy.require_cpu()
set_current_ops(previous_ops)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
fixes:
- |
Fixed a bug in ``NamedEntityExtractor`` where the spaCy/Thinc device state was not correctly
restored after execution, potentially affecting the device configuration of other spaCy
components in the same process.
Loading