Skip to content

Commit 161f0f0

Browse files
author
miranov25
committed
Commit latest working version of groupby_regression.py
1 parent fc54430 commit 161f0f0

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

UTILS/dfextensions/groupby_regression.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from joblib import Parallel, delayed
66
from numpy.linalg import inv, LinAlgError
77
from typing import Union, List, Tuple, Callable
8-
8+
from random import shuffle
99

1010
class GroupByRegressor:
1111
@staticmethod
@@ -114,6 +114,7 @@ def process_group_robust(
114114
sigmaCut: float = 4,
115115
fitter: Union[str, Callable] = "auto"
116116
) -> dict:
117+
# TODO 0handle the case os singl gb column
117118
group_dict = dict(zip(gb_columns, key))
118119
predictors = []
119120
if isinstance(weights, str) and weights not in df_group.columns:
@@ -227,7 +228,8 @@ def make_parallel_fit(
227228
n_jobs: int = 1,
228229
min_stat: List[int] = [10, 10],
229230
sigmaCut: float = 4.0,
230-
fitter: Union[str, Callable] = "auto"
231+
fitter: Union[str, Callable] = "auto",
232+
batch_size: Union[int, None] = None # ← new argument
231233
) -> Tuple[pd.DataFrame, pd.DataFrame]:
232234
"""
233235
Perform grouped robust linear regression using HuberRegressor in parallel.
@@ -256,12 +258,15 @@ def make_parallel_fit(
256258
df_selected = df.loc[selection]
257259
grouped = df_selected.groupby(gb_columns)
258260

259-
results = Parallel(n_jobs=n_jobs)(
261+
filtered_items = [(key, idxs) for key, idxs in grouped.groups.items() if len(idxs) >= min_stat[0]/2]
262+
# shuffle(filtered_items) # Shuffle to ensure random order in parallel processing - should be an option
263+
264+
results = Parallel(n_jobs=n_jobs,batch_size=batch_size)(
260265
delayed(GroupByRegressor.process_group_robust)(
261-
key, group_df, gb_columns, fit_columns, linear_columns,
266+
key, df_selected.loc[idxs], gb_columns, fit_columns, linear_columns,
262267
median_columns, weights, min_stat, sigmaCut, fitter
263268
)
264-
for key, group_df in grouped
269+
for key, idxs in filtered_items
265270
)
266271

267272
dfGB = pd.DataFrame(results)

0 commit comments

Comments
 (0)