|
5 | 5 | from joblib import Parallel, delayed |
6 | 6 | from numpy.linalg import inv, LinAlgError |
7 | 7 | from typing import Union, List, Tuple, Callable |
8 | | - |
| 8 | +from random import shuffle |
9 | 9 |
|
10 | 10 | class GroupByRegressor: |
11 | 11 | @staticmethod |
@@ -114,6 +114,7 @@ def process_group_robust( |
114 | 114 | sigmaCut: float = 4, |
115 | 115 | fitter: Union[str, Callable] = "auto" |
116 | 116 | ) -> dict: |
| 117 | + # TODO 0handle the case os singl gb column |
117 | 118 | group_dict = dict(zip(gb_columns, key)) |
118 | 119 | predictors = [] |
119 | 120 | if isinstance(weights, str) and weights not in df_group.columns: |
@@ -227,7 +228,8 @@ def make_parallel_fit( |
227 | 228 | n_jobs: int = 1, |
228 | 229 | min_stat: List[int] = [10, 10], |
229 | 230 | sigmaCut: float = 4.0, |
230 | | - fitter: Union[str, Callable] = "auto" |
| 231 | + fitter: Union[str, Callable] = "auto", |
| 232 | + batch_size: Union[int, None] = None # ← new argument |
231 | 233 | ) -> Tuple[pd.DataFrame, pd.DataFrame]: |
232 | 234 | """ |
233 | 235 | Perform grouped robust linear regression using HuberRegressor in parallel. |
@@ -256,12 +258,15 @@ def make_parallel_fit( |
256 | 258 | df_selected = df.loc[selection] |
257 | 259 | grouped = df_selected.groupby(gb_columns) |
258 | 260 |
|
259 | | - results = Parallel(n_jobs=n_jobs)( |
| 261 | + filtered_items = [(key, idxs) for key, idxs in grouped.groups.items() if len(idxs) >= min_stat[0]/2] |
| 262 | + # shuffle(filtered_items) # Shuffle to ensure random order in parallel processing - should be an option |
| 263 | + |
| 264 | + results = Parallel(n_jobs=n_jobs,batch_size=batch_size)( |
260 | 265 | delayed(GroupByRegressor.process_group_robust)( |
261 | | - key, group_df, gb_columns, fit_columns, linear_columns, |
| 266 | + key, df_selected.loc[idxs], gb_columns, fit_columns, linear_columns, |
262 | 267 | median_columns, weights, min_stat, sigmaCut, fitter |
263 | 268 | ) |
264 | | - for key, group_df in grouped |
| 269 | + for key, idxs in filtered_items |
265 | 270 | ) |
266 | 271 |
|
267 | 272 | dfGB = pd.DataFrame(results) |
|
0 commit comments