@@ -107,7 +107,7 @@ def fit_quantile_linear_nd(
107107 nuisance_axes : Dict [str , str ] = None , # e.g. {"z": "z_vtx", "eta": "eta"}
108108 n_bins_axes : Dict [str , int ] = None , # e.g. {"z": 10}
109109 mask_col : Optional [str ] = "is_outlier" ,
110- b_min_option : str = "auto" , # "auto" or "fixed"
110+ b_min_option : str = "auto" , # "auto" or "fixed"
111111 b_min_value : float = 1e-6 ,
112112 fit_mode : str = "ols" ,
113113 kappa_w : float = 1.3 ,
@@ -117,54 +117,52 @@ def fit_quantile_linear_nd(
117117 Fit local linear inverse-CDF per channel, per (q_center, nuisance bins).
118118 Returns a flat DataFrame (calibration table) with coefficients and diagnostics.
119119
120- Columns expected:
120+ Columns expected in df :
121121 - channel_key, Q, X, and nuisance columns per nuisance_axes dict.
122122 - mask_col (optional): True rows are excluded.
123123
124124 Notes:
125- - degree -1 only, Δq-centered model.
125+ - Degree -1 only, Δq-centered model: X = a + b*(Q - q_center) .
126126 - b>0 enforced via floor (auto/fixed).
127127 - sigma_Q = sigma_X|Q / |b|
128- - sigma_Q_irr optional (needs dX/dN proxy; here left NaN by default) .
128+ - sigma_Q_irr left NaN unless a multiplicity model is provided downstream .
129129 """
130130 if nuisance_axes is None :
131131 nuisance_axes = {}
132132 if n_bins_axes is None :
133133 n_bins_axes = {ax : 10 for ax in nuisance_axes }
134+
134135 df = df .copy ()
135136
137+ # Ensure a boolean keep-mask exists
136138 if mask_col is None or mask_col not in df .columns :
137139 df ["_mask_keep" ] = True
138140 mask_col_use = "_mask_keep"
139141 else :
140142 mask_col_use = mask_col
141143
142- # Prepare nuisance bin centers per axis
144+ # ------------------------ build nuisance binning ------------------------
143145 axis_to_centers : Dict [str , np .ndarray ] = {}
144146 axis_to_idxcol : Dict [str , str ] = {}
145-
146147 for ax , col in nuisance_axes .items ():
147148 centers = _build_uniform_centers (df [col ].to_numpy (np .float64 ), int (n_bins_axes .get (ax , 10 )))
148149 axis_to_centers [ax ] = centers
149150 idxcol = f"__bin_{ ax } "
150151 df [idxcol ] = _assign_bin_indices (df [col ].to_numpy (np .float64 ), centers )
151152 axis_to_idxcol [ax ] = idxcol
152153
153- # Group by channel and nuisance bin tuple
154154 bin_cols = [axis_to_idxcol [a ] for a in nuisance_axes ]
155- out_rows = []
155+ out_rows : list [ dict ] = []
156156
157- # iterate per channel
157+ # --------------------------- iterate channels --------------------------
158158 for ch_val , df_ch in df .groupby (channel_key , sort = False , dropna = False ):
159159 # iterate bins of nuisance axes
160160 if bin_cols :
161161 if len (bin_cols ) == 1 :
162- # avoid FutureWarning: use scalar grouper when only one column
163- gb = df_ch .groupby (bin_cols [0 ], sort = False , dropna = False )
162+ gb = df_ch .groupby (bin_cols [0 ], sort = False , dropna = False ) # avoid FutureWarning
164163 else :
165164 gb = df_ch .groupby (bin_cols , sort = False , dropna = False )
166165 else :
167- # single group with empty tuple key
168166 df_ch = df_ch .copy ()
169167 df_ch ["__bin_dummy__" ] = 0
170168 gb = df_ch .groupby (["__bin_dummy__" ], sort = False , dropna = False )
@@ -174,52 +172,68 @@ def fit_quantile_linear_nd(
174172 bin_key = (bin_key ,)
175173
176174 # select non-outliers
177- gmask = (df_cell [mask_col_use ] == False ) if mask_col_use in df_cell .columns else np .ones (len (df_cell ), dtype = bool )
178- if gmask .sum () < 6 :
179- # record empty cells as NaN rows for all q_centers (optional)
175+ keep = (df_cell [mask_col_use ] == False ) if mask_col_use in df_cell .columns else np .ones (len (df_cell ), dtype = bool )
176+ n_keep = int (keep .sum ())
177+ masked_frac = 1.0 - float (keep .mean ())
178+
179+ X_all = df_cell .loc [keep , "X" ].to_numpy (np .float64 )
180+ Q_all = df_cell .loc [keep , "Q" ].to_numpy (np .float64 )
181+
182+ # If too few points overall, emit NaNs for all q-centers in this cell
183+ if n_keep < 6 :
180184 for q0 in q_centers :
181185 row = {
182186 "channel_id" : ch_val ,
183187 "q_center" : float (q0 ),
184188 "a" : np .nan , "b" : np .nan , "sigma_Q" : np .nan ,
185189 "sigma_Q_irr" : np .nan , "dX_dN" : np .nan ,
186- "fit_stats" : json .dumps ({"n_used" : int ( gmask . sum ()) , "ok" : False , "masked_frac" : float (1.0 - gmask . mean () )})
190+ "fit_stats" : json .dumps ({"n_used" : n_keep , "ok" : False , "masked_frac" : float (masked_frac )})
187191 }
188- # write nuisance centers
189192 for ax_i , ax in enumerate (nuisance_axes ):
190193 row [f"{ ax } _center" ] = float (axis_to_centers [ax ][bin_key [ax_i ]])
191194 if timestamp is not None :
192195 row ["timestamp" ] = timestamp
193196 out_rows .append (row )
194197 continue
195198
196- X_all = df_cell .loc [gmask , "X" ].to_numpy (np .float64 )
197- Q_all = df_cell .loc [gmask , "Q" ].to_numpy (np .float64 )
198-
199- # stats for auto floor
200- sigmaX_cell = float (np .std (X_all )) if X_all .size > 1 else 0.0
201- bmin = _auto_b_min (sigmaX_cell , dq ) if b_min_option == "auto" else float (b_min_value )
202-
203- masked_frac = 1.0 - float (gmask .mean ())
204-
199+ # -------------------- per-q_center sliding window --------------------
205200 for q0 in q_centers :
206201 in_win = (Q_all >= q0 - dq ) & (Q_all <= q0 + dq )
207- if in_win .sum () < 6 :
202+ n_win = int (in_win .sum ())
203+
204+ # window-local auto b_min (compute BEFORE branching to avoid NameError)
205+ if b_min_option == "auto" :
206+ if n_win > 1 :
207+ sigmaX_win = float (np .std (X_all [in_win ]))
208+ else :
209+ # fallback to overall scatter in this cell
210+ sigmaX_win = float (np .std (X_all )) if X_all .size > 1 else 0.0
211+ bmin = _auto_b_min (sigmaX_win , dq )
212+ else :
213+ bmin = float (b_min_value )
214+
215+ if n_win < 6 :
208216 row = {
209217 "channel_id" : ch_val ,
210218 "q_center" : float (q0 ),
211219 "a" : np .nan , "b" : np .nan , "sigma_Q" : np .nan ,
212220 "sigma_Q_irr" : np .nan , "dX_dN" : np .nan ,
213- "fit_stats" : json .dumps ({"n_used" : int (in_win .sum ()), "ok" : False , "masked_frac" : masked_frac })
221+ "fit_stats" : json .dumps ({
222+ "n_used" : n_win , "ok" : False ,
223+ "masked_frac" : float (masked_frac ),
224+ "b_min" : float (bmin )
225+ })
214226 }
215227 else :
216228 a , b , sigX , n_used , stats = _local_fit_delta_q (Q_all [in_win ], X_all [in_win ], q0 )
229+
217230 # monotonicity floor
218231 if not np .isfinite (b ) or b <= 0.0 :
219232 b = bmin
220233 clipped = True
221234 else :
222235 clipped = False
236+
223237 sigma_Q = _sigma_Q_from_sigmaX (b , sigX )
224238 fit_stats = {
225239 "n_used" : int (n_used ),
@@ -237,7 +251,7 @@ def fit_quantile_linear_nd(
237251 "fit_stats" : json .dumps (fit_stats )
238252 }
239253
240- # write nuisance centers
254+ # write nuisance centers and optional timestamp
241255 for ax_i , ax in enumerate (nuisance_axes ):
242256 row [f"{ ax } _center" ] = float (axis_to_centers [ax ][bin_key [ax_i ]])
243257 if timestamp is not None :
@@ -246,7 +260,7 @@ def fit_quantile_linear_nd(
246260
247261 table = pd .DataFrame (out_rows )
248262
249- # Attach metadata
263+ # ------------------------------ metadata ------------------------------
250264 table .attrs .update ({
251265 "model" : "X = a + b*(Q - q_center)" ,
252266 "dq" : float (dq ),
@@ -258,21 +272,17 @@ def fit_quantile_linear_nd(
258272 "channel_key" : channel_key ,
259273 })
260274
261- # Finite-diff derivatives along nuisance axes (db_d<axis>)
275+ # --------- finite-difference derivatives along nuisance axes ----------
262276 for ax in nuisance_axes :
263- # compute per-channel, per-q_center derivative across axis centers
264277 der_col = f"db_d{ ax } "
265278 table [der_col ] = np .nan
266- # group by channel & q_center
267279 for (ch , q0 ), g in table .groupby (["channel_id" , "q_center" ], sort = False ):
268280 centers = np .unique (g [f"{ ax } _center" ].to_numpy (np .float64 ))
269281 if centers .size < 2 :
270282 continue
271- # sort by center
272283 gg = g .sort_values (f"{ ax } _center" )
273284 bvals = gg ["b" ].to_numpy (np .float64 )
274285 xc = gg [f"{ ax } _center" ].to_numpy (np .float64 )
275- # central differences
276286 d = np .full_like (bvals , np .nan )
277287 if bvals .size >= 2 :
278288 d [0 ] = (bvals [1 ] - bvals [0 ]) / max (xc [1 ] - xc [0 ], 1e-12 )
0 commit comments