44from sklearn .linear_model import LinearRegression , HuberRegressor
55from joblib import Parallel , delayed
66from numpy .linalg import inv , LinAlgError
7+ from typing import Union , List , Tuple
78
89
910class GroupByRegressor :
1011 @staticmethod
11- def _cast_fit_columns (dfGB , cast_dtype = None ):
12+ def _cast_fit_columns (dfGB : pd . DataFrame , cast_dtype : Union [ str , None ] = None ) -> pd . DataFrame :
1213 if cast_dtype is not None :
1314 for col in dfGB .columns :
1415 if ("slope" in col or "intercept" in col or "rms" in col or "mad" in col ):
1516 dfGB [col ] = dfGB [col ].astype (cast_dtype )
1617 return dfGB
1718
1819 @staticmethod
19- def make_linear_fit (df , gb_columns , fit_columns , linear_columns , median_columns , suffix , selection , addPrediction = False , cast_dtype = None , min_stat = 10 ):
20+ def make_linear_fit (
21+ df : pd .DataFrame ,
22+ gb_columns : List [str ],
23+ fit_columns : List [str ],
24+ linear_columns : List [str ],
25+ median_columns : List [str ],
26+ suffix : str ,
27+ selection : pd .Series ,
28+ addPrediction : bool = False ,
29+ cast_dtype : Union [str , None ] = None ,
30+ min_stat : int = 10
31+ ) -> Tuple [pd .DataFrame , pd .DataFrame ]:
2032 """
21- Perform standard linear regression fits for grouped data and compute median values .
33+ Perform grouped ordinary least squares linear regression and compute medians .
2234
2335 Parameters:
2436 df (pd.DataFrame): Input dataframe.
25- gb_columns (list ): Columns to group by.
26- fit_columns (list ): Target columns for linear regression.
27- linear_columns (list ): Independent variables used for the fit .
28- median_columns (list ): Columns for which median values are computed .
29- suffix (str): Suffix to append to columns in the output dfGB .
30- selection (pd.Series): Boolean mask for selecting rows.
31- addPrediction (bool): If True, merge predictions back into df.
32- cast_dtype (str or None): If not None, cast fit-related columns to this dtype .
33- min_stat (int): Minimum number of rows required to perform regression.
37+ gb_columns (List[str] ): Columns to group by.
38+ fit_columns (List[str] ): Target columns for regression.
39+ linear_columns (List[str] ): Predictor columns .
40+ median_columns (List[str] ): Columns to compute median.
41+ suffix (str): Suffix for output columns .
42+ selection (pd.Series): Boolean mask to filter rows.
43+ addPrediction (bool): If True, add predicted values to df.
44+ cast_dtype (str| None): Data type to cast result coefficients .
45+ min_stat (int): Minimum number of rows per group to perform regression.
3446
3547 Returns:
36- tuple: (df, dfGB) where
37- df is the original dataframe with predicted values appended (if addPrediction is True),
38- and dfGB is the group-by statistics dataframe containing medians and fit coefficients.
48+ Tuple[pd.DataFrame, pd.DataFrame]: (df with predictions, group-level regression results)
3949 """
4050 df_selected = df .loc [selection ]
4151 group_results = []
@@ -44,12 +54,13 @@ def make_linear_fit(df, gb_columns, fit_columns, linear_columns, median_columns,
4454 for group_vals , df_group in df_selected .groupby (gb_columns ):
4555 group_dict = dict (zip (gb_columns , group_vals ))
4656 group_sizes [group_vals ] = len (df_group )
57+
4758 for target_col in fit_columns :
4859 try :
4960 X = df_group [linear_columns ].values
5061 y = df_group [target_col ].values
5162 if len (X ) < min_stat :
52- for col in linear_columns :
63+ for i , col in enumerate ( linear_columns ) :
5364 group_dict [f"{ target_col } _slope_{ col } " ] = np .nan
5465 group_dict [f"{ target_col } _intercept" ] = np .nan
5566 continue
@@ -75,7 +86,6 @@ def make_linear_fit(df, gb_columns, fit_columns, linear_columns, median_columns,
7586 bin_counts = np .array ([group_sizes .get (tuple (row ), 0 ) for row in dfGB [gb_columns ].itertuples (index = False )], dtype = np .int32 )
7687 dfGB ["bin_count" ] = bin_counts
7788 dfGB = dfGB .rename (columns = {col : f"{ col } { suffix } " for col in dfGB .columns if col not in gb_columns })
78- dfGB = dfGB .copy ()
7989
8090 if addPrediction :
8191 df = df .merge (dfGB , on = gb_columns , how = "left" )
@@ -92,7 +102,17 @@ def make_linear_fit(df, gb_columns, fit_columns, linear_columns, median_columns,
92102 return df , dfGB
93103
94104 @staticmethod
95- def process_group_robust (key , df_group , gb_columns , fit_columns , linear_columns0 , median_columns , weights , minStat = [], sigmaCut = 4 ):
105+ def process_group_robust (
106+ key : tuple ,
107+ df_group : pd .DataFrame ,
108+ gb_columns : List [str ],
109+ fit_columns : List [str ],
110+ linear_columns0 : List [str ],
111+ median_columns : List [str ],
112+ weights : str ,
113+ minStat : List [int ],
114+ sigmaCut : float = 4
115+ ) -> dict :
96116 """
97117 Process a single group: perform robust regression fits on each target column,
98118 compute median values, RMS and MAD of the residuals.
@@ -110,7 +130,7 @@ def process_group_robust(key, df_group, gb_columns, fit_columns, linear_columns0
110130 linear_columns0 (list): List of candidate predictor columns.
111131 median_columns (list): List of columns for which median values are computed.
112132 weights (str): Column name for weights.
113- minStat (list): List of minimum number of rows required to use each predictor in linear_columns0.
133+ minStat (list[int] ): List of minimum number of rows required to use each predictor in linear_columns0.
114134 sigmaCut (float): Factor to remove outliers (points with residual > sigmaCut * MAD).
115135
116136 Returns:
@@ -184,3 +204,76 @@ def process_group_robust(key, df_group, gb_columns, fit_columns, linear_columns0
184204 group_dict [col ] = df_group [col ].median ()
185205
186206 return group_dict
207+
208+
209+ @staticmethod
210+ def make_parallel_fit (
211+ df : pd .DataFrame ,
212+ gb_columns : List [str ],
213+ fit_columns : List [str ],
214+ linear_columns : List [str ],
215+ median_columns : List [str ],
216+ weights : str ,
217+ suffix : str ,
218+ selection : pd .Series ,
219+ addPrediction : bool = False ,
220+ cast_dtype : Union [str , None ] = None ,
221+ n_jobs : int = 1 ,
222+ min_stat : List [int ] = [10 , 10 ],
223+ sigmaCut : float = 4.0
224+ ) -> Tuple [pd .DataFrame , pd .DataFrame ]:
225+ """
226+ Perform grouped robust linear regression using HuberRegressor in parallel.
227+
228+ Parameters:
229+ df (pd.DataFrame): Input dataframe.
230+ gb_columns (List[str]): Columns to group by.
231+ fit_columns (List[str]): Target columns for regression.
232+ linear_columns (List[str]): Predictor columns.
233+ median_columns (List[str]): Columns to compute medians.
234+ weights (str): Column name of weights for fitting.
235+ suffix (str): Suffix to append to output columns.
236+ selection (pd.Series): Boolean selection mask.
237+ addPrediction (bool): If True, add prediction columns to df.
238+ cast_dtype (Union[str, None]): Optional dtype cast for fit outputs.
239+ n_jobs (int): Number of parallel jobs.
240+ min_stat (List[int]): Minimum number of rows required to use each predictor.
241+ sigmaCut (float): Outlier threshold in MAD units.
242+
243+ Returns:
244+ Tuple[pd.DataFrame, pd.DataFrame]: DataFrame with predictions and group-level statistics.
245+ """
246+ df_selected = df .loc [selection ]
247+ grouped = df_selected .groupby (gb_columns )
248+
249+ results = Parallel (n_jobs = n_jobs )(
250+ delayed (GroupByRegressor .process_group_robust )(
251+ key , group_df , gb_columns , fit_columns , linear_columns ,
252+ median_columns , weights , min_stat , sigmaCut
253+ )
254+ for key , group_df in grouped
255+ )
256+
257+ dfGB = pd .DataFrame (results )
258+ dfGB = GroupByRegressor ._cast_fit_columns (dfGB , cast_dtype )
259+
260+ bin_counts = np .array ([
261+ len (grouped .get_group (key )) if key in grouped .groups else 0
262+ for key in dfGB [gb_columns ].itertuples (index = False , name = None )
263+ ], dtype = np .int32 )
264+ dfGB ["bin_count" ] = bin_counts
265+ dfGB = dfGB .rename (columns = {col : f"{ col } { suffix } " for col in dfGB .columns if col not in gb_columns })
266+
267+ if addPrediction :
268+ df = df .merge (dfGB , on = gb_columns , how = "left" )
269+ for target_col in fit_columns :
270+ intercept_col = f"{ target_col } _intercept{ suffix } "
271+ if intercept_col not in df .columns :
272+ continue
273+ df [f"{ target_col } { suffix } " ] = df [intercept_col ]
274+ for col in linear_columns :
275+ slope_col = f"{ target_col } _slope_{ col } { suffix } "
276+ if slope_col in df .columns :
277+ df [f"{ target_col } { suffix } " ] += df [slope_col ] * df [col ]
278+
279+ return df , dfGB
0 commit comments