Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ HF_HUB_ETAG_TIMEOUT=3
# want to complete jobs without word-level timestamps for selected languages.
WHISPERX_ALIGN_DISABLED_LANGUAGES=

# Alignment defaults to CPU to isolate wav2vec2 alignment from the GPU ASR and
# speaker-embedding runtimes. Set to pipeline/asr/cuda/cuda:0 only if you have
# validated the target CUDA stack is stable for WhisperX alignment.
WHISPERX_ALIGN_DEVICE=cpu

# Optional comma-separated language=model overrides.
# Example: WHISPERX_ALIGN_MODEL_MAP=zh=your-org/your-zh-align-model
WHISPERX_ALIGN_MODEL_MAP=
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/claude-code-review.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,6 @@ jobs:
synchronized English/Chinese documentation. Avoid formatting-only comments.
claude_args: |
--model ${{ env.CLAUDE_MODEL }}
--max-turns 30
env:
ANTHROPIC_BASE_URL: ${{ secrets.ANTHROPIC_BASE_URL }}
5 changes: 5 additions & 0 deletions app/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ RUN if [ -n "$PIP_INDEX_URL" ]; then \
else \
pip install --no-cache-dir -r requirements.txt; \
fi
RUN if [ -n "$PIP_INDEX_URL" ]; then \
pip install --no-cache-dir -i "$PIP_INDEX_URL" --no-deps whisperx==3.3.1; \
else \
pip install --no-cache-dir --no-deps whisperx==3.3.1; \
fi

COPY --chown=app:app . .

Expand Down
3 changes: 2 additions & 1 deletion app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pathlib import Path


APP_VERSION = "0.7.5"
APP_VERSION = "0.7.6"


def _env_float(name: str, default: float) -> float:
Expand Down Expand Up @@ -96,6 +96,7 @@ def _env_mapping(name: str) -> dict[str, str]:

# WhisperX forced-alignment controls. Languages are attempted by default; use
# WHISPERX_ALIGN_DISABLED_LANGUAGES only for an explicit operational fallback.
WHISPERX_ALIGN_DEVICE: str = _env_str("WHISPERX_ALIGN_DEVICE", "cpu").lower()
WHISPERX_ALIGN_DISABLED_LANGUAGES: frozenset[str] = _env_csv_set(
"WHISPERX_ALIGN_DISABLED_LANGUAGES",
"",
Expand Down
15 changes: 12 additions & 3 deletions app/infra/job_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,18 +87,27 @@ def flush_torch_cuda_cache(
"""Best-effort CUDA cache flush used around serialized GPU work."""

try:
import gc as _gc

import torch as _torch

_gc.collect()
# Full Python GC can hold the GIL long enough to make FastAPI liveness
# probes time out after large alignment results. Keep active job
# boundaries lightweight; the idle-unload path remains the heavy cleanup
# point because it runs after the GPU pipeline has been idle.
if phase == "idle-unload":
_collect_python_gc()
if _torch.cuda.is_available():
_torch.cuda.empty_cache()
except Exception as exc: # pragma: no cover - guarded for runtime-only failures
if logger is not None:
logger.warning("%s CUDA cache flush failed: %s", phase, exc)


def _collect_python_gc() -> None:
import gc as _gc

_gc.collect()


def run_serialized_gpu_work(
work: Callable[[], _T],
*,
Expand Down
1 change: 1 addition & 0 deletions app/nltk/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Minimal NLTK compatibility surface required by WhisperX alignment."""
1 change: 1 addition & 0 deletions app/nltk/tokenize/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tokenization compatibility helpers for WhisperX."""
89 changes: 89 additions & 0 deletions app/nltk/tokenize/punkt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Small Punkt-compatible sentence span tokenizer for WhisperX.

WhisperX 3.3.1 imports ``PunktParameters`` and ``PunktSentenceTokenizer`` only
to split an already bounded segment into sentence spans. Pulling the full NLTK
distribution into the runtime introduces unrelated data/license surface, so this
module implements the small API shape WhisperX uses.
"""

from __future__ import annotations

import re
from dataclasses import dataclass, field
from typing import Iterable


@dataclass
class PunktParameters:
"""Subset of NLTK's PunktParameters used by WhisperX."""

abbrev_types: set[str] = field(default_factory=set)


class PunktSentenceTokenizer:
"""Sentence span splitter compatible with WhisperX's use of NLTK Punkt."""

_TERMINATORS = {".", "!", "?", "。", "!", "?"}

def __init__(self, params: PunktParameters | None = None) -> None:
self.params = params or PunktParameters()

def span_tokenize(self, text: str) -> Iterable[tuple[int, int]]:
"""Yield half-open sentence spans in ``text``.

This intentionally implements conservative splitting: common
abbreviations configured by WhisperX are not treated as sentence
boundaries, and punctuation must be followed by whitespace or end of
string before a split is emitted.
"""

start = 0
index = 0
length = len(text)
while index < length:
char = text[index]
if char not in self._TERMINATORS or self._is_abbreviation(text, index):
index += 1
continue

next_index = index + 1
while next_index < length and text[next_index] in {
'"',
"'",
")",
"]",
"}",
"”",
"’",
}:
next_index += 1

if (
next_index < length
and char in {".", "!", "?"}
and not text[next_index].isspace()
):
index += 1
continue

end = next_index
while end < length and text[end].isspace():
end += 1

yield (start, next_index)
start = end
index = end

if start < length:
yield (start, length)
elif length == 0:
return

def _is_abbreviation(self, text: str, dot_index: int) -> bool:
if text[dot_index] != ".":
return False
prefix = text[:dot_index]
match = re.search(r"([A-Za-z]+)$", prefix)
if not match:
return False
return match.group(1).lower() in self.params.abbrev_types
15 changes: 14 additions & 1 deletion app/pipeline/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,10 @@ def __init__(
self._whisper_device = None
self._diarization_device = None
self._embedding_device = None
self._alignment_cache_key = None
self._alignment_device = None
self._alignment_model = None
self._alignment_metadata = None
self.model_size = model_size or WHISPER_MODEL
self.hf_token = hf_token or HF_TOKEN
self._whisper = None
Expand All @@ -284,16 +288,25 @@ def runner(self) -> PipelineRunner:
def has_loaded_models(self) -> bool:
return any(
getattr(self, name, None) is not None
for name in ("_whisper", "_diarization", "_embedding_model")
for name in (
"_whisper",
"_diarization",
"_embedding_model",
"_alignment_model",
)
)

def unload_models(self) -> None:
self._whisper = None
self._diarization = None
self._embedding_model = None
self._alignment_model = None
self._alignment_metadata = None
self._alignment_cache_key = None
self._whisper_device = None
self._diarization_device = None
self._embedding_device = None
self._alignment_device = None

def _select_device_for_lazy_load(self, device_attr: str) -> str:
selected_device = getattr(self, device_attr, None)
Expand Down
27 changes: 27 additions & 0 deletions app/providers/asr/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,18 @@
"简体中文输出",
"以下是普通话的对话",
)
_OUTRO_HALLUCINATION_MARKERS = (
"请不吝点赞",
"点赞",
"订阅",
"转发",
"打赏",
"打赏支持",
"明镜与点点栏目",
"谢谢观看",
"感谢观看",
"下期再见",
)


def _duration(segment: dict[str, Any]) -> float:
Expand Down Expand Up @@ -51,6 +63,17 @@ def _prompt_marker_key(normalized_text: str) -> str:
return ""


def _outro_marker_score(normalized_text: str) -> tuple[int, float]:
if not normalized_text:
return 0, 0.0

matched = {
marker for marker in _OUTRO_HALLUCINATION_MARKERS if marker in normalized_text
}
marker_chars = sum(len(marker) for marker in matched)
return len(matched), marker_chars / len(normalized_text)
Comment on lines +66 to +74


def _dominant_repeated_unit(normalized_text: str) -> tuple[str, int, float]:
"""Return the dominant repeated short unit, repeat count, and coverage ratio."""

Expand Down Expand Up @@ -89,6 +112,10 @@ def _is_single_segment_hallucination(segment: dict[str, Any]) -> bool:
if duration >= 3.0 and marker_count >= 2 and marker_ratio >= 0.55:
return True

outro_count, outro_ratio = _outro_marker_score(normalized)
if 3.0 <= duration <= 60.0 and outro_count >= 3 and outro_ratio >= 0.40:
return True

unit, repeat_count, repeat_ratio = _dominant_repeated_unit(normalized)
return (
bool(unit) and duration >= 12.0 and repeat_count >= 4 and repeat_ratio >= 0.82
Expand Down
89 changes: 74 additions & 15 deletions app/providers/diarization/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Any

from config import (
WHISPERX_ALIGN_DEVICE,
WHISPERX_ALIGN_CACHE_ONLY,
WHISPERX_ALIGN_DISABLED_LANGUAGES,
WHISPERX_ALIGN_MODEL_DIR,
Expand Down Expand Up @@ -107,6 +108,35 @@ def _alignment_disabled(language: str) -> bool:
)


def _resolve_alignment_device(pipeline) -> str:
configured = (WHISPERX_ALIGN_DEVICE or "cpu").strip().lower()
if configured in {"pipeline", "asr"}:
return str(getattr(pipeline, "device", "cpu") or "cpu")
if configured == "auto":
selector = getattr(pipeline, "_select_device_for_lazy_load", None)
if callable(selector):
return str(selector("_alignment_device"))
return str(getattr(pipeline, "device", "cpu") or "cpu")
return configured or "cpu"


def _alignment_cache_key(
*,
language: str,
model_name: str | None,
model_source: str,
device: str,
) -> tuple[str, str | None, str, str | None, bool, str]:
return (
language,
model_name,
model_source,
WHISPERX_ALIGN_MODEL_DIR,
WHISPERX_ALIGN_CACHE_ONLY,
device,
)


def _language_disabled_hint(language: str) -> str:
return (
f"Remove {language} from WHISPERX_ALIGN_DISABLED_LANGUAGES to retry "
Expand Down Expand Up @@ -248,32 +278,61 @@ def align_diarized_segments_with_metadata(
preflight_message = _torch_preflight_message(language, model_name)
if preflight_message:
logger.info(preflight_message)
alignment_device = _resolve_alignment_device(pipeline)
audio = whisperx.load_audio(audio_path)
load_kwargs = _load_align_model_kwargs(
whisperx.load_align_model,
language,
pipeline.device,
alignment_device,
)
load_started = time.perf_counter()
with _cache_only_alignment_environment():
align_model, align_metadata = whisperx.load_align_model(
**load_kwargs,
)
logger.info(
"Loaded WhisperX alignment model in %.2fs "
"(cold_load=True, language=%s, model_source=%s, device=%s)",
time.perf_counter() - load_started,
language,
model_source,
pipeline.device,
cache_key = _alignment_cache_key(
language=language,
model_name=model_name,
model_source=model_source,
device=alignment_device,
)
cached_key = getattr(pipeline, "_alignment_cache_key", None)
align_model = getattr(pipeline, "_alignment_model", None)
align_metadata = getattr(pipeline, "_alignment_metadata", None)
if (
cached_key == cache_key
and align_model is not None
and align_metadata is not None
):
logger.info(
"Reusing WhisperX alignment model (hot reuse, language=%s, model_source=%s, device=%s)",
language,
model_source,
alignment_device,
)
else:
setattr(pipeline, "_alignment_model", None)
setattr(pipeline, "_alignment_metadata", None)
setattr(pipeline, "_alignment_cache_key", None)
load_started = time.perf_counter()
with _cache_only_alignment_environment():
align_model, align_metadata = whisperx.load_align_model(
**load_kwargs,
)
setattr(pipeline, "_alignment_model", align_model)
setattr(pipeline, "_alignment_metadata", align_metadata)
setattr(pipeline, "_alignment_cache_key", cache_key)
setattr(pipeline, "_alignment_device", alignment_device)
logger.info(
"Loaded WhisperX alignment model in %.2fs "
"(cold_load=True, language=%s, model_source=%s, device=%s)",
time.perf_counter() - load_started,
language,
model_source,
alignment_device,
)
processing_started = time.perf_counter()
aligned_result = whisperx.align(
segments,
align_model,
align_metadata,
audio,
pipeline.device,
alignment_device,
return_char_alignments=False,
)
processing_elapsed_s = time.perf_counter() - processing_started
Expand All @@ -283,7 +342,7 @@ def align_diarized_segments_with_metadata(
processing_elapsed_s,
language,
len(segments),
pipeline.device,
alignment_device,
)
logger.info("WhisperX forced alignment succeeded for language=%s", language)
metadata = {
Expand Down
Loading
Loading