Skip to content

Commit 5d9cacd

Browse files
author
miranov25
committed
fix(quantile_fit_nd): do not floor degenerate Δq windows; keep NaN and record reason
- Apply b_min only when a valid fit yields b<=0 (monotonicity enforcement) - For low-Q-spread / low-N windows, keep NaN (no floor), add reason in fit_stats - Greatly reduces bias in Poisson case; z-bin averages use informative windows only
1 parent a578c17 commit 5d9cacd

File tree

1 file changed

+173
-93
lines changed

1 file changed

+173
-93
lines changed

UTILS/dfextensions/quantile_fit_nd/quantile_fit_nd.py

Lines changed: 173 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -61,28 +61,57 @@ def _linear_interp_1d(xc: np.ndarray, yc: np.ndarray, x: float) -> float:
6161
return float((1 - t) * y0 + t * y1)
6262

6363

64-
def _local_fit_delta_q(Q: np.ndarray, X: np.ndarray, q0: float) -> Tuple[float, float, float, int, Dict[str, float]]:
64+
def _local_fit_delta_q(Qw: np.ndarray, Xw: np.ndarray, q0: float) -> Tuple[float, float, float, int, Dict[str, Any]]:
6565
"""
66-
OLS for X = a + b*(Q - q0). Returns (a, b, sigma_X_given_Q, n_used, stats).
66+
Stable 2-parameter OLS in the Δq-centered model:
67+
X = a + b * (Q - q0)
68+
Returns:
69+
a, b, sigma_X|Q (RMS of residuals), n_used, stats(dict)
70+
Rejects windows with insufficient Q spread to estimate slope reliably.
6771
"""
68-
n = Q.size
69-
stats = {}
70-
if n < 6:
71-
return np.nan, np.nan, np.nan, n, {"n_used": n, "ok": False}
72-
dq = Q - q0
73-
dq0 = dq.mean()
74-
x0 = X.mean()
75-
dq_c = dq - dq0
76-
x_c = X - x0
77-
sxx = float(np.dot(dq_c, dq_c))
78-
if sxx <= 0:
79-
return np.nan, np.nan, np.nan, n, {"n_used": n, "ok": False}
80-
b = float(np.dot(dq_c, x_c) / sxx)
81-
a = x0 - b * dq0
82-
res = X - (a + b * (Q - q0))
83-
sig = float(np.sqrt(np.mean(res * res)))
84-
stats = {"n_used": n, "rms": float(np.sqrt(np.mean(res**2))), "ok": True}
85-
return a, b, sig, n, stats
72+
Qw = np.asarray(Qw, dtype=np.float64)
73+
Xw = np.asarray(Xw, dtype=np.float64)
74+
m = np.isfinite(Qw) & np.isfinite(Xw)
75+
Qw, Xw = Qw[m], Xw[m]
76+
n = Qw.size
77+
if n < 3:
78+
return np.nan, np.nan, np.nan, int(n), {"ok": False, "reason": "n<3"}
79+
80+
dq = Qw - q0
81+
# Degeneracy checks for discrete/plateau windows (typical in Poisson-CDF ranks)
82+
# Require at least 3 unique Q values and a minimal span in Q.
83+
uq = np.unique(np.round(Qw, 6)) # rounding collapses near-duplicates
84+
span_q = float(np.max(Qw) - np.min(Qw)) if n else 0.0
85+
if uq.size < 3 or span_q < 1e-3:
86+
return np.nan, np.nan, np.nan, int(n), {
87+
"ok": False, "reason": "low_Q_spread", "n_unique_q": int(uq.size), "span_q": span_q
88+
}
89+
90+
# Design matrix for OLS: [1, (Q - q0)]
91+
A = np.column_stack([np.ones(n, dtype=np.float64), dq])
92+
# Least squares solution (stable even when dq mean ≠ 0)
93+
sol, resid, rank, svals = np.linalg.lstsq(A, Xw, rcond=None)
94+
a, b = float(sol[0]), float(sol[1])
95+
96+
# Residual RMS as sigma_X|Q
97+
if n > 2:
98+
if resid.size > 0:
99+
rss = float(resid[0])
100+
else:
101+
# fallback if lstsq doesn't return resid (e.g., rank-deficient weird cases)
102+
rss = float(np.sum((Xw - (a + b * dq)) ** 2))
103+
sigmaX = float(np.sqrt(max(rss, 0.0) / (n - 2)))
104+
else:
105+
sigmaX = np.nan
106+
107+
stats = {
108+
"ok": True,
109+
"rms": sigmaX,
110+
"n_used": int(n),
111+
"n_unique_q": int(uq.size),
112+
"span_q": span_q,
113+
}
114+
return a, b, sigmaX, int(n), stats
86115

87116

88117
def _sigma_Q_from_sigmaX(b: float, sigma_X_given_Q: float) -> float:
@@ -115,17 +144,15 @@ def fit_quantile_linear_nd(
115144
) -> pd.DataFrame:
116145
"""
117146
Fit local linear inverse-CDF per channel, per (q_center, nuisance bins).
118-
Returns a flat DataFrame (calibration table) with coefficients and diagnostics.
147+
Degree-1, Δq-centered model: X = a + b*(Q - q_center).
119148
120-
Columns expected in df:
121-
- channel_key, Q, X, and nuisance columns per nuisance_axes dict.
122-
- mask_col (optional): True rows are excluded.
149+
Monotonicity:
150+
- Enforce floor b>=b_min ONLY for valid fits with non-positive b.
151+
- Degenerate windows (low Q spread / too few unique Q) remain NaN (no flooring).
123152
124-
Notes:
125-
- Degree-1 only, Δq-centered model: X = a + b*(Q - q_center).
126-
- b>0 enforced via floor (auto/fixed).
127-
- sigma_Q = sigma_X|Q / |b|
128-
- sigma_Q_irr left NaN unless a multiplicity model is provided downstream.
153+
sigma_Q = sigma_X|Q / |b|
154+
155+
Returns a flat DataFrame with coefficients and diagnostics.
129156
"""
130157
if nuisance_axes is None:
131158
nuisance_axes = {}
@@ -187,7 +214,7 @@ def fit_quantile_linear_nd(
187214
"q_center": float(q0),
188215
"a": np.nan, "b": np.nan, "sigma_Q": np.nan,
189216
"sigma_Q_irr": np.nan, "dX_dN": np.nan,
190-
"fit_stats": json.dumps({"n_used": n_keep, "ok": False, "masked_frac": float(masked_frac)})
217+
"fit_stats": json.dumps({"n_used": n_keep, "ok": False, "reason": "cell_n<6", "masked_frac": float(masked_frac)})
191218
}
192219
for ax_i, ax in enumerate(nuisance_axes):
193220
row[f"{ax}_center"] = float(axis_to_centers[ax][bin_key[ax_i]])
@@ -201,12 +228,11 @@ def fit_quantile_linear_nd(
201228
in_win = (Q_all >= q0 - dq) & (Q_all <= q0 + dq)
202229
n_win = int(in_win.sum())
203230

204-
# window-local auto b_min (compute BEFORE branching to avoid NameError)
231+
# window-local b_min (compute BEFORE branching)
205232
if b_min_option == "auto":
206233
if n_win > 1:
207234
sigmaX_win = float(np.std(X_all[in_win]))
208235
else:
209-
# fallback to overall scatter in this cell
210236
sigmaX_win = float(np.std(X_all)) if X_all.size > 1 else 0.0
211237
bmin = _auto_b_min(sigmaX_win, dq)
212238
else:
@@ -219,37 +245,48 @@ def fit_quantile_linear_nd(
219245
"a": np.nan, "b": np.nan, "sigma_Q": np.nan,
220246
"sigma_Q_irr": np.nan, "dX_dN": np.nan,
221247
"fit_stats": json.dumps({
222-
"n_used": n_win, "ok": False,
223-
"masked_frac": float(masked_frac),
224-
"b_min": float(bmin)
248+
"n_used": n_win, "ok": False, "reason": "win_n<6",
249+
"masked_frac": float(masked_frac), "b_min": float(bmin)
225250
})
226251
}
227252
else:
228253
a, b, sigX, n_used, stats = _local_fit_delta_q(Q_all[in_win], X_all[in_win], q0)
229254

230-
# monotonicity floor
231-
if not np.isfinite(b) or b <= 0.0:
232-
b = bmin
233-
clipped = True
255+
# If fit is NOT ok (e.g. low_Q_spread), keep NaNs (do NOT floor here)
256+
if not bool(stats.get("ok", True)):
257+
row = {
258+
"channel_id": ch_val,
259+
"q_center": float(q0),
260+
"a": np.nan, "b": np.nan, "sigma_Q": np.nan,
261+
"sigma_Q_irr": np.nan, "dX_dN": np.nan,
262+
"fit_stats": json.dumps({
263+
**stats, "ok": False, "n_used": int(n_used),
264+
"masked_frac": float(masked_frac), "b_min": float(bmin)
265+
})
266+
}
234267
else:
268+
# Valid fit: enforce b floor only if b<=0 (monotonicity)
235269
clipped = False
236-
237-
sigma_Q = _sigma_Q_from_sigmaX(b, sigX)
238-
fit_stats = {
239-
"n_used": int(n_used),
240-
"ok": bool(stats.get("ok", True)),
241-
"rms": float(stats.get("rms", np.nan)),
242-
"masked_frac": float(masked_frac),
243-
"clipped": bool(clipped),
244-
"b_min": float(bmin),
245-
}
246-
row = {
247-
"channel_id": ch_val,
248-
"q_center": float(q0),
249-
"a": float(a), "b": float(b), "sigma_Q": float(sigma_Q),
250-
"sigma_Q_irr": np.nan, "dX_dN": np.nan,
251-
"fit_stats": json.dumps(fit_stats)
252-
}
270+
if not np.isfinite(b) or b <= 0.0:
271+
b = max(bmin, 1e-9)
272+
clipped = True
273+
274+
sigma_Q = _sigma_Q_from_sigmaX(b, sigX)
275+
fit_stats = {
276+
**stats,
277+
"n_used": int(n_used),
278+
"ok": True,
279+
"masked_frac": float(masked_frac),
280+
"clipped": bool(clipped),
281+
"b_min": float(bmin),
282+
}
283+
row = {
284+
"channel_id": ch_val,
285+
"q_center": float(q0),
286+
"a": float(a), "b": float(b), "sigma_Q": float(sigma_Q),
287+
"sigma_Q_irr": np.nan, "dX_dN": np.nan,
288+
"fit_stats": json.dumps(fit_stats)
289+
}
253290

254291
# write nuisance centers and optional timestamp
255292
for ax_i, ax in enumerate(nuisance_axes):
@@ -295,6 +332,7 @@ def fit_quantile_linear_nd(
295332
return table
296333

297334

335+
298336
# --------------------------- Evaluator API -------------------------------
299337

300338
@dataclass
@@ -378,55 +416,97 @@ def params(self, *, channel_id: Any, q: float, **coords) -> Tuple[float, float,
378416

379417
def invert_rank(self, X: float, *, channel_id: Any, **coords) -> float:
380418
"""
381-
Invert amplitude -> rank using the Δq-centered grid with robust fixed-point iteration.
382-
383-
Steps:
384-
- Build candidate Q̂(q0) = q0 + (X - a(q0)) / b(q0) over all q-centers (at given nuisances).
385-
- Choose the self-consistent candidate (min |Q̂ - q0|) as the initial guess.
386-
- Run damped fixed-point iteration: q <- q + λ * (X - a(q)) / b(q), with λ in (0,1].
387-
- Clamp to [0,1]; stop when |Δq| < tol or max_iter reached.
388-
389-
Returns:
390-
q in [0,1], or NaN if unavailable.
419+
Invert amplitude -> rank using a monotone, piecewise-blended segment model:
420+
For q in [q_k, q_{k+1}], define
421+
X_blend(q) = (1-t)*(a_k + b_k*(q - q_k)) + t*(a_{k+1} + b_{k+1}*(q - q_{k+1})),
422+
t = (q - q_k) / (q_{k+1} - q_k).
423+
With b_k>0, X_blend is monotone increasing => solve X_blend(q)=X via bisection.
424+
Returns q in [0,1] or NaN if no information is available.
391425
"""
392426
item = self.store.get(channel_id)
393427
if item is None:
394428
return np.nan
395429

396-
a_vec = self._interp_nuisance_vector(item["A"], coords) # shape (n_q,)
397-
b_vec = self._interp_nuisance_vector(item["B"], coords) # shape (n_q,)
398430
qc = self.q_centers
431+
if qc.size < 2:
432+
return np.nan
433+
434+
# Interpolate nuisance -> vectors over q-centers
435+
a_vec = self._interp_nuisance_vector(item["A"], coords)
436+
b_vec = self._interp_nuisance_vector(item["B"], coords)
399437

400-
# Candidate ranks from all centers
401-
b_safe = np.where(np.isfinite(b_vec) & (b_vec > 0.0), b_vec, np.nan)
402-
with np.errstate(invalid="ignore", divide="ignore"):
403-
q_candidates = qc + (X - a_vec) / b_safe
438+
# Fill NaNs across q using linear interpolation on valid centers
439+
valid = np.isfinite(a_vec) & np.isfinite(b_vec) & (b_vec > 0.0)
440+
if valid.sum() < 2:
441+
return np.nan
404442

405-
dif = np.abs(q_candidates - qc)
406-
if not np.any(np.isfinite(dif)):
443+
def _fill1d(xc, y):
444+
v = np.isfinite(y)
445+
if v.sum() == 0:
446+
return y
447+
if v.sum() == 1:
448+
# only one point: flat fill
449+
y2 = np.full_like(y, y[v][0])
450+
return y2
451+
y2 = np.array(y, dtype=np.float64, copy=True)
452+
y2[~v] = np.interp(xc[~v], xc[v], y[v])
453+
return y2
454+
455+
a_f = _fill1d(qc, a_vec)
456+
b_f = _fill1d(qc, b_vec)
457+
# enforce positive floor to keep monotonicity
458+
b_f = np.where(np.isfinite(b_f) & (b_f > 0.0), b_f, 1e-9)
459+
460+
# Fast helpers for segment evaluation
461+
def X_blend(q: float) -> float:
462+
# find segment
463+
if q <= qc[0]:
464+
k = 0
465+
elif q >= qc[-1]:
466+
k = qc.size - 2
467+
else:
468+
k = int(np.clip(np.searchsorted(qc, q) - 1, 0, qc.size - 2))
469+
qk, qk1 = qc[k], qc[k + 1]
470+
t = (q - qk) / (qk1 - qk) if qk1 > qk else 0.0
471+
ak, bk = a_f[k], b_f[k]
472+
ak1, bk1 = a_f[k + 1], b_f[k + 1]
473+
xk = ak + bk * (q - qk)
474+
xk1 = ak1 + bk1 * (q - qk1)
475+
return float((1.0 - t) * xk + t * xk1)
476+
477+
# Bracket on [0,1]
478+
f0 = X_blend(0.0) - X
479+
f1 = X_blend(1.0) - X
480+
if not np.isfinite(f0) or not np.isfinite(f1):
407481
return np.nan
408-
j0 = int(np.nanargmin(dif))
409-
q = float(np.clip(q_candidates[j0], 0.0, 1.0))
410-
411-
# Damped fixed-point
412-
max_iter = 10
413-
tol = 1e-6
414-
lam = 0.8 # damping
415-
for _ in range(max_iter):
416-
a = _linear_interp_1d(qc, a_vec, q)
417-
b = _linear_interp_1d(qc, b_vec, q)
418-
if not np.isfinite(a) or not np.isfinite(b) or b <= 0.0:
419-
break
420-
step = (X - a) / b
421-
if not np.isfinite(step):
482+
483+
# If not bracketed, clamp to nearest end (rare with our synthetic noise)
484+
if f0 == 0.0:
485+
return 0.0
486+
if f1 == 0.0:
487+
return 1.0
488+
if f0 > 0.0 and f1 > 0.0:
489+
return 0.0
490+
if f0 < 0.0 and f1 < 0.0:
491+
return 1.0
492+
493+
# Bisection
494+
lo, hi = 0.0, 1.0
495+
flo, fhi = f0, f1
496+
for _ in range(40):
497+
mid = 0.5 * (lo + hi)
498+
fm = X_blend(mid) - X
499+
if not np.isfinite(fm):
422500
break
423-
q_new = float(np.clip(q + lam * step, 0.0, 1.0))
424-
if abs(q_new - q) < tol:
425-
q = q_new
501+
# root in [lo, mid] ?
502+
if (flo <= 0.0 and fm >= 0.0) or (flo >= 0.0 and fm <= 0.0):
503+
hi, fhi = mid, fm
504+
else:
505+
lo, flo = mid, fm
506+
if abs(hi - lo) < 1e-6:
426507
break
427-
q = q_new
508+
return float(0.5 * (lo + hi))
428509

429-
return q
430510

431511

432512
# ------------------------------ I/O helpers ------------------------------

0 commit comments

Comments
 (0)