Skip to content

Commit c45e5d0

Browse files
author
miranov25
committed
Fix: ensure regression outputs are preserved for underpopulated groups
- In `make_linear_fit`, add NaN-filled slope and intercept entries when group size < min_stat - Resolves test failures expecting prediction columns even when fit is skipped - Preserves compatibility with `addPrediction=True` logic
1 parent d55b796 commit c45e5d0

File tree

1 file changed

+112
-19
lines changed

1 file changed

+112
-19
lines changed

UTILS/dfextensions/groupby_regression.py

Lines changed: 112 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,38 +4,48 @@
44
from sklearn.linear_model import LinearRegression, HuberRegressor
55
from joblib import Parallel, delayed
66
from numpy.linalg import inv, LinAlgError
7+
from typing import Union, List, Tuple
78

89

910
class GroupByRegressor:
1011
@staticmethod
11-
def _cast_fit_columns(dfGB, cast_dtype=None):
12+
def _cast_fit_columns(dfGB: pd.DataFrame, cast_dtype: Union[str, None] = None) -> pd.DataFrame:
1213
if cast_dtype is not None:
1314
for col in dfGB.columns:
1415
if ("slope" in col or "intercept" in col or "rms" in col or "mad" in col):
1516
dfGB[col] = dfGB[col].astype(cast_dtype)
1617
return dfGB
1718

1819
@staticmethod
19-
def make_linear_fit(df, gb_columns, fit_columns, linear_columns, median_columns, suffix, selection, addPrediction=False, cast_dtype=None, min_stat=10):
20+
def make_linear_fit(
21+
df: pd.DataFrame,
22+
gb_columns: List[str],
23+
fit_columns: List[str],
24+
linear_columns: List[str],
25+
median_columns: List[str],
26+
suffix: str,
27+
selection: pd.Series,
28+
addPrediction: bool = False,
29+
cast_dtype: Union[str, None] = None,
30+
min_stat: int = 10
31+
) -> Tuple[pd.DataFrame, pd.DataFrame]:
2032
"""
21-
Perform standard linear regression fits for grouped data and compute median values.
33+
Perform grouped ordinary least squares linear regression and compute medians.
2234
2335
Parameters:
2436
df (pd.DataFrame): Input dataframe.
25-
gb_columns (list): Columns to group by.
26-
fit_columns (list): Target columns for linear regression.
27-
linear_columns (list): Independent variables used for the fit.
28-
median_columns (list): Columns for which median values are computed.
29-
suffix (str): Suffix to append to columns in the output dfGB.
30-
selection (pd.Series): Boolean mask for selecting rows.
31-
addPrediction (bool): If True, merge predictions back into df.
32-
cast_dtype (str or None): If not None, cast fit-related columns to this dtype.
33-
min_stat (int): Minimum number of rows required to perform regression.
37+
gb_columns (List[str]): Columns to group by.
38+
fit_columns (List[str]): Target columns for regression.
39+
linear_columns (List[str]): Predictor columns.
40+
median_columns (List[str]): Columns to compute median.
41+
suffix (str): Suffix for output columns.
42+
selection (pd.Series): Boolean mask to filter rows.
43+
addPrediction (bool): If True, add predicted values to df.
44+
cast_dtype (str|None): Data type to cast result coefficients.
45+
min_stat (int): Minimum number of rows per group to perform regression.
3446
3547
Returns:
36-
tuple: (df, dfGB) where
37-
df is the original dataframe with predicted values appended (if addPrediction is True),
38-
and dfGB is the group-by statistics dataframe containing medians and fit coefficients.
48+
Tuple[pd.DataFrame, pd.DataFrame]: (df with predictions, group-level regression results)
3949
"""
4050
df_selected = df.loc[selection]
4151
group_results = []
@@ -44,12 +54,13 @@ def make_linear_fit(df, gb_columns, fit_columns, linear_columns, median_columns,
4454
for group_vals, df_group in df_selected.groupby(gb_columns):
4555
group_dict = dict(zip(gb_columns, group_vals))
4656
group_sizes[group_vals] = len(df_group)
57+
4758
for target_col in fit_columns:
4859
try:
4960
X = df_group[linear_columns].values
5061
y = df_group[target_col].values
5162
if len(X) < min_stat:
52-
for col in linear_columns:
63+
for i, col in enumerate(linear_columns):
5364
group_dict[f"{target_col}_slope_{col}"] = np.nan
5465
group_dict[f"{target_col}_intercept"] = np.nan
5566
continue
@@ -75,7 +86,6 @@ def make_linear_fit(df, gb_columns, fit_columns, linear_columns, median_columns,
7586
bin_counts = np.array([group_sizes.get(tuple(row), 0) for row in dfGB[gb_columns].itertuples(index=False)], dtype=np.int32)
7687
dfGB["bin_count"] = bin_counts
7788
dfGB = dfGB.rename(columns={col: f"{col}{suffix}" for col in dfGB.columns if col not in gb_columns})
78-
dfGB = dfGB.copy()
7989

8090
if addPrediction:
8191
df = df.merge(dfGB, on=gb_columns, how="left")
@@ -92,7 +102,17 @@ def make_linear_fit(df, gb_columns, fit_columns, linear_columns, median_columns,
92102
return df, dfGB
93103

94104
@staticmethod
95-
def process_group_robust(key, df_group, gb_columns, fit_columns, linear_columns0, median_columns, weights, minStat=[], sigmaCut=4):
105+
def process_group_robust(
106+
key: tuple,
107+
df_group: pd.DataFrame,
108+
gb_columns: List[str],
109+
fit_columns: List[str],
110+
linear_columns0: List[str],
111+
median_columns: List[str],
112+
weights: str,
113+
minStat: List[int],
114+
sigmaCut: float = 4
115+
) -> dict:
96116
"""
97117
Process a single group: perform robust regression fits on each target column,
98118
compute median values, RMS and MAD of the residuals.
@@ -110,7 +130,7 @@ def process_group_robust(key, df_group, gb_columns, fit_columns, linear_columns0
110130
linear_columns0 (list): List of candidate predictor columns.
111131
median_columns (list): List of columns for which median values are computed.
112132
weights (str): Column name for weights.
113-
minStat (list): List of minimum number of rows required to use each predictor in linear_columns0.
133+
minStat (list[int]): List of minimum number of rows required to use each predictor in linear_columns0.
114134
sigmaCut (float): Factor to remove outliers (points with residual > sigmaCut * MAD).
115135
116136
Returns:
@@ -184,3 +204,76 @@ def process_group_robust(key, df_group, gb_columns, fit_columns, linear_columns0
184204
group_dict[col] = df_group[col].median()
185205

186206
return group_dict
207+
208+
209+
@staticmethod
210+
def make_parallel_fit(
211+
df: pd.DataFrame,
212+
gb_columns: List[str],
213+
fit_columns: List[str],
214+
linear_columns: List[str],
215+
median_columns: List[str],
216+
weights: str,
217+
suffix: str,
218+
selection: pd.Series,
219+
addPrediction: bool = False,
220+
cast_dtype: Union[str, None] = None,
221+
n_jobs: int = 1,
222+
min_stat: List[int] = [10, 10],
223+
sigmaCut: float = 4.0
224+
) -> Tuple[pd.DataFrame, pd.DataFrame]:
225+
"""
226+
Perform grouped robust linear regression using HuberRegressor in parallel.
227+
228+
Parameters:
229+
df (pd.DataFrame): Input dataframe.
230+
gb_columns (List[str]): Columns to group by.
231+
fit_columns (List[str]): Target columns for regression.
232+
linear_columns (List[str]): Predictor columns.
233+
median_columns (List[str]): Columns to compute medians.
234+
weights (str): Column name of weights for fitting.
235+
suffix (str): Suffix to append to output columns.
236+
selection (pd.Series): Boolean selection mask.
237+
addPrediction (bool): If True, add prediction columns to df.
238+
cast_dtype (Union[str, None]): Optional dtype cast for fit outputs.
239+
n_jobs (int): Number of parallel jobs.
240+
min_stat (List[int]): Minimum number of rows required to use each predictor.
241+
sigmaCut (float): Outlier threshold in MAD units.
242+
243+
Returns:
244+
Tuple[pd.DataFrame, pd.DataFrame]: DataFrame with predictions and group-level statistics.
245+
"""
246+
df_selected = df.loc[selection]
247+
grouped = df_selected.groupby(gb_columns)
248+
249+
results = Parallel(n_jobs=n_jobs)(
250+
delayed(GroupByRegressor.process_group_robust)(
251+
key, group_df, gb_columns, fit_columns, linear_columns,
252+
median_columns, weights, min_stat, sigmaCut
253+
)
254+
for key, group_df in grouped
255+
)
256+
257+
dfGB = pd.DataFrame(results)
258+
dfGB = GroupByRegressor._cast_fit_columns(dfGB, cast_dtype)
259+
260+
bin_counts = np.array([
261+
len(grouped.get_group(key)) if key in grouped.groups else 0
262+
for key in dfGB[gb_columns].itertuples(index=False, name=None)
263+
], dtype=np.int32)
264+
dfGB["bin_count"] = bin_counts
265+
dfGB = dfGB.rename(columns={col: f"{col}{suffix}" for col in dfGB.columns if col not in gb_columns})
266+
267+
if addPrediction:
268+
df = df.merge(dfGB, on=gb_columns, how="left")
269+
for target_col in fit_columns:
270+
intercept_col = f"{target_col}_intercept{suffix}"
271+
if intercept_col not in df.columns:
272+
continue
273+
df[f"{target_col}{suffix}"] = df[intercept_col]
274+
for col in linear_columns:
275+
slope_col = f"{target_col}_slope_{col}{suffix}"
276+
if slope_col in df.columns:
277+
df[f"{target_col}{suffix}"] += df[slope_col] * df[col]
278+
279+
return df, dfGB

0 commit comments

Comments
 (0)