Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/openbench/dataset/dataset_transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class TranscriptionExtraInfo(TypedDict, total=False):

language: str
dictionary: list[str]
metric_keywords: list[str]


class TranscriptionRow(TypedDict):
Expand All @@ -24,6 +25,7 @@ class TranscriptionRow(TypedDict):
word_timestamps_end: NotRequired[list[float]]
language: NotRequired[str]
dictionary: NotRequired[list[str]]
metric_keywords: NotRequired[list[str]]


class TranscriptionSample(BaseSample[Transcript, TranscriptionExtraInfo]):
Expand All @@ -39,6 +41,11 @@ def dictionary(self) -> list[str] | None:
"""Convenience property to access dictionary from extra_info."""
return self.extra_info.get("dictionary")

@property
def metric_keywords(self) -> list[str] | None:
"""Convenience property to access metric keywords from extra_info."""
return self.extra_info.get("metric_keywords")


class TranscriptionDataset(BaseDataset[TranscriptionSample]):
"""Dataset for transcription pipelines with optional keyword support."""
Expand All @@ -60,4 +67,7 @@ def prepare_sample(self, row: TranscriptionRow) -> tuple[Transcript, Transcripti
extra_info["language"] = row["language"]
if "dictionary" in row:
extra_info["dictionary"] = row["dictionary"]
metric_keywords = row.get("metric-keywords") or row.get("metric_keywords")
if metric_keywords:
extra_info["metric_keywords"] = metric_keywords
return reference, extra_info
17 changes: 13 additions & 4 deletions src/openbench/runner/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,11 @@ def _process_single_sample(

for metric_name, metric in metrics_dict.items():
reference = sample.reference
kwargs = sample.extra_info
kwargs = dict(sample.extra_info)

# Use metric_keywords for metric calculation if available, falling back to dictionary
if "metric_keywords" in kwargs:
kwargs["dictionary"] = kwargs.pop("metric_keywords")

# The metric returns a dictionary that is also stored in the metric object as a state to compute the global result
# We copy to avoid any side effects that may happen while interacting with dictionary for reporting
Expand Down Expand Up @@ -254,10 +258,15 @@ def _run_pipeline_on_dataset_parallel(
# Update metric with all results
for sample_result in per_sample_results:
sample = dataset[sample_result.sample_id]
# Get UEM from extra_info if available
kwargs = {}
if hasattr(sample, "extra_info") and "uem" in sample.extra_info:
kwargs["uem"] = sample.extra_info["uem"]
if hasattr(sample, "extra_info"):
if "uem" in sample.extra_info:
kwargs["uem"] = sample.extra_info["uem"]
# Use metric_keywords for metric calculation if available, falling back to dictionary
if "metric_keywords" in sample.extra_info:
kwargs["dictionary"] = sample.extra_info["metric_keywords"]
elif "dictionary" in sample.extra_info:
kwargs["dictionary"] = sample.extra_info["dictionary"]

metric(hypothesis=sample_result.prediction, reference=sample.reference, detailed=True, **kwargs)

Expand Down
Loading