Skip to content

Commit 40f1492

Browse files
committed
feat: add **kwargs support to XGBClassifier and XGBRegressor
This is in alignment with https://xgboost.readthedocs.io/en/stable/python/python_api.html\#xgboost.XGBRegressor if considering BQML to be a booster type.
1 parent 5ce5d63 commit 40f1492

File tree

3 files changed

+111
-5
lines changed

3 files changed

+111
-5
lines changed

bigframes/ml/ensemble.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from __future__ import annotations
1919

20-
from typing import Dict, List, Literal, Optional
20+
from typing import Dict, List, Literal, Optional, Union
2121

2222
import bigframes_vendored.sklearn.ensemble._forest
2323
import bigframes_vendored.xgboost.sklearn
@@ -78,6 +78,7 @@ def __init__(
7878
tol: float = 0.01,
7979
enable_global_explain: bool = False,
8080
xgboost_version: Literal["0.9", "1.1"] = "0.9",
81+
**kwargs: Union[str, str | int | bool | float | List[str]],
8182
):
8283
self.n_estimators = n_estimators
8384
self.booster = booster
@@ -99,6 +100,7 @@ def __init__(
99100
self.xgboost_version = xgboost_version
100101
self._bqml_model: Optional[core.BqmlModel] = None
101102
self._bqml_model_factory = globals.bqml_model_factory()
103+
self._extra_bqml_options = kwargs
102104

103105
@classmethod
104106
def _from_bq(
@@ -117,7 +119,7 @@ def _from_bq(
117119
@property
118120
def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
119121
"""The model options as they will be set for BQML"""
120-
return {
122+
options = {
121123
"model_type": "BOOSTED_TREE_REGRESSOR",
122124
"data_split_method": "NO_SPLIT",
123125
"early_stop": True,
@@ -139,6 +141,8 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
139141
"enable_global_explain": self.enable_global_explain,
140142
"xgboost_version": self.xgboost_version,
141143
}
144+
options.update(self._extra_bqml_options)
145+
return options # type: ignore
142146

143147
def _fit(
144148
self,
@@ -237,6 +241,7 @@ def __init__(
237241
tol: float = 0.01,
238242
enable_global_explain: bool = False,
239243
xgboost_version: Literal["0.9", "1.1"] = "0.9",
244+
**kwargs: Union[str, str | int | bool | float | List[str]],
240245
):
241246
self.n_estimators = n_estimators
242247
self.booster = booster
@@ -258,6 +263,7 @@ def __init__(
258263
self.xgboost_version = xgboost_version
259264
self._bqml_model: Optional[core.BqmlModel] = None
260265
self._bqml_model_factory = globals.bqml_model_factory()
266+
self._extra_bqml_options = kwargs
261267

262268
@classmethod
263269
def _from_bq(
@@ -276,7 +282,7 @@ def _from_bq(
276282
@property
277283
def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
278284
"""The model options as they will be set for BQML"""
279-
return {
285+
options = {
280286
"model_type": "BOOSTED_TREE_CLASSIFIER",
281287
"data_split_method": "NO_SPLIT",
282288
"early_stop": True,
@@ -298,6 +304,8 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
298304
"enable_global_explain": self.enable_global_explain,
299305
"xgboost_version": self.xgboost_version,
300306
}
307+
options.update(self._extra_bqml_options)
308+
return options # type: ignore
301309

302310
def _fit(
303311
self,

tests/unit/ml/test_golden_sql.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import textwrap
1516
from unittest import mock
1617

1718
from google.cloud import bigquery
1819
import pandas as pd
1920
import pytest
2021

2122
import bigframes
22-
from bigframes.ml import core, decomposition, linear_model
23+
from bigframes.ml import core, decomposition, ensemble, linear_model
2324
import bigframes.ml.core
2425
import bigframes.pandas as bpd
2526

@@ -286,3 +287,83 @@ def test_decomposition_mf_score_with_x(mock_session, bqml_model, mock_X):
286287
"SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_sql_property))",
287288
allow_large_results=True,
288289
)
290+
291+
292+
def test_xgb_classifier_kwargs_params_fit(
293+
bqml_model_factory, mock_session, mock_X, mock_y
294+
):
295+
model = ensemble.XGBClassifier(category_encoding_method="LABEL_ENCODING")
296+
model._bqml_model_factory = bqml_model_factory
297+
model.fit(mock_X, mock_y)
298+
299+
mock_session._start_query_ml_ddl.assert_called_once_with(
300+
textwrap.dedent(
301+
"""
302+
CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`
303+
OPTIONS(
304+
model_type='BOOSTED_TREE_CLASSIFIER',
305+
data_split_method='NO_SPLIT',
306+
early_stop=True,
307+
num_parallel_tree=1,
308+
booster_type='gbtree',
309+
tree_method='auto',
310+
min_tree_child_weight=1,
311+
colsample_bytree=1.0,
312+
colsample_bylevel=1.0,
313+
colsample_bynode=1.0,
314+
min_split_loss=0.0,
315+
max_tree_depth=6,
316+
subsample=1.0,
317+
l1_reg=0.0,
318+
l2_reg=1.0,
319+
learn_rate=0.3,
320+
max_iterations=20,
321+
min_rel_progress=0.01,
322+
enable_global_explain=False,
323+
xgboost_version='0.9',
324+
category_encoding_method='LABEL_ENCODING',
325+
INPUT_LABEL_COLS=['input_column_label'])
326+
AS input_X_y_no_index_sql
327+
"""
328+
).strip()
329+
)
330+
331+
332+
def test_xgb_regressor_kwargs_params_fit(
333+
bqml_model_factory, mock_session, mock_X, mock_y
334+
):
335+
model = ensemble.XGBRegressor(category_encoding_method="LABEL_ENCODING")
336+
model._bqml_model_factory = bqml_model_factory
337+
model.fit(mock_X, mock_y)
338+
339+
mock_session._start_query_ml_ddl.assert_called_once_with(
340+
textwrap.dedent(
341+
"""
342+
CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`
343+
OPTIONS(
344+
model_type='BOOSTED_TREE_REGRESSOR',
345+
data_split_method='NO_SPLIT',
346+
early_stop=True,
347+
num_parallel_tree=1,
348+
booster_type='gbtree',
349+
tree_method='auto',
350+
min_tree_child_weight=1,
351+
colsample_bytree=1.0,
352+
colsample_bylevel=1.0,
353+
colsample_bynode=1.0,
354+
min_split_loss=0.0,
355+
max_tree_depth=6,
356+
subsample=1.0,
357+
l1_reg=0.0,
358+
l2_reg=1.0,
359+
learn_rate=0.3,
360+
max_iterations=20,
361+
min_rel_progress=0.01,
362+
enable_global_explain=False,
363+
xgboost_version='0.9',
364+
category_encoding_method='LABEL_ENCODING',
365+
INPUT_LABEL_COLS=['input_column_label'])
366+
AS input_X_y_no_index_sql
367+
"""
368+
).strip()
369+
)

third_party/bigframes_vendored/xgboost/sklearn.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,18 @@ class XGBRegressor(XGBModel, XGBRegressorBase):
9898
tol (Optional[float]):
9999
Minimum relative loss improvement necessary to continue training. Default to 0.01.
100100
enable_global_explain (Optional[bool]):
101-
Whether to compute global explanations using explainable AI to evaluate global feature importance to the model. Default to False.
101+
Whether to compute global explanations using explainable AI to
102+
evaluate global feature importance to the model. Default to False.
102103
xgboost_version (Optional[str]):
103104
Specifies the Xgboost version for model training. Default to "0.9". Possible values: "0.9", "1.1".
105+
kwargs (dict):
106+
Keyword arguments for the ``model_option_list`` of the boosted tree
107+
BQML model. See
108+
https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-boosted-tree
109+
110+
For example, to set ``CATEGORY_ENCODING_METHOD`` to
111+
``LABEL_ENCODING``, pass in the keyword argument
112+
`category_encoding_method='LABEL_ENCODING'`.
104113
"""
105114

106115

@@ -148,4 +157,12 @@ class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase):
148157
Whether to compute global explanations using explainable AI to evaluate global feature importance to the model. Default to False.
149158
xgboost_version (Optional[str]):
150159
Specifies the Xgboost version for model training. Default to "0.9". Possible values: "0.9", "1.1".
160+
kwargs (dict):
161+
Keyword arguments for the ``model_option_list`` of the boosted tree
162+
BQML model. See
163+
https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-boosted-tree
164+
165+
For example, to set ``CATEGORY_ENCODING_METHOD`` to
166+
``LABEL_ENCODING``, pass in the keyword argument
167+
`category_encoding_method='LABEL_ENCODING'`.
151168
"""

0 commit comments

Comments
 (0)