Skip to content

Commit d55b796

Browse files
author
miranov25
committed
Refactor GroupByRegressor with robust fit logic, dtype casting, and unified min_stat
- Refactored make_linear_fit and make_parallel_fit to support `cast_dtype` for output precision control - Unified min_stat interface across OLS and robust fits - Improved coefficient indexing and error handling in robust fits (e.g. fallback for singular matrices) - Enhanced test coverage: - Outlier robustness - Exact coefficient recovery - Predictor dropout via min_stat thresholds - dtype casting validation - Replaced print statements with logging for integration readiness - Updated groupby_regression.md: - Added flowchart, use cases, and test coverage summary - Documented cast_dtype and fallback logic
1 parent 718259a commit d55b796

File tree

3 files changed

+556
-0
lines changed

3 files changed

+556
-0
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# GroupBy Linear Regression Utilities
2+
3+
This module provides utilities for computing group-wise linear fits and robust statistics on pandas DataFrames. It is designed to support workflows that require fitting separate models across grouped subsets of data.
4+
5+
Originally developed for **distortion correction** and **dE/dx calibration** in high-energy physics experiments, the code has since been generalized to support broader applications involving grouped linear regression and statistical feature extraction.
6+
7+
## Functions
8+
9+
### `GroupByRegressor.make_linear_fit(...)`
10+
11+
Performs group-wise **ordinary least squares (OLS)** regression fits.
12+
13+
#### Parameters:
14+
15+
* `df (pd.DataFrame)`: Input data
16+
* `gb_columns (list[str])`: Columns to group by
17+
* `fit_columns (list[str])`: Dependent (target) variables
18+
* `linear_columns (list[str])`: Independent variables
19+
* `median_columns (list[str])`: Columns for which medians are computed
20+
* `suffix (str)`: Suffix for generated columns
21+
* `selection (pd.Series)`: Boolean mask selecting rows to use
22+
* `addPrediction (bool)`: If True, predictions are added to the original DataFrame
23+
* `cast_dtype (str | None)`: Optional type casting (e.g., 'float32', 'float16') for fit results
24+
* `min_stat (int)`: Minimum number of rows in a group to perform fitting
25+
26+
#### Returns:
27+
28+
* `(df_out, dfGB)`:
29+
30+
* `df_out`: Original DataFrame with predictions (if enabled)
31+
* `dfGB`: Per-group statistics, including slopes, intercepts, medians, and bin counts
32+
33+
---
34+
35+
### `GroupByRegressor.make_parallel_fit(...)`
36+
37+
Performs **robust group-wise regression** using `HuberRegressor`, with optional parallelization.
38+
39+
#### Additional Parameters:
40+
41+
* `weights (str)`: Column to use as weights during regression
42+
* `n_jobs (int)`: Number of parallel processes to use
43+
* `min_stat (list[int])`: Minimum number of points required for each predictor in `linear_columns`
44+
* `sigmaCut (float)`: Threshold multiplier for MAD to reject outliers
45+
46+
#### Notes:
47+
48+
* Supports partial predictor exclusion per group based on `min_stat`
49+
* Uses robust iteration with outlier rejection (MAD filtering)
50+
* Falls back to NaNs when fits are ill-conditioned or predictors are skipped
51+
52+
## Example
53+
54+
```python
55+
from groupby_regression import GroupByRegressor
56+
57+
df_out, dfGB = GroupByRegressor.make_parallel_fit(
58+
df,
59+
gb_columns=['detector_sector'],
60+
fit_columns=['dEdx'],
61+
linear_columns=['path_length', 'momentum'],
62+
median_columns=['path_length'],
63+
weights='w_dedx',
64+
suffix='_calib',
65+
selection=(df['track_quality'] > 0.9),
66+
cast_dtype='float32',
67+
addPrediction=True,
68+
min_stat=[20, 20],
69+
n_jobs=4
70+
)
71+
```
72+
73+
## Output Columns (in `dfGB`):
74+
75+
| Column Name | Description |
76+
| ----------------------------------------- | ---------------------------------------- |
77+
| `<target>_slope_<predictor>_<suffix>` | Regression slope for predictor |
78+
| `<target>_intercept_<suffix>` | Regression intercept |
79+
| `<target>_rms_<suffix>` / `_mad_<suffix>` | Residual stats (robust only) |
80+
| `<median_column>_<suffix>` | Median of the specified column per group |
81+
| `bin_count_<suffix>` | Number of entries in each group |
82+
83+
## Regression Flowchart
84+
85+
```text
86+
+-------------+
87+
| Input Data |
88+
+------+------+
89+
|
90+
v
91+
+------+------+
92+
| Apply mask |
93+
| (selection)|
94+
+------+------+
95+
|
96+
v
97+
+----------------------------+
98+
| Group by gb_columns |
99+
+----------------------------+
100+
|
101+
v
102+
+----------------------------+
103+
| For each group: |
104+
| - Check min_stat |
105+
| - Fit model |
106+
| - Estimate residual stats |
107+
+----------------------------+
108+
|
109+
v
110+
+-------------+ +-------------+
111+
| df_out | | dfGB |
112+
| (with preds)| | (fit params)|
113+
+-------------+ +-------------+
114+
```
115+
116+
## Use Cases
117+
118+
* Detector distortion correction
119+
* dE/dx signal calibration
120+
* Grouped trend removal in sensor data
121+
* Statistical correction of multi-source measurements
122+
123+
## Test Coverage
124+
125+
* Basic regression fit and prediction verification
126+
* Edge case handling (missing data, small groups)
127+
* Outlier injection and robust fit evaluation
128+
* Exact recovery of known coefficients
129+
* `cast_dtype` precision testing
130+
131+
## Tips
132+
133+
💡 Use `cast_dtype='float16'` for storage savings, but ensure it's compatible with downstream numerical precision requirements.
134+
135+
## Recent Changes
136+
137+
* ✅ Unified `min_stat` interface for both OLS and robust fits
138+
* ✅ Type casting via `cast_dtype` param (e.g. `'float16'` for storage efficiency)
139+
* ✅ Stable handling of singular matrices and small group sizes
140+
* ✅ Test coverage for missing values, outliers, and exact recovery scenarios
141+
* ✅ Logging replaces print-based diagnostics for cleaner integration
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import numpy as np
2+
import pandas as pd
3+
import logging
4+
from sklearn.linear_model import LinearRegression, HuberRegressor
5+
from joblib import Parallel, delayed
6+
from numpy.linalg import inv, LinAlgError
7+
8+
9+
class GroupByRegressor:
10+
@staticmethod
11+
def _cast_fit_columns(dfGB, cast_dtype=None):
12+
if cast_dtype is not None:
13+
for col in dfGB.columns:
14+
if ("slope" in col or "intercept" in col or "rms" in col or "mad" in col):
15+
dfGB[col] = dfGB[col].astype(cast_dtype)
16+
return dfGB
17+
18+
@staticmethod
19+
def make_linear_fit(df, gb_columns, fit_columns, linear_columns, median_columns, suffix, selection, addPrediction=False, cast_dtype=None, min_stat=10):
20+
"""
21+
Perform standard linear regression fits for grouped data and compute median values.
22+
23+
Parameters:
24+
df (pd.DataFrame): Input dataframe.
25+
gb_columns (list): Columns to group by.
26+
fit_columns (list): Target columns for linear regression.
27+
linear_columns (list): Independent variables used for the fit.
28+
median_columns (list): Columns for which median values are computed.
29+
suffix (str): Suffix to append to columns in the output dfGB.
30+
selection (pd.Series): Boolean mask for selecting rows.
31+
addPrediction (bool): If True, merge predictions back into df.
32+
cast_dtype (str or None): If not None, cast fit-related columns to this dtype.
33+
min_stat (int): Minimum number of rows required to perform regression.
34+
35+
Returns:
36+
tuple: (df, dfGB) where
37+
df is the original dataframe with predicted values appended (if addPrediction is True),
38+
and dfGB is the group-by statistics dataframe containing medians and fit coefficients.
39+
"""
40+
df_selected = df.loc[selection]
41+
group_results = []
42+
group_sizes = {}
43+
44+
for group_vals, df_group in df_selected.groupby(gb_columns):
45+
group_dict = dict(zip(gb_columns, group_vals))
46+
group_sizes[group_vals] = len(df_group)
47+
for target_col in fit_columns:
48+
try:
49+
X = df_group[linear_columns].values
50+
y = df_group[target_col].values
51+
if len(X) < min_stat:
52+
for col in linear_columns:
53+
group_dict[f"{target_col}_slope_{col}"] = np.nan
54+
group_dict[f"{target_col}_intercept"] = np.nan
55+
continue
56+
model = LinearRegression()
57+
model.fit(X, y)
58+
for i, col in enumerate(linear_columns):
59+
group_dict[f"{target_col}_slope_{col}"] = model.coef_[i]
60+
group_dict[f"{target_col}_intercept"] = model.intercept_
61+
except Exception as e:
62+
logging.warning(f"Linear regression failed for {target_col} in group {group_vals}: {e}")
63+
for col in linear_columns:
64+
group_dict[f"{target_col}_slope_{col}"] = np.nan
65+
group_dict[f"{target_col}_intercept"] = np.nan
66+
67+
for col in median_columns:
68+
group_dict[col] = df_group[col].median()
69+
70+
group_results.append(group_dict)
71+
72+
dfGB = pd.DataFrame(group_results)
73+
dfGB = GroupByRegressor._cast_fit_columns(dfGB, cast_dtype)
74+
75+
bin_counts = np.array([group_sizes.get(tuple(row), 0) for row in dfGB[gb_columns].itertuples(index=False)], dtype=np.int32)
76+
dfGB["bin_count"] = bin_counts
77+
dfGB = dfGB.rename(columns={col: f"{col}{suffix}" for col in dfGB.columns if col not in gb_columns})
78+
dfGB = dfGB.copy()
79+
80+
if addPrediction:
81+
df = df.merge(dfGB, on=gb_columns, how="left")
82+
for target_col in fit_columns:
83+
intercept_col = f"{target_col}_intercept{suffix}"
84+
if intercept_col not in df.columns:
85+
continue
86+
df[f"{target_col}{suffix}"] = df[intercept_col]
87+
for col in linear_columns:
88+
slope_col = f"{target_col}_slope_{col}{suffix}"
89+
if slope_col in df.columns:
90+
df[f"{target_col}{suffix}"] += df[slope_col] * df[col]
91+
92+
return df, dfGB
93+
94+
@staticmethod
95+
def process_group_robust(key, df_group, gb_columns, fit_columns, linear_columns0, median_columns, weights, minStat=[], sigmaCut=4):
96+
"""
97+
Process a single group: perform robust regression fits on each target column,
98+
compute median values, RMS and MAD of the residuals.
99+
After an initial Huber fit, points with residuals > sigmaCut * MAD are removed and the fit is redone
100+
if enough points remain.
101+
102+
For each predictor in linear_columns0, the predictor is used only if the number of rows in the group
103+
is greater than the corresponding value in minStat.
104+
105+
Parameters:
106+
key: Group key.
107+
df_group (pd.DataFrame): Data for the group.
108+
gb_columns (list): Columns used for grouping.
109+
fit_columns (list): Target columns to be fit.
110+
linear_columns0 (list): List of candidate predictor columns.
111+
median_columns (list): List of columns for which median values are computed.
112+
weights (str): Column name for weights.
113+
minStat (list): List of minimum number of rows required to use each predictor in linear_columns0.
114+
sigmaCut (float): Factor to remove outliers (points with residual > sigmaCut * MAD).
115+
116+
Returns:
117+
dict: A dictionary containing group keys, fit parameters, RMS, and MAD.
118+
"""
119+
group_dict = dict(zip(gb_columns, key))
120+
n_rows = len(df_group)
121+
predictors = []
122+
123+
for i, col in enumerate(linear_columns0):
124+
if n_rows > minStat[i]:
125+
predictors.append(col)
126+
127+
for target_col in fit_columns:
128+
try:
129+
if not predictors:
130+
continue
131+
X = df_group[predictors].values
132+
y = df_group[target_col].values
133+
w = df_group[weights].values
134+
if len(y) < min(minStat):
135+
continue
136+
137+
model = HuberRegressor(tol=1e-4)
138+
model.fit(X, y, sample_weight=w)
139+
predicted = model.predict(X)
140+
residuals = y - predicted
141+
n, p = X.shape
142+
denom = n - p if n > p else 1e-9
143+
s2 = np.sum(residuals ** 2) / denom
144+
145+
try:
146+
cov_matrix = inv(X.T @ X) * s2
147+
std_errors = np.sqrt(np.diag(cov_matrix))
148+
except LinAlgError:
149+
std_errors = np.full(len(predictors), np.nan)
150+
151+
rms = np.sqrt(np.mean(residuals ** 2))
152+
mad = np.median(np.abs(residuals))
153+
154+
mask = np.abs(residuals) <= sigmaCut * mad
155+
if mask.sum() >= min(minStat):
156+
model.fit(X[mask], y[mask], sample_weight=w[mask])
157+
predicted = model.predict(X)
158+
residuals = y - predicted
159+
rms = np.sqrt(np.mean(residuals ** 2))
160+
mad = np.median(np.abs(residuals))
161+
162+
for col in linear_columns0:
163+
if col in predictors:
164+
idx = predictors.index(col)
165+
group_dict[f"{target_col}_slope_{col}"] = model.coef_[idx]
166+
group_dict[f"{target_col}_err_{col}"] = std_errors[idx] if idx < len(std_errors) else np.nan
167+
else:
168+
group_dict[f"{target_col}_slope_{col}"] = np.nan
169+
group_dict[f"{target_col}_err_{col}"] = np.nan
170+
171+
group_dict[f"{target_col}_intercept"] = model.intercept_
172+
group_dict[f"{target_col}_rms"] = rms
173+
group_dict[f"{target_col}_mad"] = mad
174+
except Exception as e:
175+
logging.warning(f"Robust regression failed for {target_col} in group {key}: {e}")
176+
for col in linear_columns0:
177+
group_dict[f"{target_col}_slope_{col}"] = np.nan
178+
group_dict[f"{target_col}_err_{col}"] = np.nan
179+
group_dict[f"{target_col}_intercept"] = np.nan
180+
group_dict[f"{target_col}_rms"] = np.nan
181+
group_dict[f"{target_col}_mad"] = np.nan
182+
183+
for col in median_columns:
184+
group_dict[col] = df_group[col].median()
185+
186+
return group_dict

0 commit comments

Comments
 (0)