Skip to content

Commit b4b5b41

Browse files
author
miranov25
committed
fix(dfextensions/quantile_fit_nd): evaluator axis bug + window-local b_min + stable inversion
- QuantileEvaluator: exclude 'q_center' from nuisance axes (fix AxisError in moveaxis) - Groupby: use scalar grouper for single nuisance bin column (silence FutureWarning) - Fit: compute b_min per |Q−q0|≤dq window (avoid over-clipping b in low-b regions) - Inversion: implement self-consistent candidate + 2-step fixed-point refine (invert_rank) - Keep API/metadata unchanged; prepare for ND nuisances and time
1 parent 6d65a12 commit b4b5b41

File tree

1 file changed

+45
-35
lines changed

1 file changed

+45
-35
lines changed

UTILS/dfextensions/quantile_fit_nd/quantile_fit_nd.py

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def fit_quantile_linear_nd(
107107
nuisance_axes: Dict[str, str] = None, # e.g. {"z": "z_vtx", "eta": "eta"}
108108
n_bins_axes: Dict[str, int] = None, # e.g. {"z": 10}
109109
mask_col: Optional[str] = "is_outlier",
110-
b_min_option: str = "auto", # "auto" or "fixed"
110+
b_min_option: str = "auto", # "auto" or "fixed"
111111
b_min_value: float = 1e-6,
112112
fit_mode: str = "ols",
113113
kappa_w: float = 1.3,
@@ -117,54 +117,52 @@ def fit_quantile_linear_nd(
117117
Fit local linear inverse-CDF per channel, per (q_center, nuisance bins).
118118
Returns a flat DataFrame (calibration table) with coefficients and diagnostics.
119119
120-
Columns expected:
120+
Columns expected in df:
121121
- channel_key, Q, X, and nuisance columns per nuisance_axes dict.
122122
- mask_col (optional): True rows are excluded.
123123
124124
Notes:
125-
- degree-1 only, Δq-centered model.
125+
- Degree-1 only, Δq-centered model: X = a + b*(Q - q_center).
126126
- b>0 enforced via floor (auto/fixed).
127127
- sigma_Q = sigma_X|Q / |b|
128-
- sigma_Q_irr optional (needs dX/dN proxy; here left NaN by default).
128+
- sigma_Q_irr left NaN unless a multiplicity model is provided downstream.
129129
"""
130130
if nuisance_axes is None:
131131
nuisance_axes = {}
132132
if n_bins_axes is None:
133133
n_bins_axes = {ax: 10 for ax in nuisance_axes}
134+
134135
df = df.copy()
135136

137+
# Ensure a boolean keep-mask exists
136138
if mask_col is None or mask_col not in df.columns:
137139
df["_mask_keep"] = True
138140
mask_col_use = "_mask_keep"
139141
else:
140142
mask_col_use = mask_col
141143

142-
# Prepare nuisance bin centers per axis
144+
# ------------------------ build nuisance binning ------------------------
143145
axis_to_centers: Dict[str, np.ndarray] = {}
144146
axis_to_idxcol: Dict[str, str] = {}
145-
146147
for ax, col in nuisance_axes.items():
147148
centers = _build_uniform_centers(df[col].to_numpy(np.float64), int(n_bins_axes.get(ax, 10)))
148149
axis_to_centers[ax] = centers
149150
idxcol = f"__bin_{ax}"
150151
df[idxcol] = _assign_bin_indices(df[col].to_numpy(np.float64), centers)
151152
axis_to_idxcol[ax] = idxcol
152153

153-
# Group by channel and nuisance bin tuple
154154
bin_cols = [axis_to_idxcol[a] for a in nuisance_axes]
155-
out_rows = []
155+
out_rows: list[dict] = []
156156

157-
# iterate per channel
157+
# --------------------------- iterate channels --------------------------
158158
for ch_val, df_ch in df.groupby(channel_key, sort=False, dropna=False):
159159
# iterate bins of nuisance axes
160160
if bin_cols:
161161
if len(bin_cols) == 1:
162-
# avoid FutureWarning: use scalar grouper when only one column
163-
gb = df_ch.groupby(bin_cols[0], sort=False, dropna=False)
162+
gb = df_ch.groupby(bin_cols[0], sort=False, dropna=False) # avoid FutureWarning
164163
else:
165164
gb = df_ch.groupby(bin_cols, sort=False, dropna=False)
166165
else:
167-
# single group with empty tuple key
168166
df_ch = df_ch.copy()
169167
df_ch["__bin_dummy__"] = 0
170168
gb = df_ch.groupby(["__bin_dummy__"], sort=False, dropna=False)
@@ -174,52 +172,68 @@ def fit_quantile_linear_nd(
174172
bin_key = (bin_key,)
175173

176174
# select non-outliers
177-
gmask = (df_cell[mask_col_use] == False) if mask_col_use in df_cell.columns else np.ones(len(df_cell), dtype=bool)
178-
if gmask.sum() < 6:
179-
# record empty cells as NaN rows for all q_centers (optional)
175+
keep = (df_cell[mask_col_use] == False) if mask_col_use in df_cell.columns else np.ones(len(df_cell), dtype=bool)
176+
n_keep = int(keep.sum())
177+
masked_frac = 1.0 - float(keep.mean())
178+
179+
X_all = df_cell.loc[keep, "X"].to_numpy(np.float64)
180+
Q_all = df_cell.loc[keep, "Q"].to_numpy(np.float64)
181+
182+
# If too few points overall, emit NaNs for all q-centers in this cell
183+
if n_keep < 6:
180184
for q0 in q_centers:
181185
row = {
182186
"channel_id": ch_val,
183187
"q_center": float(q0),
184188
"a": np.nan, "b": np.nan, "sigma_Q": np.nan,
185189
"sigma_Q_irr": np.nan, "dX_dN": np.nan,
186-
"fit_stats": json.dumps({"n_used": int(gmask.sum()), "ok": False, "masked_frac": float(1.0 - gmask.mean())})
190+
"fit_stats": json.dumps({"n_used": n_keep, "ok": False, "masked_frac": float(masked_frac)})
187191
}
188-
# write nuisance centers
189192
for ax_i, ax in enumerate(nuisance_axes):
190193
row[f"{ax}_center"] = float(axis_to_centers[ax][bin_key[ax_i]])
191194
if timestamp is not None:
192195
row["timestamp"] = timestamp
193196
out_rows.append(row)
194197
continue
195198

196-
X_all = df_cell.loc[gmask, "X"].to_numpy(np.float64)
197-
Q_all = df_cell.loc[gmask, "Q"].to_numpy(np.float64)
198-
199-
# stats for auto floor
200-
sigmaX_cell = float(np.std(X_all)) if X_all.size > 1 else 0.0
201-
bmin = _auto_b_min(sigmaX_cell, dq) if b_min_option == "auto" else float(b_min_value)
202-
203-
masked_frac = 1.0 - float(gmask.mean())
204-
199+
# -------------------- per-q_center sliding window --------------------
205200
for q0 in q_centers:
206201
in_win = (Q_all >= q0 - dq) & (Q_all <= q0 + dq)
207-
if in_win.sum() < 6:
202+
n_win = int(in_win.sum())
203+
204+
# window-local auto b_min (compute BEFORE branching to avoid NameError)
205+
if b_min_option == "auto":
206+
if n_win > 1:
207+
sigmaX_win = float(np.std(X_all[in_win]))
208+
else:
209+
# fallback to overall scatter in this cell
210+
sigmaX_win = float(np.std(X_all)) if X_all.size > 1 else 0.0
211+
bmin = _auto_b_min(sigmaX_win, dq)
212+
else:
213+
bmin = float(b_min_value)
214+
215+
if n_win < 6:
208216
row = {
209217
"channel_id": ch_val,
210218
"q_center": float(q0),
211219
"a": np.nan, "b": np.nan, "sigma_Q": np.nan,
212220
"sigma_Q_irr": np.nan, "dX_dN": np.nan,
213-
"fit_stats": json.dumps({"n_used": int(in_win.sum()), "ok": False, "masked_frac": masked_frac})
221+
"fit_stats": json.dumps({
222+
"n_used": n_win, "ok": False,
223+
"masked_frac": float(masked_frac),
224+
"b_min": float(bmin)
225+
})
214226
}
215227
else:
216228
a, b, sigX, n_used, stats = _local_fit_delta_q(Q_all[in_win], X_all[in_win], q0)
229+
217230
# monotonicity floor
218231
if not np.isfinite(b) or b <= 0.0:
219232
b = bmin
220233
clipped = True
221234
else:
222235
clipped = False
236+
223237
sigma_Q = _sigma_Q_from_sigmaX(b, sigX)
224238
fit_stats = {
225239
"n_used": int(n_used),
@@ -237,7 +251,7 @@ def fit_quantile_linear_nd(
237251
"fit_stats": json.dumps(fit_stats)
238252
}
239253

240-
# write nuisance centers
254+
# write nuisance centers and optional timestamp
241255
for ax_i, ax in enumerate(nuisance_axes):
242256
row[f"{ax}_center"] = float(axis_to_centers[ax][bin_key[ax_i]])
243257
if timestamp is not None:
@@ -246,7 +260,7 @@ def fit_quantile_linear_nd(
246260

247261
table = pd.DataFrame(out_rows)
248262

249-
# Attach metadata
263+
# ------------------------------ metadata ------------------------------
250264
table.attrs.update({
251265
"model": "X = a + b*(Q - q_center)",
252266
"dq": float(dq),
@@ -258,21 +272,17 @@ def fit_quantile_linear_nd(
258272
"channel_key": channel_key,
259273
})
260274

261-
# Finite-diff derivatives along nuisance axes (db_d<axis>)
275+
# --------- finite-difference derivatives along nuisance axes ----------
262276
for ax in nuisance_axes:
263-
# compute per-channel, per-q_center derivative across axis centers
264277
der_col = f"db_d{ax}"
265278
table[der_col] = np.nan
266-
# group by channel & q_center
267279
for (ch, q0), g in table.groupby(["channel_id", "q_center"], sort=False):
268280
centers = np.unique(g[f"{ax}_center"].to_numpy(np.float64))
269281
if centers.size < 2:
270282
continue
271-
# sort by center
272283
gg = g.sort_values(f"{ax}_center")
273284
bvals = gg["b"].to_numpy(np.float64)
274285
xc = gg[f"{ax}_center"].to_numpy(np.float64)
275-
# central differences
276286
d = np.full_like(bvals, np.nan)
277287
if bvals.size >= 2:
278288
d[0] = (bvals[1] - bvals[0]) / max(xc[1] - xc[0], 1e-12)

0 commit comments

Comments
 (0)