Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,24 @@
# SPDX-License-Identifier: Apache-2.0

import importlib.metadata
from collections.abc import Iterable
from typing import Any, ClassVar

from haystack import component, default_from_dict, default_to_dict
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.components.embedders import OpenAIDocumentEmbedder
from haystack.utils.auth import Secret
from more_itertools import batched
from openai import APIError
from openai.types.create_embedding_response import CreateEmbeddingResponse
from tqdm import tqdm
from tqdm.asyncio import tqdm as async_tqdm

from haystack_integrations.components.embedders.perplexity.embedding_encoding import (
_decode_embedding,
_validate_encoding_format,
)

logger = logging.getLogger(__name__)

_INTEGRATION_SLUG = "haystack"
_PACKAGE_NAME = "perplexity-haystack"
Expand Down Expand Up @@ -73,6 +86,7 @@ def __init__(
progress_bar: bool = True,
meta_fields_to_embed: list[str] | None = None,
embedding_separator: str = "\n",
encoding_format: str = "base64_int8",
timeout: float | None = None,
max_retries: int | None = None,
http_client_kwargs: dict[str, Any] | None = None,
Expand All @@ -99,6 +113,8 @@ def __init__(
List of meta fields that should be embedded along with the Document text.
:param embedding_separator:
Separator used to concatenate the meta fields to the Document text.
:param encoding_format:
The Perplexity embedding encoding format. Supported values are `base64_int8` and `base64_binary`.
:param timeout:
Timeout for Perplexity client calls. If not set, it defaults to either the `OPENAI_TIMEOUT` environment
variable, or 30 seconds.
Expand All @@ -109,6 +125,7 @@ def __init__(
A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
"""
self.encoding_format = _validate_encoding_format(encoding_format)
super(PerplexityDocumentEmbedder, self).__init__( # noqa: UP008
api_key=api_key,
model=model,
Expand All @@ -129,6 +146,94 @@ def __init__(
self.timeout = timeout
self.max_retries = max_retries

def _decode_response_embeddings(self, response: CreateEmbeddingResponse) -> list[list[float]]:
return [_decode_embedding(str(el.embedding), self.encoding_format) for el in response.data]

def _embed_batch(
self, texts_to_embed: dict[str, str], batch_size: int
) -> tuple[dict[str, list[float]], dict[str, Any]]:
"""
Embed a list of texts in batches.
"""

doc_ids_to_embeddings: dict[str, list[float]] = {}
meta: dict[str, Any] = {}
for batch in tqdm(
batched(texts_to_embed.items(), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
):
args: dict[str, Any] = {
"model": self.model,
"input": [b[1] for b in batch],
"encoding_format": self.encoding_format,
}

try:
response = self.client.embeddings.create(**args)
except APIError as exc:
ids = ", ".join(b[0] for b in batch)
msg = "Failed embedding of documents {ids} caused by {exc}"
logger.exception(msg, ids=ids, exc=exc)
if self.raise_on_failure:
raise exc
continue

embeddings = self._decode_response_embeddings(response)
doc_ids_to_embeddings.update(dict(zip((b[0] for b in batch), embeddings, strict=True)))

if "model" not in meta:
meta["model"] = response.model
if "usage" not in meta:
meta["usage"] = dict(response.usage)
else:
meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens
meta["usage"]["total_tokens"] += response.usage.total_tokens

return doc_ids_to_embeddings, meta

async def _embed_batch_async(
self, texts_to_embed: dict[str, str], batch_size: int
) -> tuple[dict[str, list[float]], dict[str, Any]]:
"""
Embed a list of texts in batches asynchronously.
"""

doc_ids_to_embeddings: dict[str, list[float]] = {}
meta: dict[str, Any] = {}

batches: Iterable[tuple[tuple[str, str], ...]] = list(batched(texts_to_embed.items(), batch_size))
if self.progress_bar:
batches = async_tqdm(batches, desc="Calculating embeddings")

for batch in batches:
args: dict[str, Any] = {
"model": self.model,
"input": [b[1] for b in batch],
"encoding_format": self.encoding_format,
}

try:
response = await self.async_client.embeddings.create(**args)
except APIError as exc:
ids = ", ".join(b[0] for b in batch)
msg = "Failed embedding of documents {ids} caused by {exc}"
logger.exception(msg, ids=ids, exc=exc)
if self.raise_on_failure:
raise exc
continue

embeddings = self._decode_response_embeddings(response)
doc_ids_to_embeddings.update(dict(zip((b[0] for b in batch), embeddings, strict=True)))

if "model" not in meta:
meta["model"] = response.model
if "usage" not in meta:
meta["usage"] = dict(response.usage)
else:
meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens
meta["usage"]["total_tokens"] += response.usage.total_tokens

return doc_ids_to_embeddings, meta

def to_dict(self) -> dict[str, Any]:
"""
Serializes the component to a dictionary.
Expand All @@ -147,6 +252,7 @@ def to_dict(self) -> dict[str, Any]:
progress_bar=self.progress_bar,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
encoding_format=self.encoding_format,
timeout=self.timeout,
max_retries=self.max_retries,
http_client_kwargs=self.http_client_kwargs,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

import base64

import numpy as np

PERPLEXITY_FLOAT_ENCODING_FORMAT_ERROR = (
"Perplexity's /v1/embeddings does not support encoding_format='float'; use 'base64_int8' or 'base64_binary'."
)

SUPPORTED_ENCODING_FORMATS = {"base64_int8", "base64_binary"}


def _validate_encoding_format(encoding_format: str) -> str:
"""
Validate Perplexity's embedding encoding format.
"""
if encoding_format not in SUPPORTED_ENCODING_FORMATS:
if encoding_format == "float":
msg = PERPLEXITY_FLOAT_ENCODING_FORMAT_ERROR
else:
supported_formats = "', '".join(sorted(SUPPORTED_ENCODING_FORMATS))
msg = f"Unsupported encoding_format='{encoding_format}'. Use '{supported_formats}'."
Comment on lines +21 to +25
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if encoding_format == "float":
msg = PERPLEXITY_FLOAT_ENCODING_FORMAT_ERROR
else:
supported_formats = "', '".join(sorted(SUPPORTED_ENCODING_FORMATS))
msg = f"Unsupported encoding_format='{encoding_format}'. Use '{supported_formats}'."
supported_formats = "', '".join(sorted(SUPPORTED_ENCODING_FORMATS))
msg = f"Unsupported encoding_format='{encoding_format}'. Use '{supported_formats}'."

I'd not distinguish between float and other unsupported formats, no?

raise ValueError(msg)
return encoding_format


def _decode_embedding(embedding: str, encoding_format: str) -> list[float]:
"""
Decode a Perplexity base64 embedding into Haystack's list[float] representation.
"""
raw_embedding = base64.b64decode(embedding)
if encoding_format == "base64_int8":
return np.frombuffer(raw_embedding, dtype=np.int8).astype(np.float32).tolist()
if encoding_format == "base64_binary":
return np.unpackbits(np.frombuffer(raw_embedding, dtype=np.uint8)).astype(np.float32).tolist()

msg = f"Unsupported encoding_format='{encoding_format}'."
raise ValueError(msg)
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
from haystack import component, default_from_dict, default_to_dict
from haystack.components.embedders import OpenAITextEmbedder
from haystack.utils.auth import Secret
from openai.types.create_embedding_response import CreateEmbeddingResponse

from haystack_integrations.components.embedders.perplexity.embedding_encoding import (
_decode_embedding,
_validate_encoding_format,
)

_INTEGRATION_SLUG = "haystack"
_PACKAGE_NAME = "perplexity-haystack"
Expand Down Expand Up @@ -64,6 +70,7 @@ def __init__(
api_base_url: str | None = "https://api.perplexity.ai/v1",
prefix: str = "",
suffix: str = "",
encoding_format: str = "base64_int8",
timeout: float | None = None,
max_retries: int | None = None,
http_client_kwargs: dict[str, Any] | None = None,
Expand All @@ -81,6 +88,8 @@ def __init__(
A string to add to the beginning of each text.
:param suffix:
A string to add to the end of each text.
:param encoding_format:
The Perplexity embedding encoding format. Supported values are `base64_int8` and `base64_binary`.
:param timeout:
Timeout for Perplexity client calls. If not set, it defaults to either the `OPENAI_TIMEOUT` environment
variable, or 30 seconds.
Expand All @@ -92,6 +101,7 @@ def __init__(
For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
"""

self.encoding_format = _validate_encoding_format(encoding_format)
super(PerplexityTextEmbedder, self).__init__( # noqa: UP008
api_key=api_key,
model=model,
Expand All @@ -108,6 +118,18 @@ def __init__(
self.timeout = timeout
self.max_retries = max_retries

def _prepare_input(self, text: str) -> dict[str, Any]:
kwargs = OpenAITextEmbedder._prepare_input(self, text=text)
kwargs["input"] = [kwargs["input"]]
kwargs["encoding_format"] = self.encoding_format
return kwargs

def _prepare_output(self, result: CreateEmbeddingResponse) -> dict[str, Any]:
return {
"embedding": _decode_embedding(str(result.data[0].embedding), self.encoding_format),
"meta": {"model": result.model, "usage": dict(result.usage)},
}

def to_dict(self) -> dict[str, Any]:
"""
Serializes the component to a dictionary.
Expand All @@ -122,6 +144,7 @@ def to_dict(self) -> dict[str, Any]:
api_base_url=self.api_base_url,
prefix=self.prefix,
suffix=self.suffix,
encoding_format=self.encoding_format,
timeout=self.timeout,
max_retries=self.max_retries,
http_client_kwargs=self.http_client_kwargs,
Expand Down
24 changes: 24 additions & 0 deletions integrations/perplexity/tests/test_embedding_encoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

import pytest

from haystack_integrations.components.embedders.perplexity.embedding_encoding import (
_decode_embedding,
_validate_encoding_format,
)


def test_validate_encoding_format_rejects_unsupported_format():
with pytest.raises(ValueError) as exc_info:
_validate_encoding_format("base64_float16")

assert str(exc_info.value) == "Unsupported encoding_format='base64_float16'. Use 'base64_binary', 'base64_int8'."


def test_decode_embedding_rejects_unsupported_format():
with pytest.raises(ValueError) as exc_info:
_decode_embedding("", "base64_float16")

assert str(exc_info.value) == "Unsupported encoding_format='base64_float16'."
Loading
Loading