Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion haystack/components/generators/chat/openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def _prepare_api_call( # noqa: PLR0913
# Convert all tool objects to the correct OpenAI-compatible structure
else:
# mypy can't infer that tools is ToolsType here
flattened_tools = flatten_tools_or_toolsets(tools)
flattened_tools = flatten_tools_or_toolsets(tools) # type: ignore[arg-type]
_check_duplicate_tool_names(flattened_tools)
for t in flattened_tools:
function_spec = {**t.tool_spec}
Expand Down
13 changes: 6 additions & 7 deletions haystack/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# ruff: noqa: I001 (ignore import order as we need to import Tool before ComponentTool and PipelineTool)

from collections.abc import Sequence
from haystack.tools.from_function import create_tool_from_function, tool
from haystack.tools.tool import Tool, _check_duplicate_tool_names
from haystack.tools.toolset import Toolset
Expand All @@ -15,13 +16,11 @@
from haystack.tools.serde_utils import deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset
from haystack.tools.utils import flatten_tools_or_toolsets, warm_up_tools

# Type alias for tools parameter - allows mixing Tools and Toolsets in a list
# Explicitly list all valid combinations due to list invariance:
# - list[Tool]: Most common pattern - list of Tool objects
# - list[Toolset]: Less common pattern - list of Toolset objects
# - list[Union[Tool, Toolset]]: Mixing Tools and Toolsets in one list
# - Toolset: Single Toolset (not in a list)
ToolsType = list[Tool] | list[Toolset] | list[Tool | Toolset] | Toolset
# Type alias for tools parameter - allows mixing Tools and Toolsets in a sequence
# Accepts either:
# - Sequence[Tool | Toolset]: Any sequence (list, tuple, etc.) containing Tools, Toolsets, or a mix of both
# - Toolset: A single Toolset (not in a sequence)
ToolsType = Sequence[Tool | Toolset] | Toolset

__all__ = [
"_check_duplicate_tool_names",
Expand Down
26 changes: 25 additions & 1 deletion haystack/tools/from_function.py
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anakin87 do the changes to this file (addition of the overload) fix the typing issues you were seeing?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes!

from haystack.components.generators.chat.openai import OpenAIChatGenerator
from haystack.tools import tool


@tool
def my_tool(number: int) -> int:
    return number * 2


generator = OpenAIChatGenerator(tools=[my_tool])

running mypy on this example no longer gives errors

Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import inspect
from collections.abc import Callable
from typing import Any
from typing import Any, overload

from pydantic import create_model

Expand Down Expand Up @@ -188,6 +188,30 @@ def get_weather(
)


@overload
def tool(
function: Callable,
*,
name: str | None = None,
description: str | None = None,
inputs_from_state: dict[str, str] | None = None,
outputs_to_state: dict[str, dict[str, Any]] | None = None,
outputs_to_string: dict[str, Any] | None = None,
) -> Tool: ...


@overload
def tool(
function: None = None,
*,
name: str | None = None,
description: str | None = None,
inputs_from_state: dict[str, str] | None = None,
outputs_to_state: dict[str, dict[str, Any]] | None = None,
outputs_to_string: dict[str, Any] | None = None,
) -> Callable[[Callable], Tool]: ...


def tool(
function: Callable | None = None,
*,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ integration-only-slow = 'pytest --maxfail=5 -m "integration and slow" {args:test
all = 'pytest {args:test}'

# TODO We want to eventually type the whole test folder
types = "mypy --install-types --non-interactive --cache-dir=.mypy_cache/ {args:haystack test/core/ test/marshal/ test/testing/ test/tracing/ test/human_in_the_loop test/evaluation test/document_stores test/dataclasses}"
types = "mypy --install-types --non-interactive --cache-dir=.mypy_cache/ {args:haystack test/core/ test/marshal/ test/testing/ test/tracing/ test/tools/ test/human_in_the_loop test/evaluation test/document_stores test/dataclasses}"

[tool.hatch.envs.e2e]
template = "test"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
Update ``ToolsType`` to improve type checking for the ``tools`` parameter. Any class that inherits from either ``Tool`` or ``Toolset`` is now accepted in any sequence (list, tuple, etc).
35 changes: 20 additions & 15 deletions test/tools/test_component_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ def run(self, messages: list[ChatMessage]) -> dict[str, str]:
class SimpleComponent:
"""A simple component that generates text."""

def warm_up(self):
"""
Prepare the component for use.
"""

@component.output_types(reply=str)
def run(self, text: str) -> dict[str, str]:
"""
Expand Down Expand Up @@ -143,7 +148,7 @@ def run(self, documents: list[Document], top_k: int = 5) -> dict[str, str]:
:param top_k: The number of top documents to concatenate
:returns: Dictionary containing the concatenated document contents
"""
return {"concatenated": "\n".join(doc.content for doc in documents[:top_k])}
return {"concatenated": "\n".join(doc.content for doc in documents[:top_k] if doc.content)}


@component
Expand Down Expand Up @@ -215,7 +220,7 @@ def test_from_component_with_inputs_from_state_different_names(self):
def test_from_component_with_invalid_inputs_from_state_nested_dict(self):
"""Test that ComponentTool rejects nested dict format for inputs_from_state"""
with pytest.raises(TypeError, match="must be str, not dict"):
ComponentTool(component=SimpleComponent(), inputs_from_state={"documents": {"source": "documents"}})
ComponentTool(component=SimpleComponent(), inputs_from_state={"documents": {"source": "documents"}}) # type: ignore[dict-item]

def test_from_component_with_outputs_to_state(self):
tool = ComponentTool(component=SimpleComponent(), outputs_to_state={"replies": {"source": "reply"}})
Expand Down Expand Up @@ -369,13 +374,13 @@ def test_from_component_with_dynamic_input_types(self):

def test_from_component_with_invalid_component(self):
class NotAComponent:
def foo(self, text: str):
def foo(self, text: str) -> dict[str, str]:
return {"reply": f"Hello, {text}!"}

not_a_component = NotAComponent()

with pytest.raises(TypeError):
ComponentTool(component=not_a_component, name="invalid_tool", description="This should fail")
ComponentTool(component=not_a_component, name="invalid_tool", description="This should fail") # type: ignore[arg-type]

def test_component_invoker_with_chat_message_input(self):
tool = ComponentTool(
Expand All @@ -392,7 +397,7 @@ class AnnotatedComponent:
"""An annotated component with descriptive parameter docstrings."""

@component.output_types(result=str)
def run(self, text: str, number: int = 42):
def run(self, text: str, number: int = 42) -> dict[str, str]:
"""
Process inputs and return result.

Expand Down Expand Up @@ -447,7 +452,7 @@ class ComponentA:
"""Component A with descriptive docstrings."""

@component.output_types(output_a=str)
def run(self, query: str):
def run(self, query: str) -> dict[str, str]:
"""
Process query in component A.

Expand All @@ -460,7 +465,7 @@ class ComponentB:
"""Component B with descriptive docstrings."""

@component.output_types(output_b=str)
def run(self, text: str):
def run(self, text: str) -> dict[str, str]:
"""
Process text in component B.

Expand Down Expand Up @@ -503,20 +508,20 @@ def run(self, text: str):

def test_warm_up_is_idempotent(self):
"""Test that calling warm_up multiple times only warms up the component once."""
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

component = SimpleComponent()
component.warm_up = MagicMock()

tool = ComponentTool(component=component)

# Call warm_up multiple times
tool.warm_up()
tool.warm_up()
tool.warm_up()
with patch.object(component, "warm_up", MagicMock()) as mock_warm_up:
# Call warm_up multiple times
tool.warm_up()
tool.warm_up()
tool.warm_up()

# Component's warm_up should only be called once
component.warm_up.assert_called_once()
# Component's warm_up should only be called once
mock_warm_up.assert_called_once()

def test_from_component_with_callable_params_skipped(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
Expand Down
4 changes: 2 additions & 2 deletions test/tools/test_from_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,15 @@ def function_with_annotations(


def test_from_function_missing_type_hint():
def function_missing_type_hint(city) -> str:
def function_missing_type_hint(city) -> str: # type: ignore[no-untyped-def]
return f"Weather report for {city}: 20°C, sunny"

with pytest.raises(ValueError):
create_tool_from_function(function=function_missing_type_hint)


def test_from_function_schema_generation_error():
def function_with_invalid_type_hint(city: "invalid") -> str: # noqa: F821
def function_with_invalid_type_hint(city: "invalid") -> str: # type: ignore[name-defined] # noqa: F821
return f"Weather report for {city}: 20°C, sunny"

with pytest.raises(SchemaGenerationError):
Expand Down
4 changes: 2 additions & 2 deletions test/tools/test_pipeline_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_init_invalid_pipeline(self):
with pytest.raises(
TypeError, match="The 'pipeline' parameter must be an instance of Pipeline or AsyncPipeline."
):
PipelineTool(pipeline="invalid_pipeline", name="test_tool", description="A test tool")
PipelineTool(pipeline="invalid_pipeline", name="test_tool", description="A test tool") # type: ignore[arg-type]

def test_to_dict(self, sample_pipeline, sample_pipeline_dict):
tool = PipelineTool(
Expand Down Expand Up @@ -381,7 +381,7 @@ def test_pipeline_tool_with_invalid_inputs_from_state_nested_dict(self, sample_p
output_mapping={"ranker.documents": "documents"},
name="test_tool",
description="A test tool",
inputs_from_state={"user_query": {"source": "query"}},
inputs_from_state={"user_query": {"source": "query"}}, # type: ignore[dict-item]
)

def test_pipeline_tool_with_valid_outputs_to_state(self, sample_pipeline):
Expand Down
36 changes: 19 additions & 17 deletions test/tools/test_searchable_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@


import os
from collections.abc import Callable
from typing import Any

import pytest

Expand Down Expand Up @@ -82,30 +84,28 @@ def small_catalog(weather_tool, add_tool, multiply_tool):
@pytest.fixture
def large_catalog():
"""Larger catalog that requires discovery (>= 8 tools)."""
return [
create_tool_from_function(fn)
for fn in [
get_weather,
add_numbers,
multiply_numbers,
get_stock_price,
search_database,
send_email,
calculate_tax,
convert_currency,
]
functions: list[Callable[..., Any]] = [
get_weather,
add_numbers,
multiply_numbers,
get_stock_price,
search_database,
send_email,
calculate_tax,
convert_currency,
]
return [create_tool_from_function(fn) for fn in functions]


class TestSearchableToolset:
def test_init_with_invalid_catalog(self):
with pytest.raises(TypeError):
SearchableToolset(catalog=123)
SearchableToolset(catalog=123) # type: ignore[arg-type]
with pytest.raises(TypeError):
SearchableToolset(catalog=[123])
SearchableToolset(catalog=[123]) # type: ignore[list-item]
with pytest.raises(TypeError):
SearchableToolset(
catalog=Tool(
catalog=Tool( # type: ignore[arg-type]
name="test",
description="test",
parameters={"type": "object", "properties": {}},
Expand All @@ -132,6 +132,7 @@ def test_not_implemented_methods(self):
def test_clear(self, large_catalog):
toolset = SearchableToolset(catalog=large_catalog)
toolset.warm_up()
assert toolset._bootstrap_tool is not None
toolset._bootstrap_tool.invoke(tool_keywords="weather temperature city")
assert len(toolset._discovered_tools) > 0
toolset.clear()
Expand Down Expand Up @@ -187,7 +188,7 @@ def test_passthrough_contains_by_tool_invalid_type(self, small_catalog):
toolset.warm_up()

with pytest.raises(TypeError):
123 in toolset # noqa: B015
123 in toolset # type: ignore[operator] # noqa: B015

def test_custom_search_threshold(self, large_catalog):
"""Test that custom search_threshold changes passthrough behavior."""
Expand Down Expand Up @@ -318,6 +319,7 @@ def test_contains_bootstrap_tool(self, large_catalog):
toolset.warm_up()

assert "search_tools" in toolset
assert toolset._bootstrap_tool is not None
assert toolset._bootstrap_tool in toolset

def test_contains_discovered_tool(self, large_catalog):
Expand Down Expand Up @@ -680,7 +682,7 @@ def warm_up(self) -> None:
for i in range(5)
]

toolset = SearchableToolset(catalog=[LazyToolset()] + eager_tools)
toolset = SearchableToolset(catalog=[LazyToolset(), *eager_tools])
toolset.warm_up()

# Should have 5 lazy + 5 eager = 10 tools
Expand Down
15 changes: 10 additions & 5 deletions test/tools/test_serde_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any

import pytest

from haystack.tools import Tool, Toolset, deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset
Expand Down Expand Up @@ -91,7 +93,7 @@ def test_deserialize_tools_inplace(self):
name="weather", description="Get weather report", parameters=parameters, function=get_weather_report
)

data = {"tools": [tool.to_dict()]}
data: dict[str, Any] = {"tools": [tool.to_dict()]}
deserialize_tools_or_toolset_inplace(data)
assert data["tools"] == [tool]

Expand All @@ -104,7 +106,7 @@ def test_deserialize_tools_inplace(self):
assert data == {"no_tools": 123}

def test_deserialize_tools_inplace_failures(self):
data = {"key": "value"}
data: dict[str, Any] = {"key": "value"}
deserialize_tools_or_toolset_inplace(data)
assert data == {"key": "value"}

Expand Down Expand Up @@ -186,7 +188,8 @@ def test_deserialize_list_of_toolsets_inplace(self):

assert isinstance(data["tools"], list)
assert len(data["tools"]) == 2
assert all(isinstance(ts, Toolset) for ts in data["tools"])
assert isinstance(data["tools"][0], Toolset)
assert isinstance(data["tools"][1], Toolset)
assert data["tools"][0][0].name == "weather"
assert data["tools"][1][0].name == "calculator"

Expand All @@ -201,7 +204,8 @@ def test_serialize_mixed_list_tools_and_toolsets(self):

toolset = Toolset([tool2])

data = serialize_tools_or_toolset([tool1, toolset])
tools: list[Tool | Toolset] = [tool1, toolset]
data = serialize_tools_or_toolset(tools)

assert isinstance(data, list)
assert len(data) == 2
Expand Down Expand Up @@ -230,7 +234,8 @@ def test_serialize_mixed_list_multiple_tools_and_toolsets(self):

toolset = Toolset([tool4, tool5])

data = serialize_tools_or_toolset([tool1, tool2, toolset, tool3])
tools: list[Tool | Toolset] = [tool1, tool2, toolset, tool3]
data = serialize_tools_or_toolset(tools)

assert isinstance(data, list)
assert len(data) == 4
Expand Down
4 changes: 2 additions & 2 deletions test/tools/test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_init_invalid_output_structure_config_not_dict(self):
description="irrelevant",
parameters={"type": "object", "properties": {"city": {"type": "string"}}},
function=get_weather_report,
outputs_to_state={"documents": ["some_value"]},
outputs_to_state={"documents": ["some_value"]}, # type: ignore[dict-item]
)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -258,7 +258,7 @@ def test_inputs_from_state_validation_with_non_string_value(self):
description="Get weather report",
parameters=parameters,
function=get_weather_report,
inputs_from_state={"state_key": {"source": "city"}},
inputs_from_state={"state_key": {"source": "city"}}, # type: ignore[dict-item]
)

def test_inputs_from_state_validation_with_valid_parameter(self):
Expand Down
Loading
Loading