Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 42 additions & 15 deletions src/openbench/engine/whisperkitpro_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@

logger = get_logger(__name__)


def _config_str_provided(value: str | None) -> bool:
return value is not None and value.strip() != ""


COMPUTE_UNITS_MAPPER = {
ct.ComputeUnit.CPU_ONLY: "cpuOnly",
ct.ComputeUnit.CPU_AND_NE: "cpuAndNeuralEngine",
Expand All @@ -26,9 +31,10 @@
class WhisperKitProConfig(BaseModel):
"""Configuration for transcription operations.

Supports two modes:
1. Legacy: model_version, model_prefix, model_repo_name
2. New: repo_id, model_variant (downloads models locally)
Supports three modes:
1. Local: model_dir only (existing directory on disk; no Hugging Face download)
2. Hugging Face: repo_id + model_variant (downloads unless model_dir already exists)
3. Legacy: model_version, model_prefix, model_repo_name
"""

# Legacy fields
Expand Down Expand Up @@ -56,7 +62,10 @@ class WhisperKitProConfig(BaseModel):
)
model_dir: str | None = Field(
None,
description="Local path to model directory, if this is provided it will ignore the repo_id and model_variant and use the provided path directly",
description=(
"Local directory passed as --model-path. If set, must exist; repo_id/model_variant are not used for "
"download (no Hugging Face fetch when only model_dir is configured)."
),
)
word_timestamps: bool = Field(
True,
Expand Down Expand Up @@ -128,7 +137,7 @@ def generate_cli_args(self, model_path: Path | None = None) -> list[str]:
# Use either --model-path (new) or legacy model args
if self.use_model_path:
if model_path is None:
raise ValueError("model_path required when using repo_id/model_variant")
raise ValueError("model_path required when using --model-path mode")
args = [
"--model-path",
str(model_path),
Expand Down Expand Up @@ -185,26 +194,34 @@ def generate_cli_args(self, model_path: Path | None = None) -> list[str]:

@property
def use_model_path(self) -> bool:
"""Check if we should use --model-path vs legacy args."""
return self.repo_id is not None and self.model_variant is not None
"""Use --model-path when model_dir is set, or when HF repo_id + model_variant are set."""
if _config_str_provided(self.model_dir):
return True
return _config_str_provided(self.repo_id) and _config_str_provided(self.model_variant)

def download_and_prepare_model(self) -> Path:
"""Download model from HuggingFace and prepare folder.
"""Resolve local model directory or download from Hugging Face.

Returns:
Path to model directory for --model-path
"""
if not self.use_model_path:
raise ValueError("download_and_prepare_model requires repo_id and model_variant")
raise ValueError("download_and_prepare_model requires model_dir or repo_id/model_variant")

if _config_str_provided(self.model_dir):
p = Path(self.model_dir).expanduser().resolve()
if not p.is_dir():
raise FileNotFoundError(
f"model_dir must be an existing directory (no Hugging Face download when model_dir is set): {self.model_dir}"
)
logger.info(f"Using local model at: {p}")
return p

# Check if model already exists
if self.model_dir is not None and os.path.exists(self.model_dir):
logger.info(f"Model already exists at: {self.model_dir}")
return Path(self.model_dir)
if not (_config_str_provided(self.repo_id) and _config_str_provided(self.model_variant)):
raise ValueError("repo_id and model_variant are required when model_dir is not set")

logger.info(f"Downloading model from {self.repo_id}, variant: {self.model_variant}")

# Download specific model variant folder from HuggingFace
try:
downloaded_path = snapshot_download(repo_id=self.repo_id, allow_patterns=f"{self.model_variant}/*")
return Path(f"{downloaded_path}/{self.model_variant}")
Expand Down Expand Up @@ -252,10 +269,20 @@ def __init__(
# Download and prepare model if using new model management
self.model_path = None
if self.transcription_config.use_model_path:
logger.debug("Using model path management with repo_id/model_variant")
logger.debug("Using --model-path (local model_dir and/or Hugging Face ids)")
self.model_path = self.transcription_config.download_and_prepare_model()
else:
logger.debug("Using legacy model management")
if not (
_config_str_provided(self.transcription_config.model_version)
and _config_str_provided(self.transcription_config.model_prefix)
and _config_str_provided(self.transcription_config.model_repo_name)
):
raise ValueError(
"WhisperKitPro requires one of: model_dir (existing directory), "
"(repo_id and model_variant for Hugging Face), or "
"(model_version, model_prefix, model_repo_name) for legacy CLI args."
)

# Generate CLI args (with model_path if available)
self.transcription_args = self.transcription_config.generate_cli_args(model_path=self.model_path)
Expand Down
14 changes: 14 additions & 0 deletions src/openbench/pipeline/pipeline_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,20 @@ def register_pipeline_aliases() -> None:
description="WhisperKitPro transcription pipeline using the parakeet-v3 version of the model compressed to 494MB. Requires `WHISPERKITPRO_CLI_PATH` env var and depending on your permissions also `WHISPERKITPRO_API_KEY` env var.",
)

PipelineRegistry.register_alias(
"whisperkitpro-local-model",
WhisperKitProTranscriptionPipeline,
default_config={
"model_dir": os.getenv("WHISPERKITPRO_LOCAL_MODEL_PATH"),
"cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"),
},
description=(
"WhisperKitPro transcription using only a local model directory (no default Hugging Face repo). "
"Set `WHISPERKITPRO_LOCAL_MODEL_PATH` to the folder passed as `--model-path` on the CLI; it must exist. "
"Requires `WHISPERKITPRO_CLI_PATH` and may require `WHISPERKITPRO_API_KEY`."
),
)

PipelineRegistry.register_alias(
"groq-whisper-large-v3-turbo",
GroqTranscriptionPipeline,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@
class WhisperKitProTranscriptionConfig(TranscriptionConfig):
"""Configuration for WhisperKitPro transcription pipeline.

Supports two modes:
1. Legacy: model_version, model_prefix, model_repo_name
2. New: repo_id, model_variant (downloads and manages models)
Supports local model_dir only, Hugging Face repo_id + model_variant, or legacy model fields.
"""

cli_path: str = Field(
Expand Down Expand Up @@ -59,7 +57,7 @@ class WhisperKitProTranscriptionConfig(TranscriptionConfig):
)
model_dir: str | None = Field(
None,
description="Directory to cache downloaded models",
description="Existing local model directory for --model-path (no Hugging Face download when this is the only model source).",
)

audio_encoder_compute_units: ComputeUnit = Field(
Expand Down
Loading