@@ -102,7 +102,7 @@ def make_linear_fit(
102102 return df , dfGB
103103
104104 @staticmethod
105- def process_group_robust (
105+ def process_group_robustBackup (
106106 key : tuple ,
107107 df_group : pd .DataFrame ,
108108 gb_columns : List [str ],
@@ -114,7 +114,7 @@ def process_group_robust(
114114 sigmaCut : float = 4 ,
115115 fitter : Union [str , Callable ] = "auto"
116116 ) -> dict :
117- # TODO 0handle the case os singl gb column
117+ # TODO handle the case os single gb column
118118 group_dict = dict (zip (gb_columns , key ))
119119 predictors = []
120120 if isinstance (weights , str ) and weights not in df_group .columns :
@@ -213,6 +213,248 @@ def process_group_robust(
213213
214214 return group_dict
215215
216+ @staticmethod
217+ def process_group_robust (
218+ key : tuple ,
219+ df_group : pd .DataFrame ,
220+ gb_columns : List [str ],
221+ fit_columns : List [str ],
222+ linear_columns0 : List [str ],
223+ median_columns : List [str ],
224+ weights : str ,
225+ minStat : List [int ],
226+ sigmaCut : float = 4 ,
227+ fitter : Union [str , Callable ] = "auto" ,
228+ # --- NEW (optional) diagnostics ---
229+ diag : bool = False ,
230+ diag_prefix : str = "diag_" ,
231+ ) -> dict :
232+ """
233+ Per-group robust/OLS fit with optional diagnostics.
234+
235+ Diagnostics (only when diag=True; added once per group into the result dict):
236+ - {diag_prefix}n_refits : int, number of extra fits after the initial one (0 or 1 in this implementation)
237+ - {diag_prefix}frac_rejected : float, fraction rejected by sigmaCut at final mask
238+ - {diag_prefix}hat_max : float, max leverage proxy via QR (max rowwise ||Q||^2)
239+ - {diag_prefix}cond_xtx : float, condition number of X^T X
240+ - {diag_prefix}time_ms : float, wall-time per group (ms) excluding leverage/cond computation
241+ - {diag_prefix}n_rows : int, number of rows in the group (after dropna for predictors/target/weights)
242+
243+ Notes:
244+ - n_refits counts *additional* iterations beyond the first fit. With this one-pass sigmaCut scheme,
245+ it will be 0 (no re-fit) or 1 (re-fit once on inliers).
246+ """
247+ import time
248+ import numpy as np
249+ import logging
250+ from sklearn .linear_model import HuberRegressor , LinearRegression
251+
252+ # TODO handle the case of single gb column
253+ group_dict = dict (zip (gb_columns , key ))
254+
255+ if isinstance (weights , str ) and weights not in df_group .columns :
256+ raise ValueError (f"Weight column '{ weights } ' not found in input DataFrame." )
257+
258+ # Select predictors that meet per-predictor minStat (based on non-null rows with target+weights)
259+ predictors : List [str ] = []
260+ for i , col in enumerate (linear_columns0 ):
261+ required_columns = [col ] + fit_columns + [weights ]
262+ df_valid = df_group [required_columns ].dropna ()
263+ if len (df_valid ) >= minStat [i ]:
264+ predictors .append (col )
265+
266+ # Prepare diagnostics state (group-level)
267+ n_refits_group = 0 # extra fits after initial fit
268+ frac_rejected_group = np .nan
269+ hat_max_group = np .nan
270+ cond_xtx_group = np .nan
271+ time_ms_group = np .nan
272+ n_rows_group = int (len (df_group )) # raw group size (will refine to cleaned size later)
273+
274+ # Start timing the *fitting* work (we will stop before leverage/cond to avoid polluting time)
275+ t0_group = time .perf_counter ()
276+
277+ # Loop over target columns
278+ for target_col in fit_columns :
279+ try :
280+ if not predictors :
281+ # No valid predictors met minStat; emit NaNs for this target
282+ for col in linear_columns0 :
283+ group_dict [f"{ target_col } _slope_{ col } " ] = np .nan
284+ group_dict [f"{ target_col } _err_{ col } " ] = np .nan
285+ group_dict [f"{ target_col } _intercept" ] = np .nan
286+ group_dict [f"{ target_col } _rms" ] = np .nan
287+ group_dict [f"{ target_col } _mad" ] = np .nan
288+ continue
289+
290+ subset_columns = predictors + [target_col , weights ]
291+ df_clean = df_group .dropna (subset = subset_columns )
292+ if len (df_clean ) < min (minStat ):
293+ # Not enough rows to fit
294+ for col in linear_columns0 :
295+ group_dict [f"{ target_col } _slope_{ col } " ] = np .nan
296+ group_dict [f"{ target_col } _err_{ col } " ] = np .nan
297+ group_dict [f"{ target_col } _intercept" ] = np .nan
298+ group_dict [f"{ target_col } _rms" ] = np .nan
299+ group_dict [f"{ target_col } _mad" ] = np .nan
300+ continue
301+
302+ # Update cleaned group size for diagnostics
303+ n_rows_group = int (len (df_clean ))
304+
305+ X = df_clean [predictors ].to_numpy (copy = False )
306+ y = df_clean [target_col ].to_numpy (copy = False )
307+ w = df_clean [weights ].to_numpy (copy = False )
308+
309+ # Choose model
310+ if callable (fitter ):
311+ model = fitter ()
312+ elif fitter == "robust" :
313+ model = HuberRegressor (tol = 1e-4 )
314+ elif fitter == "ols" :
315+ model = LinearRegression ()
316+ else :
317+ model = HuberRegressor (tol = 1e-4 )
318+
319+ # Initial fit
320+ try :
321+ model .fit (X , y , sample_weight = w )
322+ except Exception as e :
323+ logging .warning (
324+ f"{ model .__class__ .__name__ } failed for { target_col } in group { key } : { e } . "
325+ f"Falling back to LinearRegression."
326+ )
327+ model = LinearRegression ()
328+ model .fit (X , y , sample_weight = w )
329+
330+ # Residuals and robust stats
331+ predicted = model .predict (X )
332+ residuals = y - predicted
333+ rms = float (np .sqrt (np .mean (residuals ** 2 )))
334+ mad = float (np .median (np .abs (residuals )))
335+
336+ # One-pass sigmaCut masking (current implementation supports at most a single re-fit)
337+ final_mask = None
338+ if np .isfinite (mad ) and mad > 0 and sigmaCut is not None and sigmaCut < np .inf :
339+ mask = (np .abs (residuals ) <= sigmaCut * mad )
340+ if mask .sum () >= min (minStat ):
341+ # Re-fit on inliers
342+ n_refits_group += 1 # <-- counts *extra* fits beyond the first
343+ try :
344+ model .fit (X [mask ], y [mask ], sample_weight = w [mask ])
345+ except Exception as e :
346+ logging .warning (
347+ f"{ model .__class__ .__name__ } re-fit with outlier mask failed for { target_col } "
348+ f"in group { key } : { e } . Falling back to LinearRegression."
349+ )
350+ model = LinearRegression ()
351+ model .fit (X [mask ], y [mask ], sample_weight = w [mask ])
352+
353+ # Recompute residuals on full X (to report global rms/mad)
354+ predicted = model .predict (X )
355+ residuals = y - predicted
356+ rms = float (np .sqrt (np .mean (residuals ** 2 )))
357+ mad = float (np .median (np .abs (residuals )))
358+ final_mask = mask
359+ else :
360+ final_mask = np .ones_like (residuals , dtype = bool )
361+ else :
362+ final_mask = np .ones_like (residuals , dtype = bool )
363+
364+ # Parameter errors from final fit (on the design actually used to fit)
365+ try :
366+ if final_mask is not None and final_mask .any ():
367+ X_used = X [final_mask ]
368+ y_used = y [final_mask ]
369+ else :
370+ X_used = X
371+ y_used = y
372+
373+ n , p = X_used .shape
374+ denom = n - p if n > p else 1e-9
375+ s2 = float (np .sum ((y_used - model .predict (X_used )) ** 2 ) / denom )
376+ cov_matrix = np .linalg .inv (X_used .T @ X_used ) * s2
377+ std_errors = np .sqrt (np .diag (cov_matrix ))
378+ except np .linalg .LinAlgError :
379+ std_errors = np .full (len (predictors ), np .nan , dtype = float )
380+
381+ # Store results for this target
382+ for col in linear_columns0 :
383+ if col in predictors :
384+ idx = predictors .index (col )
385+ group_dict [f"{ target_col } _slope_{ col } " ] = float (model .coef_ [idx ])
386+ group_dict [f"{ target_col } _err_{ col } " ] = float (std_errors [idx ]) if idx < len (std_errors ) else np .nan
387+ else :
388+ group_dict [f"{ target_col } _slope_{ col } " ] = np .nan
389+ group_dict [f"{ target_col } _err_{ col } " ] = np .nan
390+
391+ group_dict [f"{ target_col } _intercept" ] = float (model .intercept_ ) if hasattr (model , "intercept_" ) else np .nan
392+ group_dict [f"{ target_col } _rms" ] = rms
393+ group_dict [f"{ target_col } _mad" ] = mad
394+
395+ # Update group-level diagnostics that depend on the final mask
396+ if diag :
397+ # Capture timing up to here (pure fitting + residuals + errors); exclude leverage/cond below
398+ time_ms_group = (time .perf_counter () - t0_group ) * 1e3
399+ if final_mask is not None and len (final_mask ) > 0 :
400+ frac_rejected_group = 1.0 - (float (np .count_nonzero (final_mask )) / float (len (final_mask )))
401+ else :
402+ frac_rejected_group = np .nan
403+
404+ except Exception as e :
405+ logging .warning (f"Robust regression failed for { target_col } in group { key } : { e } " )
406+ for col in linear_columns0 :
407+ group_dict [f"{ target_col } _slope_{ col } " ] = np .nan
408+ group_dict [f"{ target_col } _err_{ col } " ] = np .nan
409+ group_dict [f"{ target_col } _intercept" ] = np .nan
410+ group_dict [f"{ target_col } _rms" ] = np .nan
411+ group_dict [f"{ target_col } _mad" ] = np .nan
412+
413+ # Medians
414+ for col in median_columns :
415+ try :
416+ group_dict [col ] = df_group [col ].median ()
417+ except Exception :
418+ group_dict [col ] = np .nan
419+
420+ # Compute leverage & conditioning proxies (kept OUTSIDE the timed span)
421+ if diag :
422+ try :
423+ X_cols = [c for c in linear_columns0 if c in df_group .columns and c in predictors ]
424+ if X_cols :
425+ X_diag = df_group [X_cols ].dropna ().to_numpy (dtype = np .float64 , copy = False )
426+ else :
427+ X_diag = None
428+
429+ hat_max_group = np .nan
430+ cond_xtx_group = np .nan
431+ if X_diag is not None and X_diag .size and X_diag .shape [1 ] > 0 :
432+ # cond(X^T X)
433+ try :
434+ s = np .linalg .svd (X_diag .T @ X_diag , compute_uv = False )
435+ cond_xtx_group = float (s [0 ] / s [- 1 ]) if (s .size > 0 and s [- 1 ] > 0 ) else float ("inf" )
436+ except Exception :
437+ cond_xtx_group = float ("inf" )
438+ # leverage via QR
439+ try :
440+ Q , _ = np .linalg .qr (X_diag , mode = "reduced" )
441+ hat_max_group = float (np .max (np .sum (Q * Q , axis = 1 )))
442+ except Exception :
443+ pass
444+ except Exception :
445+ pass
446+
447+ # Attach diagnostics (once per group)
448+ group_dict [f"{ diag_prefix } n_refits" ] = int (n_refits_group )
449+ group_dict [f"{ diag_prefix } frac_rejected" ] = float (frac_rejected_group ) if np .isfinite (frac_rejected_group ) else np .nan
450+ group_dict [f"{ diag_prefix } hat_max" ] = float (hat_max_group ) if np .isfinite (hat_max_group ) else np .nan
451+ group_dict [f"{ diag_prefix } cond_xtx" ] = float (cond_xtx_group ) if np .isfinite (cond_xtx_group ) else np .nan
452+ group_dict [f"{ diag_prefix } time_ms" ] = float (time_ms_group ) if np .isfinite (time_ms_group ) else np .nan
453+ group_dict [f"{ diag_prefix } n_rows" ] = int (n_rows_group )
454+
455+ return group_dict
456+
457+
216458 @staticmethod
217459 def make_parallel_fit (
218460 df : pd .DataFrame ,
@@ -229,7 +471,10 @@ def make_parallel_fit(
229471 min_stat : List [int ] = [10 , 10 ],
230472 sigmaCut : float = 4.0 ,
231473 fitter : Union [str , Callable ] = "auto" ,
232- batch_size : Union [int , None ] = None # ← new argument
474+ batch_size : Union [int , None ] = None , # ← new argument
475+ # --- NEW: diagnostics switch ---
476+ diag : bool = False ,
477+ diag_prefix : str = "diag_"
233478 ) -> Tuple [pd .DataFrame , pd .DataFrame ]:
234479 """
235480 Perform grouped robust linear regression using HuberRegressor in parallel.
@@ -292,3 +537,47 @@ def make_parallel_fit(
292537 df [f"{ target_col } { suffix } " ] += df [slope_col ] * df [col ]
293538
294539 return df , dfGB
540+
541+ def summarize_diagnostics (dfGB , diag_prefix : str = "diag_" , top : int = 50 ):
542+ """
543+ Quick look at diagnostic columns emitted by make_parallel_fit(..., diag=True).
544+ Returns a dict of small DataFrames for top offenders, and prints a short summary.
545+
546+ Example:
547+ summ = summarize_diagnostics(dfGB, top=20)
548+ summ["slowest"].head()
549+ """
550+ import pandas as pd
551+ cols = {
552+ "time" : f"{ diag_prefix } time_ms" ,
553+ "refits" : f"{ diag_prefix } n_refits" ,
554+ "rej" : f"{ diag_prefix } frac_rejected" ,
555+ "lev" : f"{ diag_prefix } hat_max" ,
556+ "cond" : f"{ diag_prefix } cond_xtx" ,
557+ "nrows" : f"{ diag_prefix } n_rows" ,
558+ }
559+ missing = [c for c in cols .values () if c not in dfGB .columns ]
560+ if missing :
561+ print ("[diagnostics] Missing columns (did you run diag=True?):" , missing )
562+ return {}
563+
564+ summary = {}
565+ # Defensive: numeric coerce
566+ d = dfGB .copy ()
567+ for k , c in cols .items ():
568+ d [c ] = pd .to_numeric (d [c ], errors = "coerce" )
569+
570+ summary ["slowest" ] = d .sort_values (cols ["time" ], ascending = False ).head (top )[list ({* dfGB .columns [:len (dfGB .columns )// 4 ], * cols .values ()})]
571+ summary ["most_refits" ] = d .sort_values (cols ["refits" ], ascending = False ).head (top )[list ({* dfGB .columns [:len (dfGB .columns )// 4 ], * cols .values ()})]
572+ summary ["most_rejected" ] = d .sort_values (cols ["rej" ], ascending = False ).head (top )[list ({* dfGB .columns [:len (dfGB .columns )// 4 ], * cols .values ()})]
573+ summary ["highest_leverage" ] = d .sort_values (cols ["lev" ], ascending = False ).head (top )[list ({* dfGB .columns [:len (dfGB .columns )// 4 ], * cols .values ()})]
574+ summary ["worst_conditioned" ] = d .sort_values (cols ["cond" ], ascending = False ).head (top )[list ({* dfGB .columns [:len (dfGB .columns )// 4 ], * cols .values ()})]
575+
576+ # Console summary
577+ print ("[diagnostics] Groups:" , len (dfGB ))
578+ print ("[diagnostics] mean time (ms):" , float (d [cols ["time" ]].mean ()))
579+ print ("[diagnostics] pct with refits>0:" , float ((d [cols ["refits" ]] > 0 ).mean ()) * 100.0 )
580+ print ("[diagnostics] mean frac_rejected:" , float (d [cols ["rej" ]].mean ()))
581+ print ("[diagnostics] 99p cond_xtx:" , float (d [cols ["cond" ]].quantile (0.99 )))
582+ print ("[diagnostics] 99p hat_max:" , float (d [cols ["lev" ]].quantile (0.99 )))
583+ return summary
0 commit comments