@@ -51,6 +51,26 @@ def _make_groups(n_rows: int, n_groups: int, rng: np.random.Generator) -> np.nda
5151 rng .shuffle (base )
5252 return base
5353
54+ def _find_diag_col (df : pd .DataFrame , base : str , dp : str , suffix : str | None = None ) -> str | None :
55+ """
56+ Return diagnostics column for a given base (e.g. 'time_ms'), handling suffixes.
57+ If suffix is provided, match startswith(dp+base) and endswith(suffix).
58+ """
59+ exact = dp + base
60+ if suffix is None and exact in df .columns :
61+ return exact
62+ pref = dp + base
63+ for c in df .columns :
64+ if not isinstance (c , str ):
65+ continue
66+ if not c .startswith (pref ):
67+ continue
68+ if suffix is not None and not c .endswith (suffix ):
69+ continue
70+ return c
71+ return None
72+
73+
5474def create_clean_data (n_rows : int , n_groups : int , * , seed : int = 42 , noise_sigma : float = 1.0 , x_corr : float = 0.0 ) -> pd .DataFrame :
5575 rng = np .random .default_rng (seed )
5676 group = _make_groups (n_rows , n_groups , rng )
@@ -89,8 +109,7 @@ class Scenario:
89109 fitter : str
90110 sigmaCut : float
91111
92- def _run_one (df : pd .DataFrame , scenario : Scenario ) -> Dict [str , Any ]:
93- # Workaround for module expecting tuple keys: duplicate group
112+ def _run_one (df : pd .DataFrame , scenario : Scenario , args ) -> Dict [str , Any ]:
94113 df = df .copy ()
95114 df ["group2" ] = df ["group" ].astype (np .int32 )
96115 df ["weight" ] = 1.0
@@ -112,10 +131,13 @@ def _run_one(df: pd.DataFrame, scenario: Scenario) -> Dict[str, Any]:
112131 sigmaCut = scenario .sigmaCut ,
113132 fitter = scenario .fitter ,
114133 batch_size = "auto" ,
134+ diag = getattr (args , "diag" , False ),
135+ diag_prefix = getattr (args , "diag_prefix" , "diag_" ),
115136 )
116137 dt = time .perf_counter () - t0
117- n_groups = int (df_params .shape [0 ])
118- per_1k = dt / (n_groups / 1000.0 ) if n_groups else float ("nan" )
138+ n_groups_eff = int (df_params .shape [0 ])
139+ per_1k = dt / (n_groups_eff / 1000.0 ) if n_groups_eff else float ("nan" )
140+
119141 return {
120142 "scenario" : scenario .name ,
121143 "config" : {
@@ -130,8 +152,9 @@ def _run_one(df: pd.DataFrame, scenario: Scenario) -> Dict[str, Any]:
130152 "result" : {
131153 "total_sec" : dt ,
132154 "sec_per_1k_groups" : per_1k ,
133- "n_groups_effective" : n_groups ,
155+ "n_groups_effective" : n_groups_eff ,
134156 },
157+ "df_params" : df_params if getattr (args , "diag" , False ) else None , # <-- add this
135158 }
136159
137160def _make_df (s : Scenario , seed : int = 7 ) -> pd .DataFrame :
@@ -200,20 +223,61 @@ def run_suite(args) -> Tuple[List[Dict[str, Any]], str, str, str | None]:
200223 # Prepare output
201224 out_dir = Path (args .out ).resolve ()
202225 out_dir .mkdir (parents = True , exist_ok = True )
203-
226+ diag_rows = []
227+ human_summaries : List [Tuple [str , str ]] = []
204228 # Run
205229 results : List [Dict [str , Any ]] = []
206230 for s in scenarios :
207231 df = _make_df (s , seed = args .seed )
208- results .append (_run_one (df , s ))
232+ # PASS ARGS HERE
233+ out = _run_one (df , s , args )
234+ results .append (out )
235+ if args .diag and out .get ("df_params" ) is not None :
236+ dfp = out ["df_params" ]
237+ dp = args .diag_prefix
238+ # Try to infer a suffix from any diag column (optional). If you know your suffix, set it via CLI later.
239+ # For now we won’t guess; we’ll just use dp and allow both suffixed or unsuffixed.
240+
241+ # 2a) Write top-10 violators per scenario
242+ safe = (s .name .replace (" " , "_" )
243+ .replace ("%" ,"pct" )
244+ .replace ("(" ,"" ).replace (")" ,"" )
245+ .replace ("σ" ,"sigma" ))
246+ tcol = _find_diag_col (dfp , "time_ms" , dp )
247+ if tcol :
248+ dfp .sort_values (tcol , ascending = False ).head (10 ).to_csv (
249+ out_dir / f"diag_top10_time__{ safe } .csv" , index = False
250+ )
251+ rcol = _find_diag_col (dfp , "n_refits" , dp )
252+ if rcol :
253+ dfp .sort_values (rcol , ascending = False ).head (10 ).to_csv (
254+ out_dir / f"diag_top10_refits__{ safe } .csv" , index = False
255+ )
256+
257+ # 2b) Class-level summary (machine + human)
258+ summary = GroupByRegressor .summarize_diagnostics (dfp , diag_prefix = dp ,diag_suffix = "_fit" )
259+ summary_row = {"scenario" : s .name , ** summary }
260+ diag_rows .append (summary_row )
261+ human = GroupByRegressor .format_diagnostics_summary (summary )
262+ human_summaries .append ((s .name , human ))
263+ if args .diag :
264+ txt_path = out_dir / "benchmark_report.txt"
265+ with open (txt_path , "a" ) as f :
266+ f .write ("\n Diagnostics summary (per scenario):\n " )
267+ for name , human in human_summaries :
268+ f .write (f" - { name } : { human } \n " )
269+ f .write ("\n Top-10 violators were saved per scenario as:\n " )
270+ f .write (" diag_top10_time__<scenario>.csv, diag_top10_refits__<scenario>.csv\n " )
271+
209272
210273 # Save
211274 txt_path = out_dir / "benchmark_report.txt"
212275 json_path = out_dir / "benchmark_results.json"
213276 with open (txt_path , "w" ) as f :
214277 f .write (_format_report (results ))
278+ results_slim = [{k : v for k , v in r .items () if k != "df_params" } for r in results ]
215279 with open (json_path , "w" ) as f :
216- json .dump (results , f , indent = 2 )
280+ json .dump (results_slim , f , indent = 2 )
217281
218282 csv_path = None
219283 if args .emit_csv :
@@ -226,6 +290,15 @@ def run_suite(args) -> Tuple[List[Dict[str, Any]], str, str, str | None]:
226290 cfg = r ["config" ]; res = r ["result" ]
227291 w .writerow ([r ["scenario" ], cfg ["n_jobs" ], cfg ["sigmaCut" ], cfg ["fitter" ], cfg ["rows_per_group" ], cfg ["n_groups" ], cfg ["outlier_pct" ], cfg ["outlier_mag" ], res ["total_sec" ], res ["sec_per_1k_groups" ], res ["n_groups_effective" ]])
228292
293+ # --- Append diagnostics summaries to the text report ---
294+ if args .diag and 'human_summaries' in locals () and human_summaries :
295+ with open (txt_path , "a" ) as f :
296+ f .write ("\n Diagnostics summary (per scenario):\n " )
297+ for name , human in human_summaries :
298+ f .write (f" - { name } : { human } \n " )
299+ f .write ("\n Top-10 violators saved as diag_top10_time__<scenario>.csv "
300+ "and diag_top10_refits__<scenario>.csv\n " )
301+
229302 return results , str (txt_path ), str (json_path ), (str (csv_path ) if csv_path else None )
230303
231304def parse_args ():
@@ -240,11 +313,18 @@ def parse_args():
240313 p .add_argument ("--emit-csv" , action = "store_true" , help = "Also emit CSV summary." )
241314 p .add_argument ("--serial-only" , action = "store_true" , help = "Skip parallel scenarios." )
242315 p .add_argument ("--quick" , action = "store_true" , help = "Small quick run: groups=200." )
316+ p .add_argument ("--diag" , action = "store_true" ,
317+ help = "Collect per-group diagnostics into dfGB (diag_* columns)." )
318+ p .add_argument ("--diag-prefix" , type = str , default = "diag_" ,
319+ help = "Prefix for diagnostic columns (default: diag_)." )
320+
243321 args = p .parse_args ()
244322 if args .quick :
245323 args .groups = min (args .groups , 200 )
246324 return args
247325
326+
327+
248328def main ():
249329 args = parse_args ()
250330 results , txt_path , json_path , csv_path = run_suite (args )
0 commit comments