Skip to content

Commit 22ce23c

Browse files
author
miranov25
committed
Add NaN filtering and robust fit fallback logic to GroupByRegressor
- Updated `process_group_robust` to filter NaNs in predictors, targets, and weights before fitting - Ensured that only predictors with sufficient valid statistics are included in robust fit - Added fallback to `LinearRegression` if `HuberRegressor` fails - Improves reliability of `make_parallel_fit` when using `robust` option under real-world data imperfections - Corresponding test for per-predictor `min_stat` now passes consistently
1 parent 4f4f425 commit 22ce23c

File tree

2 files changed

+24
-14
lines changed

2 files changed

+24
-14
lines changed

UTILS/dfextensions/groupby_regression.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,14 @@ def process_group_robust(
139139
y = df_clean[target_col].values
140140
w = df_clean[weights].values
141141

142-
model = HuberRegressor(tol=1e-4)
143-
model.fit(X, y, sample_weight=w)
142+
try:
143+
model = HuberRegressor(tol=1e-4)
144+
model.fit(X, y, sample_weight=w)
145+
except Exception as e:
146+
logging.warning(f"HuberRegressor failed for {target_col} in group {key}: {e}. Falling back to LinearRegression.")
147+
model = LinearRegression()
148+
model.fit(X, y, sample_weight=w)
149+
144150
predicted = model.predict(X)
145151
residuals = y - predicted
146152
n, p = X.shape
@@ -158,7 +164,13 @@ def process_group_robust(
158164

159165
mask = np.abs(residuals) <= sigmaCut * mad
160166
if mask.sum() >= min(minStat):
161-
model.fit(X[mask], y[mask], sample_weight=w[mask])
167+
try:
168+
model.fit(X[mask], y[mask], sample_weight=w[mask])
169+
except Exception as e:
170+
logging.warning(f"HuberRegressor re-fit with outlier mask failed for {target_col} in group {key}: {e}. Falling back to LinearRegression.")
171+
model = LinearRegression()
172+
model.fit(X[mask], y[mask], sample_weight=w[mask])
173+
162174
predicted = model.predict(X)
163175
residuals = y - predicted
164176
rms = np.sqrt(np.mean(residuals ** 2))
@@ -189,8 +201,6 @@ def process_group_robust(
189201
group_dict[col] = df_group[col].median()
190202

191203
return group_dict
192-
193-
194204
@staticmethod
195205
def make_parallel_fit(
196206
df: pd.DataFrame,

UTILS/dfextensions/test_groupby_regression.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -260,30 +260,30 @@ def test_min_stat_per_predictor():
260260
assert 'y_slope_x1_minstat' in dfGB.columns
261261
assert not np.isnan(dfGB['y_slope_x1_minstat'].iloc[0]) # x1 passed
262262
assert 'y_slope_x2_minstat' not in dfGB.columns or np.isnan(dfGB['y_slope_x2_minstat'].iloc[0]) # x2 skipped
263-
264263
def test_sigma_cut_impact():
265264
np.random.seed(0)
265+
n_samples = 10000
266266
df = pd.DataFrame({
267-
'group': ['G1'] * 20,
268-
'x1': np.linspace(0, 1, 20),
267+
'group': ['G1'] * n_samples,
268+
'x1': np.linspace(0, 1, n_samples),
269269
})
270-
df['y'] = 3.0 * df['x1'] + np.random.normal(0, 0.1, size=20)
271-
df.loc[::5, 'y'] += 10 # Insert strong outliers
270+
df['y'] = 3.0 * df['x1'] + np.random.normal(0, 0.1, size=n_samples)
271+
df.loc[::50, 'y'] += 100 # Insert strong outliers every 50th sample
272272
df['weight'] = 1.0
273-
274273
selection = df['x1'].notna() & df['y'].notna()
275274

276275
_, dfGB_all = GroupByRegressor.make_parallel_fit(
277276
df, ['group'], ['y'], ['x1'], ['x1'], 'weight', '_s100',
278-
selection=selection, sigmaCut=100, n_jobs=1
277+
selection=selection, sigmaCut=100, n_jobs=1, addPrediction=True
279278
)
280279

281280
_, dfGB_strict = GroupByRegressor.make_parallel_fit(
282281
df, ['group'], ['y'], ['x1'], ['x1'], 'weight', '_s2',
283-
selection=selection, sigmaCut=2, n_jobs=1
282+
selection=selection, sigmaCut=3, n_jobs=1, addPrediction=True
284283
)
285284

286285
slope_all = dfGB_all['y_slope_x1_s100'].iloc[0]
287286
slope_strict = dfGB_strict['y_slope_x1_s2'].iloc[0]
288287

289-
assert abs(slope_strict - 3.0) < abs(slope_all - 3.0), "Robust fit with sigmaCut=2 should be closer to truth"
288+
assert abs(slope_strict - 3.0) < abs(slope_all - 3.0), \
289+
f"Robust fit with sigmaCut=2 should be closer to truth: slope_strict={slope_strict}, slope_all={slope_all}"

0 commit comments

Comments
 (0)