@@ -287,3 +287,80 @@ def test_sigma_cut_impact():
287287
288288 assert abs (slope_strict - 3.0 ) < abs (slope_all - 3.0 ), \
289289 f"Robust fit with sigmaCut=2 should be closer to truth: slope_strict={ slope_strict } , slope_all={ slope_all } "
290+
291+
292+
293+ def test_make_parallel_fit_robust (sample_data ):
294+ df = sample_data .copy ()
295+ df_out , dfGB = GroupByRegressor .make_parallel_fit (
296+ df ,
297+ gb_columns = ['group' ],
298+ fit_columns = ['y' ],
299+ linear_columns = ['x1' , 'x2' ],
300+ median_columns = ['x1' ],
301+ weights = 'weight' ,
302+ suffix = '_rob' ,
303+ selection = (df ['x1' ] > - 10 ),
304+ addPrediction = True ,
305+ n_jobs = 1 ,
306+ min_stat = [5 , 5 ],
307+ fitter = "robust"
308+ )
309+ assert not dfGB .empty
310+ assert 'y_rob' in df_out .columns
311+ assert 'y_slope_x1_rob' in dfGB .columns
312+ assert 'y_intercept_rob' in dfGB .columns
313+
314+
315+ def test_make_parallel_fit_with_linear_regression (sample_data ):
316+ df = sample_data .copy ()
317+ df_out , dfGB = GroupByRegressor .make_parallel_fit (
318+ df ,
319+ gb_columns = ['group' ],
320+ fit_columns = ['y' ],
321+ linear_columns = ['x1' , 'x2' ],
322+ median_columns = ['x1' ],
323+ weights = 'weight' ,
324+ suffix = '_ols' ,
325+ selection = (df ['x1' ] > - 10 ),
326+ addPrediction = True ,
327+ n_jobs = 1 ,
328+ min_stat = [5 , 5 ],
329+ fitter = "ols"
330+ )
331+ assert not dfGB .empty
332+ assert 'y_ols' in df_out .columns
333+ assert 'y_slope_x1_ols' in dfGB .columns
334+ assert 'y_intercept_ols' in dfGB .columns
335+
336+ def test_make_parallel_fit_with_custom_fitter (sample_data ):
337+ class DummyFitter :
338+ def fit (self , X , y , sample_weight = None ):
339+ self .coef_ = np .zeros (X .shape [1 ])
340+ self .intercept_ = 42
341+ return self
342+
343+ def predict (self , X ):
344+ return np .full (X .shape [0 ], self .intercept_ )
345+
346+ df = sample_data .copy ()
347+ df_out , dfGB = GroupByRegressor .make_parallel_fit (
348+ df ,
349+ gb_columns = ['group' ],
350+ fit_columns = ['y' ],
351+ linear_columns = ['x1' ],
352+ median_columns = ['x1' ],
353+ weights = 'weight' ,
354+ suffix = '_dummy' ,
355+ selection = (df ['x1' ] > - 10 ),
356+ addPrediction = True ,
357+ n_jobs = 1 ,
358+ min_stat = [5 ],
359+ fitter = DummyFitter
360+ )
361+ predicted = df_out ['y_dummy' ].dropna ()
362+ assert not predicted .empty
363+ assert np .allclose (predicted .unique (), 42 )
364+ assert 'y_slope_x1_dummy' in dfGB .columns
365+ assert dfGB ['y_slope_x1_dummy' ].iloc [0 ] == 0
366+ assert dfGB ['y_intercept_dummy' ].iloc [0 ] == 42
0 commit comments