Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 29 additions & 17 deletions haystack/core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
from haystack import logging
from haystack.core.component.component import _hook_component_init
from haystack.core.errors import DeserializationError, SerializationError
from haystack.utils.auth import Secret, deserialize_secrets_inplace
from haystack.utils.auth import Secret
from haystack.utils.device import ComponentDevice
from haystack.utils.type_serialization import thread_safe_import

logger = logging.getLogger(__name__)


T = TypeVar("T")


Expand Down Expand Up @@ -183,8 +184,8 @@ def default_to_dict(obj: Any, **init_parameters: Any) -> dict[str, Any]:
instance of `obj` with `from_dict`. Omitting them might cause deserialisation
errors or unexpected behaviours later, when calling `from_dict`.

Secret and ComponentDevice instances in `init_parameters` are automatically
serialized by calling their `to_dict()` method.
Objects in `init_parameters` that have a `to_dict()` method are automatically
serialized by calling that method.

An example usage:

Expand Down Expand Up @@ -214,10 +215,10 @@ def to_dict(self):
:returns:
A dictionary representation of the instance.
"""
# Automatically serialize Secret and ComponentDevice instances
# Automatically serialize objects that have a to_dict method
serialized_params = {}
for key, value in init_parameters.items():
if isinstance(value, (Secret, ComponentDevice)):
if value is not None and hasattr(value, "to_dict") and callable(getattr(value, "to_dict")):
serialized_params[key] = value.to_dict()
else:
serialized_params[key] = value
Expand Down Expand Up @@ -266,6 +267,10 @@ def default_from_dict(cls: type[T], data: dict[str, Any]) -> T:
and deserialized. A dictionary is considered a serialized ComponentDevice if it has a
"type" key with value "single" or "multiple".

Objects in `init_parameters` that are dictionaries with a "type" key containing a fully
qualified class name are automatically detected and deserialized if the class has a
`from_dict()` method.

:param cls:
The class to be used for deserialization.
:param data:
Expand All @@ -282,19 +287,26 @@ def default_from_dict(cls: type[T], data: dict[str, Any]) -> T:
if data["type"] != generate_qualified_class_name(cls):
raise DeserializationError(f"Class '{data['type']}' can't be deserialized as '{cls.__name__}'")

# Automatically detect and deserialize Secret instances
secret_keys = []
for key, value in init_params.items():
if isinstance(value, dict) and value.get("type") == "env_var":
secret_keys.append(key)

if secret_keys:
deserialize_secrets_inplace(init_params, keys=secret_keys)

# Automatically detect and deserialize ComponentDevice instances
# Automatically detect and deserialize objects with from_dict methods
for key, value in init_params.items():
if _is_serialized_component_device(value):
init_params[key] = ComponentDevice.from_dict(value)
if isinstance(value, dict) and "type" in value:
type_value = value.get("type")
# Special handling for Secret (type == "env_var")
if type_value == "env_var":
init_params[key] = Secret.from_dict(value)
# Special handling for ComponentDevice (type == "single" or "multiple")
elif _is_serialized_component_device(value):
init_params[key] = ComponentDevice.from_dict(value)
# If type looks like a fully qualified class name, try to import it and deserialize
elif isinstance(type_value, str) and "." in type_value:
try:
imported_class = import_class_by_name(type_value)
if hasattr(imported_class, "from_dict") and callable(getattr(imported_class, "from_dict")):
init_params[key] = imported_class.from_dict(value)
else:
init_params[key] = default_from_dict(imported_class, value)
except (ImportError, DeserializationError) as e:
raise type(e)(f"Failed to deserialize '{key}': {e}") from e

return cls(**init_params)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
features:
- |
component_from_dict/component_to_dict now work out of the box with custom components that have an object as init parameter as long as the object defines to_dict/from_dict methods.
Users no longer need to explicitly define to_dict/from_dict methods in their custom components in such cases.
For example, a custom retriever, which has a DocumentStore as an init parameter, does not need explicitly defined to_dict/from_dict methods.
component_from_dict/component_to_dict now handle such cases automatically.
139 changes: 139 additions & 0 deletions test/core/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
generate_qualified_class_name,
import_class_by_name,
)
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.testing import factory
from haystack.utils import ComponentDevice, Secret
from haystack.utils.auth import EnvVarSecret
Expand Down Expand Up @@ -431,3 +432,141 @@ def test_component_to_dict_and_from_dict_roundtrip_with_component_device():
assert deserialized_comp.device.to_torch_str() == "cpu"
assert deserialized_comp.other_device.to_torch_str() == "cuda:0"
assert deserialized_comp.name == "test"


@component
class CustomComponentWithDocumentStore:
def __init__(self, document_store: InMemoryDocumentStore | None = None, name: str | None = None):
self.document_store = document_store
self.name = name

@component.output_types(value=str)
def run(self, value: str) -> dict[str, str]:
return {"value": value}


def test_component_to_dict_with_document_store():
"""Test that DocumentStore instances are automatically serialized in component_to_dict."""
# Test with InMemoryDocumentStore
doc_store = InMemoryDocumentStore()
comp = CustomComponentWithDocumentStore(document_store=doc_store)
res = component_to_dict(comp, "test_component")
assert "type" in res["init_parameters"]["document_store"]
assert "init_parameters" in res["init_parameters"]["document_store"]
assert (
res["init_parameters"]["document_store"]["type"]
== "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore"
)

# Test with None
comp = CustomComponentWithDocumentStore(document_store=None)
res = component_to_dict(comp, "test_component")
assert res["init_parameters"]["document_store"] is None


def test_component_from_dict_with_document_store():
"""Test that serialized DocumentStore dictionaries are automatically deserialized in component_from_dict."""
# Test with InMemoryDocumentStore
doc_store = InMemoryDocumentStore()
serialized_doc_store = doc_store.to_dict()
data = {
"type": generate_qualified_class_name(CustomComponentWithDocumentStore),
"init_parameters": {"document_store": serialized_doc_store, "name": "test"},
}
comp = component_from_dict(CustomComponentWithDocumentStore, data, "test_component")
assert isinstance(comp, CustomComponentWithDocumentStore)
assert isinstance(comp.document_store, InMemoryDocumentStore)
assert comp.name == "test"

# Test with None
data = {
"type": generate_qualified_class_name(CustomComponentWithDocumentStore),
"init_parameters": {"document_store": None, "name": "test"},
}
comp = component_from_dict(CustomComponentWithDocumentStore, data, "test_component")
assert comp.document_store is None
assert comp.name == "test"


def test_component_to_dict_and_from_dict_roundtrip_with_document_store():
"""Test that serialization and deserialization work together for DocumentStore."""
# Test roundtrip with InMemoryDocumentStore
original_doc_store = InMemoryDocumentStore()
comp = CustomComponentWithDocumentStore(document_store=original_doc_store)

serialized = component_to_dict(comp, "test_component")
assert "type" in serialized["init_parameters"]["document_store"]
assert (
serialized["init_parameters"]["document_store"]["type"]
== "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore"
)

deserialized_comp = component_from_dict(CustomComponentWithDocumentStore, serialized, "test_component")
assert isinstance(deserialized_comp.document_store, InMemoryDocumentStore)
assert deserialized_comp.document_store.bm25_algorithm == original_doc_store.bm25_algorithm
assert (
deserialized_comp.document_store.embedding_similarity_function
== original_doc_store.embedding_similarity_function
)

# Test roundtrip with custom parameters
original_doc_store = InMemoryDocumentStore(
bm25_algorithm="BM25Okapi", embedding_similarity_function="cosine", return_embedding=False
)
comp = CustomComponentWithDocumentStore(document_store=original_doc_store)

serialized = component_to_dict(comp, "test_component")
deserialized_comp = component_from_dict(CustomComponentWithDocumentStore, serialized, "test_component")
assert isinstance(deserialized_comp.document_store, InMemoryDocumentStore)
assert deserialized_comp.document_store.bm25_algorithm == "BM25Okapi"
assert deserialized_comp.document_store.embedding_similarity_function == "cosine"
assert deserialized_comp.document_store.return_embedding is False


def test_default_to_dict_with_document_store():
"""Test that DocumentStore instances are automatically serialized in default_to_dict."""
doc_store = InMemoryDocumentStore()
res = default_to_dict(doc_store)
assert res["type"] == "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore"
assert "init_parameters" in res

# Test that DocumentStore is serialized when passed as a parameter
doc_store = InMemoryDocumentStore()
comp = CustomComponentWithDocumentStore(document_store=doc_store)
res = default_to_dict(comp, document_store=doc_store, name="test")
assert "type" in res["init_parameters"]["document_store"]
assert res["init_parameters"]["name"] == "test"


def test_default_from_dict_with_document_store():
"""Test that serialized DocumentStore dictionaries are automatically deserialized in default_from_dict."""
doc_store = InMemoryDocumentStore()
serialized = doc_store.to_dict()

# Test direct deserialization
deserialized = default_from_dict(InMemoryDocumentStore, serialized)
assert isinstance(deserialized, InMemoryDocumentStore)
assert deserialized.bm25_algorithm == doc_store.bm25_algorithm

# Test deserialization when DocumentStore is in init_parameters
data = {
"type": generate_qualified_class_name(CustomComponentWithDocumentStore),
"init_parameters": {"document_store": serialized, "name": "test"},
}
comp = default_from_dict(CustomComponentWithDocumentStore, data)
assert isinstance(comp.document_store, InMemoryDocumentStore)
assert comp.name == "test"


def test_default_from_dict_with_invalid_class_name():
"""Test that deserialization raises ImportError with improved message when class cannot be imported."""
data = {
"type": generate_qualified_class_name(CustomComponentWithDocumentStore),
"init_parameters": {
"document_store": {"type": "nonexistent.module.Class", "init_parameters": {}},
"name": "test",
},
}
# Verify the error message includes the parameter key and original error
with pytest.raises(ImportError, match=r"Failed to deserialize 'document_store':.*nonexistent\.module\.Class"):
default_from_dict(CustomComponentWithDocumentStore, data)
Loading