Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions src/sharp/cli/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,17 @@ def _finalize_prediction(
target_size_wh=(1536, 1536),
dtype=torch.float32,
)
if aux.get("highlight_rolloff_applied"):
LOGGER.info(
"Highlight rolloff enabled [trigger=%s] for %s (white_frac=%.6f, "
"sat_frac=%.6f; white_thr=%.6f, sat_fallback=%.6f)",
aux.get("highlight_rolloff_trigger", "unknown"),
image_path,
aux.get("highlight_white_frac", 0.0),
aux.get("highlight_sat_frac", 0.0),
aux.get("highlight_white_frac_threshold", 0.0),
aux.get("highlight_sat_frac_fallback_threshold", 0.0),
)
aux["metrics"] = metrics
preprocess_elapsed = 0.0
if metrics and preprocess_start is not None:
Expand Down Expand Up @@ -1010,6 +1021,42 @@ def preprocess_one(
image_np = image_np.copy()
image_pt = torch.from_numpy(image_np).to(dtype=dtype, device=device).permute(2, 0, 1)
image_pt = image_pt / 255.0
sat_threshold = 0.98
white_threshold = 0.98
chroma_threshold = 0.08
white_frac_threshold = 0.01
sat_frac_fallback_threshold = 0.30
knee_start = 0.98
rolloff_strength = 4.0
max_rgb = image_pt.max(dim=0).values
min_rgb = image_pt.min(dim=0).values
chroma = max_rgb - min_rgb
sat_mask = max_rgb >= sat_threshold
white_mask = (min_rgb >= white_threshold) & (chroma <= chroma_threshold)
sat_frac = float(sat_mask.float().mean().item())
white_frac = float(white_mask.float().mean().item())
highlight_rolloff_applied = False
highlight_rolloff_trigger = "none"
if white_frac >= white_frac_threshold:
highlight_rolloff_applied = True
highlight_rolloff_trigger = "white"
elif sat_frac >= sat_frac_fallback_threshold:
highlight_rolloff_applied = True
highlight_rolloff_trigger = "sat_fallback"
if highlight_rolloff_applied:
knee_start_pt = torch.tensor(knee_start, dtype=image_pt.dtype, device=image_pt.device)
one_minus_knee = torch.tensor(
1.0 - knee_start, dtype=image_pt.dtype, device=image_pt.device
)
rolloff_strength_pt = torch.tensor(
rolloff_strength, dtype=image_pt.dtype, device=image_pt.device
)
normalized = (image_pt - knee_start_pt) / one_minus_knee
rolloff = knee_start_pt + one_minus_knee * (
1.0 - torch.exp(-rolloff_strength_pt * normalized)
)
image_pt = torch.where(image_pt > knee_start_pt, rolloff, image_pt)
image_pt = image_pt.clamp(0.0, 1.0)
_, height, width = image_pt.shape
disparity_factor_pt = torch.tensor([f_px / width], dtype=dtype, device=device)
image_resized_pt = F.interpolate(
Expand All @@ -1031,6 +1078,15 @@ def preprocess_one(
"f_px": f_px,
"target_w": target_w,
"target_h": target_h,
"highlight_rolloff_applied": highlight_rolloff_applied,
"highlight_rolloff_trigger": highlight_rolloff_trigger,
"highlight_sat_frac": sat_frac,
"highlight_white_frac": white_frac,
"highlight_sat_threshold": sat_threshold,
"highlight_white_threshold": white_threshold,
"highlight_chroma_threshold": chroma_threshold,
"highlight_white_frac_threshold": white_frac_threshold,
"highlight_sat_frac_fallback_threshold": sat_frac_fallback_threshold,
}
return image_resized_pt, disparity_factor_pt, aux

Expand Down Expand Up @@ -1203,6 +1259,17 @@ def predict_image(
target_size_wh=target_size_wh,
dtype=torch.float32,
)
if aux.get("highlight_rolloff_applied"):
LOGGER.info(
"Highlight rolloff enabled [trigger=%s] for %s (white_frac=%.6f, sat_frac=%.6f; "
"white_thr=%.6f, sat_fallback=%.6f)",
aux.get("highlight_rolloff_trigger", "unknown"),
"<single image>",
aux.get("highlight_white_frac", 0.0),
aux.get("highlight_sat_frac", 0.0),
aux.get("highlight_white_frac_threshold", 0.0),
aux.get("highlight_sat_frac_fallback_threshold", 0.0),
)
aux["metrics"] = metrics
if metrics and preprocess_start is not None:
metrics.add_time("preprocess", perf_counter() - preprocess_start)
Expand Down