Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions src/engine/openvino/qwen3_asr/qwen3_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pathlib import Path
import logging
import gc
from typing import Any, AsyncIterator, Dict, Optional, Union
from typing import Any, AsyncIterator, Dict, List, Optional, Tuple, Union

import numpy as np
import openvino as ov
Expand Down Expand Up @@ -368,7 +368,7 @@ def audio_chunks(self, chunk_audio: np.ndarray, max_tokens: int):
)
return raw, metrics

async def transcribe(self, gen_config: OV_Qwen3ASRGenConfig) -> AsyncIterator[Union[Dict[str, Any], str]]:
async def transcribe(self, gen_config: OV_Qwen3ASRGenConfig) -> Tuple[str, Dict[str, Any], List[Dict[str, Any]]]:
t_transcribe_start = time.perf_counter()
audio_input = gen_config.audio_base64
assert audio_input, "audio_base64 is required"
Expand All @@ -382,9 +382,7 @@ async def transcribe(self, gen_config: OV_Qwen3ASRGenConfig) -> AsyncIterator[Un

audio_seconds = len(audio_array) / SAMPLE_RATE
if audio_seconds <= 0:
yield {}
yield ""
return
return "", {}, []

max_chunk_sec = min(float(gen_config.max_chunk_sec), float(MAX_ASR_INPUT_SECONDS))
chunk_items = await asyncio.to_thread(
Expand All @@ -399,6 +397,7 @@ async def transcribe(self, gen_config: OV_Qwen3ASRGenConfig) -> AsyncIterator[Un

langs = []
texts = []
segments = []
agg = {
"feature_sec": 0.0,
"encoder_sec": 0.0,
Expand All @@ -422,6 +421,12 @@ async def transcribe(self, gen_config: OV_Qwen3ASRGenConfig) -> AsyncIterator[Un
langs.append(lang)
if text:
texts.append(text)
segments.append({
"id": idx,
"start": float(chunk_offset_sec),
"end": float(chunk_offset_sec) + float(chunk_sec),
"text": text,
})
agg["feature_sec"] += chunk_metrics["feature_sec"]
agg["encoder_sec"] += chunk_metrics["encoder_sec"]
agg["prefill_sec"] += chunk_metrics["prefill_sec"]
Expand Down Expand Up @@ -453,8 +458,7 @@ async def transcribe(self, gen_config: OV_Qwen3ASRGenConfig) -> AsyncIterator[Un
if merged_language:
metrics["language"] = merged_language

yield metrics
yield text
return text, metrics, segments

async def unload_model(self, registry: ModelRegistry, model_name: str) -> bool:
removed = await registry.register_unload(model_name)
Expand Down
12 changes: 10 additions & 2 deletions src/server/routes/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ async def openai_audio_transcriptions(
gen_config = OVGenAI_WhisperGenConfig(audio_base64=audio_base64)
result = await _workers.transcribe_whisper(model, gen_config)

metrics = result.get("metrics", {})
metrics: Dict[str, Any] = result.get("metrics", {})
logger.info(f"[audio/transcriptions] model={model} metrics={metrics}")

if response_format == "json":
Expand All @@ -544,9 +544,17 @@ async def openai_audio_transcriptions(
return {
"text": result.get("text", ""),
"language": metrics.get("language"),
"duration": metrics.get("duration"),
"duration": metrics.get("duration") or metrics.get("audio_duration_sec"),
"segments": result.get("segments", []),
"metrics": metrics,
}
elif response_format == "diarized_json":
return {
"duration": metrics.get("duration") or metrics.get("audio_duration_sec"),
"segments": result.get("segments", []),
"task": "transcribe",
"text": result.get("text", ""),
}
else:
return result.get("text", "")

Expand Down
28 changes: 14 additions & 14 deletions src/server/worker_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
import soundfile as sf
from dataclasses import dataclass
from typing import Any, AsyncIterator, Dict, Optional, Union
from typing import Any, AsyncIterator, Dict, List, Optional, Union

from src.engine.ov_genai.llm import OVGenAI_LLM
from src.engine.ov_genai.vlm import OVGenAI_VLM
Expand Down Expand Up @@ -66,6 +66,7 @@ class WorkerPacket:
]
response: Optional[str] = None
metrics: Optional[Dict[str, Any]] = None
segments: Optional[List[Dict[str, Any]]] = None
# Orchestration plumbing
result_future: Optional[asyncio.Future] = None
stream_queue: Optional[asyncio.Queue] = None
Expand Down Expand Up @@ -191,21 +192,13 @@ async def infer_whisper(packet: WorkerPacket, whisper_model: OVGenAI_Whisper) ->
async def infer_qwen3_asr(packet: WorkerPacket, asr_model: OVQwen3ASR) -> WorkerPacket:
"""Transcribe audio for a single packet using the OVQwen3ASR pipeline."""
metrics = None
final_text = ""

try:
async for item in asr_model.transcribe(packet.gen_config):
if isinstance(item, dict):
metrics = item
else:
final_text = item

packet.response = final_text
packet.metrics = metrics
assert isinstance(packet.gen_config, OV_Qwen3ASRGenConfig), "Expected OV_Qwen3ASRGenConfig for Qwen3 ASR inference"
packet.response, packet.metrics, packet.segments = await asr_model.transcribe(packet.gen_config)
except Exception as e:
logger.error("Qwen3 ASR inference failed!", exc_info=True)
packet.response = f"Error: {str(e)}"
packet.metrics = None
packet.response, packet.metrics, packet.segments = f"Error: {str(e)}", None, None

return packet

Expand Down Expand Up @@ -899,7 +892,7 @@ async def transcribe_whisper(self, model_name: str, gen_config: OVGenAI_WhisperG
async def transcribe_qwen3_asr(self, model_name: str, gen_config: OV_Qwen3ASRGenConfig) -> Dict[str, Any]:
"""Transcribe audio using Qwen3 ASR model."""
request_id = uuid.uuid4().hex
result_future: asyncio.Future = asyncio.get_running_loop().create_future()
result_future: asyncio.Future[WorkerPacket] = asyncio.get_running_loop().create_future()
packet = WorkerPacket(
request_id=request_id,
id_model=model_name,
Expand All @@ -909,7 +902,14 @@ async def transcribe_qwen3_asr(self, model_name: str, gen_config: OV_Qwen3ASRGen
q = self._get_qwen3_asr_queue(model_name)
await q.put(packet)
completed = await result_future
return {"text": completed.response or "", "metrics": completed.metrics or {}}

response: Dict[str, Any] = {"text": completed.response or ""}
if completed.metrics:
response["metrics"] = completed.metrics
if completed.segments:
response["segments"] = completed.segments

return response

async def generate_speech_qwen3_tts(self, model_name: str, gen_config: OV_Qwen3TTSGenConfig) -> Dict[str, Any]:
"""Generate speech using a loaded Qwen3 TTS model.
Expand Down
101 changes: 101 additions & 0 deletions src/tests/test_openai_audio_transcriptions_unit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import asyncio
import base64
from types import SimpleNamespace
from unittest.mock import AsyncMock

import pytest # type: ignore[import]

import src.server.routes.openai as openai_module
from src.server.models.registration import ModelType


_AUDIO_BYTES = b"audio-bytes"
_RESULT = {
"text": "hello world",
"metrics": {"language": "english", "audio_duration_sec": 4.0, "rtf": 0.5},
"segments": [{"id": 0, "start": 0.0, "end": 4.0, "text": "hello world"}],
}


class _FakeUpload:
async def read(self) -> bytes:
return _AUDIO_BYTES


def _call(monkeypatch: pytest.MonkeyPatch, response_format: str, result=None, openarc_asr=None):
"""Invoke the transcription handler with a loaded qwen3-asr model mocked in."""
result = _RESULT if result is None else result

fake_registry = SimpleNamespace(
_lock=asyncio.Lock(),
_models={
"qwen3": SimpleNamespace(model_name="qwen3-asr", model_type=ModelType.QWEN3_ASR),
},
)
transcribe_mock = AsyncMock(return_value=result)
fake_workers = SimpleNamespace(transcribe_qwen3_asr=transcribe_mock)

monkeypatch.setattr(openai_module, "_registry", fake_registry)
monkeypatch.setattr(openai_module, "_workers", fake_workers)

response = asyncio.run(
openai_module.openai_audio_transcriptions(
file=_FakeUpload(),
model="qwen3-asr",
response_format=response_format,
openarc_asr=openarc_asr,
)
)
return response, transcribe_mock


def test_json_returns_text_only(monkeypatch: pytest.MonkeyPatch) -> None:
response, _ = _call(monkeypatch, "json")
assert response == {"text": "hello world"}


def test_verbose_json_includes_segments_and_duration(monkeypatch: pytest.MonkeyPatch) -> None:
response, _ = _call(monkeypatch, "verbose_json")
assert response == {
"text": "hello world",
"language": "english",
"duration": 4.0, # falls back to audio_duration_sec (metrics has no "duration")
"segments": _RESULT["segments"],
"metrics": _RESULT["metrics"],
}


def test_diarized_json_shape(monkeypatch: pytest.MonkeyPatch) -> None:
response, _ = _call(monkeypatch, "diarized_json")
assert response == {
"duration": 4.0,
"segments": _RESULT["segments"],
"task": "transcribe",
"text": "hello world",
}


def test_duration_prefers_explicit_metric(monkeypatch: pytest.MonkeyPatch) -> None:
result = {
"text": "hi",
"metrics": {"duration": 9.9, "audio_duration_sec": 4.0},
"segments": [],
}
response, _ = _call(monkeypatch, "verbose_json", result=result)
assert response["duration"] == 9.9


def test_missing_segments_defaults_to_empty_list(monkeypatch: pytest.MonkeyPatch) -> None:
result = {"text": "hi", "metrics": {"audio_duration_sec": 1.0}} # no "segments" key
response, _ = _call(monkeypatch, "verbose_json", result=result)
assert response["segments"] == []


def test_defaults_used_when_openarc_asr_missing(monkeypatch: pytest.MonkeyPatch) -> None:
# openarc_asr is None -> handler builds a default qwen3 gen_config and still
# forwards the uploaded audio as base64.
_, transcribe_mock = _call(monkeypatch, "json", openarc_asr=None)

transcribe_mock.assert_awaited_once()
_model_arg, gen_config = transcribe_mock.await_args.args
assert gen_config.audio_base64 == base64.b64encode(_AUDIO_BYTES).decode("utf-8")
Loading