Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions aiola/clients/stt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,17 @@ def _build_query_and_headers(
query = {
"execution_id": execution_id,
"flow_id": resolved_workflow_id,
"lang_code": lang_code or "en",
"time_zone": time_zone or "UTC",
"keywords": json.dumps(keywords or {}),
"tasks_config": json.dumps(tasks_config or {}),
"x-aiola-api-token": access_token,
}

if lang_code is not None:
query["lang_code"] = lang_code
if keywords is not None:
query["keywords"] = json.dumps(keywords)
if tasks_config is not None:
query["tasks_config"] = json.dumps(tasks_config)

headers = {
"Authorization": f"Bearer {access_token}",
}
Expand Down
51 changes: 9 additions & 42 deletions aiola/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import enum
from collections.abc import Mapping
from dataclasses import dataclass
from typing import IO, Union
from typing import IO, Any, Union

from .constants import DEFAULT_AUTH_BASE_URL, DEFAULT_BASE_URL, DEFAULT_HTTP_TIMEOUT, DEFAULT_WORKFLOW_ID

Expand Down Expand Up @@ -46,16 +46,7 @@ def __post_init__(self) -> None:
class LiveEvents(str, enum.Enum):
Transcript = "transcript"
Translation = "translation"
SentimentAnalysis = "sentiment_analysis"
Summarization = "summarization"
TopicDetection = "topic_detection"
ContentModeration = "content_moderation"
AutoChapters = "auto_chapters"
FormFilling = "form_filling"
EntityDetection = "entity_detection"
EntityDetectionFromList = "entity_detection_from_list"
KeyPhrases = "key_phrases"
PiiRedaction = "pii_redaction"
Structured = "structured"
Error = "error"
Disconnect = "disconnect"
Connect = "connect"
Expand Down Expand Up @@ -115,6 +106,13 @@ def from_dict(cls, data: dict) -> TranscriptionResponse:
)


@dataclass
class StructuredResponse:
"""Response from structured API."""

results: dict[str, Any]


@dataclass
class SessionCloseResponse:
"""Response from session close API."""
Expand All @@ -137,40 +135,9 @@ class TranslationPayload:
dst_lang_code: str


@dataclass
class EntityDetectionFromListPayload:
entity_list: list[str]


@dataclass
class _EmptyPayload:
pass


EntityDetectionPayload = _EmptyPayload
KeyPhrasesPayload = _EmptyPayload
PiiRedactionPayload = _EmptyPayload
SentimentAnalysisPayload = _EmptyPayload
SummarizationPayload = _EmptyPayload
TopicDetectionPayload = _EmptyPayload
ContentModerationPayload = _EmptyPayload
AutoChaptersPayload = _EmptyPayload
FormFillingPayload = _EmptyPayload


@dataclass
class TasksConfig:
FORM_FILLING: FormFillingPayload | None = None
TRANSLATION: TranslationPayload | None = None
ENTITY_DETECTION: EntityDetectionPayload | None = None
ENTITY_DETECTION_FROM_LIST: EntityDetectionFromListPayload | None = None
KEY_PHRASES: KeyPhrasesPayload | None = None
PII_REDACTION: PiiRedactionPayload | None = None
SENTIMENT_ANALYSIS: SentimentAnalysisPayload | None = None
SUMMARIZATION: SummarizationPayload | None = None
TOPIC_DETECTION: TopicDetectionPayload | None = None
CONTENT_MODERATION: ContentModerationPayload | None = None
AUTO_CHAPTERS: AutoChaptersPayload | None = None


FileContent = Union[IO[bytes], bytes, str]
Expand Down
9 changes: 4 additions & 5 deletions tests/unit/stt/test_stt_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def test_stt_stream_with_empty_tasks_config(patch_dummy_socket):


def test_stt_stream_with_no_tasks_config(patch_dummy_socket):
"""``SttClient.stream`` handles None tasks_config properly."""
"""``SttClient.stream`` handles None tasks_config properly by not including it in URL."""

client = AiolaClient(api_key="secret-key", base_url="https://speech.example")

Expand All @@ -341,15 +341,14 @@ def test_stt_stream_with_no_tasks_config(patch_dummy_socket):
# Access the underlying socket to validate connection parameters
sio = connection._sio

# Verify None tasks_config is serialized as empty JSON object
# Verify None tasks_config is not included in URL
kwargs = sio.connect_kwargs
url = kwargs["url"]
parsed = urllib.parse.urlparse(url)
query = urllib.parse.parse_qs(parsed.query)

tasks_config_json = query["tasks_config"][0]
parsed_tasks_config = json.loads(tasks_config_json)
assert parsed_tasks_config == {}
# tasks_config should not be present when None
assert "tasks_config" not in query


def test_stt_stream_with_all_tasks_config(patch_dummy_socket):
Expand Down