Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies = [
"numpy",
"openai",
"pillow",
"rapidfuzz",
"rich",
"scipy",
"stamina",
Expand Down
33 changes: 33 additions & 0 deletions src/ocr_bench/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
load_existing_metadata,
publish_results,
)
from ocr_bench.standard_eval import evaluate_against_ground_truth
from ocr_bench.task_config import DEFAULT_GROUND_TRUTH_COLUMN

logger = structlog.get_logger()
console = Console()
Expand All @@ -47,6 +49,11 @@ def build_parser() -> argparse.ArgumentParser:
# Dataset
judge.add_argument("dataset", help="HF dataset repo id")
judge.add_argument("--split", default="train", help="Dataset split (default: train)")
judge.add_argument(
"--ground-truth-column",
default=DEFAULT_GROUND_TRUTH_COLUMN,
help=f"Ground truth column for standard metrics (default: {DEFAULT_GROUND_TRUTH_COLUMN})",
)
judge.add_argument("--columns", nargs="+", default=None, help="Explicit OCR column names")
judge.add_argument(
"--configs", nargs="+", default=None, help="Config-per-model: list of config names"
Expand Down Expand Up @@ -284,6 +291,32 @@ def cmd_judge(args: argparse.Namespace) -> None:
for col, model in ocr_columns.items():
console.print(f" {col} → {model}")

standard_metrics = evaluate_against_ground_truth(
ds,
ocr_columns,
ground_truth_column=args.ground_truth_column,
)
if standard_metrics:
metrics_table = Table(title="Standard Evaluation (Dummy Scaffold)")
metrics_table.add_column("Model")
metrics_table.add_column("Samples", justify="right")
metrics_table.add_column("Global F1", justify="right")
metrics_table.add_column("Jury global F1", justify="right")
for metric in standard_metrics:
metrics_table.add_row(
metric.model,
str(metric.samples),
f"{metric.global_f1:.3f}",
f"{metric.jury_global_f1:.3f}",
)
console.print()
console.print(metrics_table)
else:
console.print(
f"[yellow]Standard evaluation skipped:[/yellow] "
f"missing or incompatible '{args.ground_truth_column}' column."
)

# --- Incremental: load existing comparisons ---
existing_results: list[ComparisonResult] = []
existing_meta_rows: list[dict] = []
Expand Down
62 changes: 24 additions & 38 deletions src/ocr_bench/run.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""OCR model orchestration — launch HF Jobs for multiple OCR models."""
"""VLM model orchestration — launch HF Jobs for metadata extraction models."""

from __future__ import annotations

Expand All @@ -8,6 +8,12 @@
import structlog
from huggingface_hub import HfApi, get_token

from ocr_bench.task_config import (
DEFAULT_IMAGE_COLUMN,
DEFAULT_SOURCE_DATASET,
build_default_task_prompt,
)

logger = structlog.get_logger()


Expand All @@ -23,52 +29,28 @@ class ModelConfig:


MODEL_REGISTRY: dict[str, ModelConfig] = {
"glm-ocr": ModelConfig(
script="https://huggingface.co/datasets/uv-scripts/ocr/raw/main/glm-ocr.py",
model_id="zai-org/GLM-OCR",
size="0.9B",
default_flavor="l4x1",
),
"deepseek-ocr": ModelConfig(
script="https://huggingface.co/datasets/uv-scripts/ocr/raw/main/deepseek-ocr-vllm.py",
model_id="deepseek-ai/DeepSeek-OCR",
"qwen3-vl-4b-instruct": ModelConfig(
script="https://huggingface.co/datasets/uv-scripts/ocr/raw/main/vlm-metadata-extraction.py",
model_id="Qwen/Qwen3-VL-4B-Instruct",
size="4B",
default_flavor="l4x1",
default_args=["--prompt-mode", "free"],
),
"lighton-ocr-2": ModelConfig(
script="https://huggingface.co/datasets/uv-scripts/ocr/raw/main/lighton-ocr2.py",
model_id="lightonai/LightOnOCR-2-1B",
size="1B",
default_flavor="a100-large",
),
"dots-ocr": ModelConfig(
script="https://huggingface.co/datasets/uv-scripts/ocr/raw/main/dots-ocr.py",
model_id="rednote-hilab/dots.ocr",
size="1.7B",
default_flavor="l4x1",
),
"firered-ocr": ModelConfig(
script="https://huggingface.co/datasets/uv-scripts/ocr/raw/main/firered-ocr.py",
model_id="FireRedTeam/FireRed-OCR",
size="2.1B",
default_flavor="l4x1",
),
"qianfan-ocr": ModelConfig(
script="https://huggingface.co/datasets/uv-scripts/ocr/raw/main/qianfan-ocr.py",
model_id="baidu/Qianfan-OCR",
size="4.7B",
"nanonets-ocr2-3b": ModelConfig(
script="https://huggingface.co/datasets/uv-scripts/ocr/raw/main/vlm-metadata-extraction.py",
model_id="nanonets/Nanonets-OCR2-3B",
size="3B",
default_flavor="l4x1",
),
"dots-mocr": ModelConfig(
script="https://huggingface.co/datasets/uv-scripts/ocr/raw/main/dots-mocr.py",
model_id="rednote-hilab/dots.mocr",
size="3B",
"gemma-4-e4b-it": ModelConfig(
script="https://huggingface.co/datasets/uv-scripts/ocr/raw/main/vlm-metadata-extraction.py",
model_id="google/gemma-4-E4B-it",
size="4B",
default_flavor="l4x1",
),
}

DEFAULT_MODELS = ["glm-ocr", "deepseek-ocr", "lighton-ocr-2", "dots-ocr", "firered-ocr"]
DEFAULT_MODELS = ["qwen3-vl-4b-instruct", "nanonets-ocr2-3b", "gemma-4-e4b-it"]
DEFAULT_TASK_PROMPT = build_default_task_prompt()


@dataclass
Expand Down Expand Up @@ -103,6 +85,10 @@ def build_script_args(
"--config",
config_name,
"--create-pr",
"--image-column",
DEFAULT_IMAGE_COLUMN,
"--prompt",
DEFAULT_TASK_PROMPT,
]
if max_samples is not None:
args += ["--max-samples", str(max_samples)]
Expand Down
Loading