Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
(system: nixpkgsFor.${system}.nixpkgs-fmt);

overlays.default = final: prev: {
devShell = final.python3Packages.callPackage nixos/devShell { inherit self; };
devShell = final.python3Packages.callPackage nixos/devShell { };
whisper_api = final.python3Packages.callPackage nixos/pkgs/whisper_api { inherit self; };
# Our code is not compatible with pydantic version 2 yet.
python3 = prev.python3.override {
Expand Down
9 changes: 6 additions & 3 deletions nixos/devShell/default.nix
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
{ self, pkgs, ... }:
{ pkgs, ... }:
let
whisper_api = pkgs.whisper_api;
python-with-packages = pkgs.python3.withPackages (p: with p; [
# only needed for development
autopep8
black
httpx
isort
pylint
pip
] ++ self.packages.${pkgs.system}.whisper_api.propagatedBuildInputs);
pylint
pytest
] ++ whisper_api.propagatedBuildInputs);
in
pkgs.mkShell {
inputsFrom = [ whisper_api ];
buildInputs = with pkgs; [
# only needed for development
nixpkgs-fmt
Expand Down
8 changes: 6 additions & 2 deletions nixos/pkgs/whisper_api/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
uvicorn,

# tests
unittestCheckHook,
pytestCheckHook,
httpx,
}:
buildPythonApplication {
Expand Down Expand Up @@ -47,10 +47,14 @@ buildPythonApplication {
];

nativeCheckInputs = [
unittestCheckHook
pytestCheckHook
httpx
];

disabledTestPaths = [
"test/test_api.py"
];

pythonImportsCheck = [ "whisper_api" ];

meta = with lib; {
Expand Down
7 changes: 5 additions & 2 deletions src/whisper_api/decoding/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(
"""

self.pipe_to_parent = pipe_to_parent
self.pipe_send_lock = threading.Lock()
# TODO: handle maxsize by making it configurable from outside and handle case where Queue reaches limit
# queue that stores tasks that wait for processing
# using FastQueue because it allows for position queries of queued objects
Expand Down Expand Up @@ -188,14 +189,16 @@ def send_status_update(self):
status_dict = self.get_status_dict()
self.logger.debug(f"{status_dict}")

self.pipe_to_parent.send(status_dict)
with self.pipe_send_lock:
self.pipe_to_parent.send(status_dict)

@staticmethod
def task_to_pipe_message(task: Task, /) -> dict:
return {"type": "task_update", "data": task.to_json}

def send_task_update(self, task: Task, /):
self.pipe_to_parent.send(self.task_to_pipe_message(task))
with self.pipe_send_lock:
self.pipe_to_parent.send(self.task_to_pipe_message(task))

def handle_task(self, task: Task) -> Task:
"""
Expand Down
24 changes: 22 additions & 2 deletions src/whisper_api/log_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import multiprocessing
import os
import queue
import threading
from logging.handlers import TimedRotatingFileHandler
from multiprocessing.connection import Connection
Expand Down Expand Up @@ -63,6 +64,12 @@ def __init__(self, log_pipe: Connection, log_dir: str, log_file: str, **rotating
super().__init__(self.log_path, **rotating_file_handler_kwargs)
self.log_pipe = log_pipe

# Queue + drain thread for non-blocking pipe sends from child processes.
# Initialized lazily on first emit() in a child, because the handler is
# constructed in MainProcess and then inherited across fork().
self._send_queue: queue.Queue[logging.LogRecord] | None = None
self._drain_thread: threading.Thread | None = None

if multiprocessing.current_process().name == "MainProcess":
# start listening for logs from children
self.listener_thread = threading.Thread(target=self.listen_for_logs_from_children, args=(self.log_pipe,))
Expand All @@ -85,12 +92,25 @@ def wait_for_listener(self):
self.listener_thread.join()
print("Logger closed")

def _drain_send_queue(self):
"""Drain buffered log records into the pipe. Runs in child processes only."""
while True:
try:
record = self._send_queue.get()
self.log_pipe.send(record)
except Exception:
break

def emit(self, record: logging.LogRecord):
"""Emit the message or send it to the main"""

# if we're in a child process, send the record to the pipe to main process
# if we're in a child process, buffer the record for async sending
if not self.am_I_main:
self.log_pipe.send(record)
if self._send_queue is None:
self._send_queue = queue.Queue()
self._drain_thread = threading.Thread(target=self._drain_send_queue, daemon=True)
self._drain_thread.start()
self._send_queue.put_nowait(record)
return

# only write from main process
Expand Down
2 changes: 1 addition & 1 deletion src/whisper_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def exit_fn(signum: int):
# - the logger thread (registered in atexit)
# - uvicorn (it'll be fine and have its own shutdown procedures)

sys.exit(0)
os._exit(0)

def signal_worker_to_exit(signum: int, frame: Optional[FrameType]):
"""
Expand Down
Binary file not shown.
208 changes: 208 additions & 0 deletions test/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
import socket
import time
import unittest
from multiprocessing import Process
from pathlib import Path
from typing import Optional

import httpx
import uvicorn

from whisper_api import app

"""
Test that the API works.
"""

TEST_FILES_DIR = Path(__file__).parent / "files"
TEST_AUDIO_FILE = TEST_FILES_DIR / "En-Open_Source_Software_CD-article.ogg"


def do_test() -> tuple[bool, str]:
"""
Decide whether to run the tests or not.
These tests require a running whisper model.
"""
return True, "These tests require a whisper model and are skipped"


def find_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("127.0.0.1", 0))
return sock.getsockname()[1]


def run_server(port: int):
uvicorn.run(app, host="127.0.0.1", port=port)


class TestAPI(unittest.TestCase):
"""
Test basic features of the API.
"""

do_test, reason = do_test()

POST_TIMEOUT_S = 30
STATUS_REQUEST_TIMEOUT_S = 10
TRANSCRIPTION_DEADLINE_S = 240
STABILITY_RUNS = 20

proc: Optional[Process] = None
client: Optional[httpx.Client] = None
server_port: Optional[int] = None

@classmethod
def setUpClass(cls):
cls.server_port = find_free_port()
cls.client = httpx.Client(base_url=f"http://127.0.0.1:{cls.server_port}")
cls.proc = Process(target=run_server, args=(cls.server_port,), daemon=False)
cls.proc.start()

timeout = 30
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = cls.client.get("/api/v1/decoder_status", timeout=2)
if response.status_code == 200:
return
except httpx.HTTPError:
pass
time.sleep(0.2)

cls.tearDownClass()
raise RuntimeError("Server did not become ready in setUpClass")

@classmethod
def tearDownClass(cls):
if cls.client is not None:
cls.client.close()
cls.client = None

if cls.proc is None:
return

if cls.proc.is_alive():
cls.proc.terminate()
cls.proc.join(timeout=5)

if cls.proc.is_alive():
cls.proc.kill()
cls.proc.join(timeout=5)

cls.proc = None
cls.server_port = None

@unittest.skipIf(not do_test, reason)
def test_version(self):
"""Test that the version endpoint returns a version string."""
response = self.client.get("/api/v1/version")
self.assertEqual(response.status_code, 200)
self.assertIn("version", response.json())

@unittest.skipIf(not do_test, reason)
def test_4_loaded_model(self):
"""
Test that the API is reachable and the model is loaded within 120 seconds.
"""
timeout = 120
start_time = time.time()

while time.time() - start_time < timeout:
response = self.client.get("/api/v1/decoder_status")
if response.json().get("is_model_loaded") == True:
break
print("Waiting for model to load...")
time.sleep(1)
else:
self.fail(f"Model did not load within {timeout} seconds")

self.assertTrue(response.json().get("is_model_loaded"))

@unittest.skipIf(not do_test, reason)
def test_status_invalid_task_id(self):
"""Test that requesting status for an unknown task_id returns 400."""
response = self.client.get("/api/v1/status?task_id=00000000000000000000000000000000")
self.assertEqual(response.status_code, 400)

@unittest.skipIf(not do_test, reason)
def test_transcribe_non_audio_file(self):
"""Test that uploading a non-audio file returns 400."""
files = {"file": ("test.txt", b"this is not audio", "text/plain")}
response = self.client.post("/api/v1/transcribe", files=files, timeout=self.POST_TIMEOUT_S)
self.assertEqual(response.status_code, 400)

@unittest.skipIf(not do_test, reason)
def test_transcribe(self):
"""
Test that the API can transcribe a given audio file.
"""
result = self._run_transcription_and_wait()

self.assertEqual(result.get("status"), "finished")
self.assertTrue(result.get("transcript") is not None and len(result.get("transcript")) > 0)

@unittest.skipIf(not do_test, reason)
def test_translate(self):
"""
Test that the API can translate a given audio file.
"""
result = self._run_transcription_and_wait(endpoint="/api/v1/translate")

self.assertEqual(result.get("status"), "finished")
self.assertTrue(result.get("transcript") is not None and len(result.get("transcript")) > 0)

@unittest.skipIf(not do_test, reason)
def test_stability(self):
"""
Test the stability of the API by running transcription 20 times.
Only passes if all runs complete successfully.
"""
failures = []
for i in range(self.STABILITY_RUNS):
print(f"Stability run {i + 1}/{self.STABILITY_RUNS}")
try:
result = self._run_transcription_and_wait()
self.assertEqual(result.get("status"), "finished")
self.assertTrue(result.get("transcript") is not None and len(result.get("transcript")) > 0)
except Exception as exc:
failures.append(f"Run {i + 1}: {exc}")

if failures:
self.fail(f"{len(failures)}/{self.STABILITY_RUNS} runs failed:\n" + "\n".join(failures))

def _run_transcription_and_wait(self, endpoint: str = "/api/v1/transcribe") -> dict:
with open(TEST_AUDIO_FILE, "rb") as file:
files = {"file": file}
response = self.client.post(endpoint, files=files, timeout=self.POST_TIMEOUT_S)

self.assertEqual(response.status_code, 200)
task_id = response.json().get("task_id")
self.assertIsNotNone(task_id)

print(response.json())

start_time = time.time()

while time.time() - start_time < self.TRANSCRIPTION_DEADLINE_S:
try:
status_response = self.client.get(
f"/api/v1/status?task_id={task_id}", timeout=self.STATUS_REQUEST_TIMEOUT_S
)
except httpx.TimeoutException:
# Status requests can occasionally time out under heavy CPU load; retry until deadline.
print("Status request timed out, retrying...")
time.sleep(1)
continue

self.assertEqual(status_response.status_code, 200)
payload = status_response.json()

if payload.get("status") == "finished":
print(payload)
return payload

print("Waiting for transcription to complete...")
time.sleep(1)

self.fail(f"Transcription did not complete within {self.TRANSCRIPTION_DEADLINE_S} seconds")