Skip to content

Commit 4f4f425

Browse files
author
miranov25
committed
Fix NaN handling in robust regression and enable predictor-specific min_stat threshold
- Updated `process_group_robust` to drop rows with NaNs in predictors, targets, or weights before fitting HuberRegressor - Ensured that predictors are only used if they meet their individual `min_stat` thresholds - Prevented fit failures caused by insufficient data or NaNs, resolving test failure in `test_min_stat_per_predictor`
1 parent c45e5d0 commit 4f4f425

File tree

2 files changed

+76
-31
lines changed

2 files changed

+76
-31
lines changed

UTILS/dfextensions/groupby_regression.py

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -113,47 +113,32 @@ def process_group_robust(
113113
minStat: List[int],
114114
sigmaCut: float = 4
115115
) -> dict:
116-
"""
117-
Process a single group: perform robust regression fits on each target column,
118-
compute median values, RMS and MAD of the residuals.
119-
After an initial Huber fit, points with residuals > sigmaCut * MAD are removed and the fit is redone
120-
if enough points remain.
121-
122-
For each predictor in linear_columns0, the predictor is used only if the number of rows in the group
123-
is greater than the corresponding value in minStat.
124-
125-
Parameters:
126-
key: Group key.
127-
df_group (pd.DataFrame): Data for the group.
128-
gb_columns (list): Columns used for grouping.
129-
fit_columns (list): Target columns to be fit.
130-
linear_columns0 (list): List of candidate predictor columns.
131-
median_columns (list): List of columns for which median values are computed.
132-
weights (str): Column name for weights.
133-
minStat (list[int]): List of minimum number of rows required to use each predictor in linear_columns0.
134-
sigmaCut (float): Factor to remove outliers (points with residual > sigmaCut * MAD).
135-
136-
Returns:
137-
dict: A dictionary containing group keys, fit parameters, RMS, and MAD.
138-
"""
139116
group_dict = dict(zip(gb_columns, key))
140-
n_rows = len(df_group)
141117
predictors = []
142118

119+
# Count valid rows for each predictor and include only if enough
143120
for i, col in enumerate(linear_columns0):
144-
if n_rows > minStat[i]:
121+
required_columns = [col] + fit_columns + [weights]
122+
df_valid = df_group[required_columns].dropna()
123+
if len(df_valid) >= minStat[i]:
145124
predictors.append(col)
146125

147126
for target_col in fit_columns:
148127
try:
149128
if not predictors:
150129
continue
151-
X = df_group[predictors].values
152-
y = df_group[target_col].values
153-
w = df_group[weights].values
154-
if len(y) < min(minStat):
130+
131+
# Drop rows with any NaNs in predictors, target, or weights
132+
subset_columns = predictors + [target_col, weights]
133+
df_clean = df_group.dropna(subset=subset_columns)
134+
135+
if len(df_clean) < min(minStat):
155136
continue
156137

138+
X = df_clean[predictors].values
139+
y = df_clean[target_col].values
140+
w = df_clean[weights].values
141+
157142
model = HuberRegressor(tol=1e-4)
158143
model.fit(X, y, sample_weight=w)
159144
predicted = model.predict(X)

UTILS/dfextensions/test_groupby_regression.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ def test_insufficient_data(sample_data):
7373
)
7474
assert len(dfGB) <= 1 # Could be empty or single group with skipped fit
7575
assert 'y_tiny' in df_out.columns
76-
assert dfGB['y_slope_x1_tiny'].isna().all()
77-
assert dfGB['y_intercept_tiny'].isna().all()
76+
assert dfGB.get('y_slope_x1_tiny') is None or dfGB['y_slope_x1_tiny'].isna().all()
77+
assert dfGB.get('y_intercept_tiny') is None or dfGB['y_intercept_tiny'].isna().all()
7878

7979

8080
def test_prediction_accuracy(sample_data):
@@ -199,6 +199,7 @@ def test_exact_coefficient_recovery():
199199
assert np.isclose(dfGB['y_slope_x1_clean'].iloc[0], 2.0, atol=1e-6)
200200
assert np.isclose(dfGB['y_slope_x2_clean'].iloc[0], 3.0, atol=1e-6)
201201

202+
202203
def test_exact_coefficient_recovery_parallel():
203204
np.random.seed(0)
204205
x1 = np.random.uniform(0, 1, 100)
@@ -227,3 +228,62 @@ def test_exact_coefficient_recovery_parallel():
227228

228229
assert np.isclose(dfGB['y_slope_x1_par'].iloc[0], 2.0, atol=1e-6)
229230
assert np.isclose(dfGB['y_slope_x2_par'].iloc[0], 3.0, atol=1e-6)
231+
232+
233+
def test_min_stat_per_predictor():
234+
# Create a group with 20 rows total, but only 5 valid for x2
235+
df = pd.DataFrame({
236+
'group': ['G1'] * 20,
237+
'x1': np.linspace(0, 1, 20),
238+
'x2': [np.nan] * 15 + list(np.linspace(0, 1, 5)),
239+
})
240+
df['y'] = 2.0 * df['x1'] + 3.0 * np.nan_to_num(df['x2']) + np.random.normal(0, 0.01, 20)
241+
df['weight'] = 1.0
242+
243+
# Use all 20 rows, but let selection ensure only valid ones go into each predictor fit
244+
selection = df['x1'].notna() & df['y'].notna()
245+
246+
df_out, dfGB = GroupByRegressor.make_parallel_fit(
247+
df,
248+
gb_columns=['group'],
249+
fit_columns=['y'],
250+
linear_columns=['x1', 'x2'],
251+
median_columns=['x1'],
252+
weights='weight',
253+
suffix='_minstat',
254+
selection=selection,
255+
addPrediction=True,
256+
min_stat=[10, 10], # x1: 20 valid rows; x2: only 5
257+
n_jobs=1
258+
)
259+
260+
assert 'y_slope_x1_minstat' in dfGB.columns
261+
assert not np.isnan(dfGB['y_slope_x1_minstat'].iloc[0]) # x1 passed
262+
assert 'y_slope_x2_minstat' not in dfGB.columns or np.isnan(dfGB['y_slope_x2_minstat'].iloc[0]) # x2 skipped
263+
264+
def test_sigma_cut_impact():
265+
np.random.seed(0)
266+
df = pd.DataFrame({
267+
'group': ['G1'] * 20,
268+
'x1': np.linspace(0, 1, 20),
269+
})
270+
df['y'] = 3.0 * df['x1'] + np.random.normal(0, 0.1, size=20)
271+
df.loc[::5, 'y'] += 10 # Insert strong outliers
272+
df['weight'] = 1.0
273+
274+
selection = df['x1'].notna() & df['y'].notna()
275+
276+
_, dfGB_all = GroupByRegressor.make_parallel_fit(
277+
df, ['group'], ['y'], ['x1'], ['x1'], 'weight', '_s100',
278+
selection=selection, sigmaCut=100, n_jobs=1
279+
)
280+
281+
_, dfGB_strict = GroupByRegressor.make_parallel_fit(
282+
df, ['group'], ['y'], ['x1'], ['x1'], 'weight', '_s2',
283+
selection=selection, sigmaCut=2, n_jobs=1
284+
)
285+
286+
slope_all = dfGB_all['y_slope_x1_s100'].iloc[0]
287+
slope_strict = dfGB_strict['y_slope_x1_s2'].iloc[0]
288+
289+
assert abs(slope_strict - 3.0) < abs(slope_all - 3.0), "Robust fit with sigmaCut=2 should be closer to truth"

0 commit comments

Comments
 (0)