Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@

# stochtree 0.3.1

## New Features

* Added `print`, `summary`, `plot`, and `extract_parameter` generic functions in R for the `bartmodel` and `bcfmodel` classes ([#271](https://github.com/StochasticTree/stochtree/pull/271))
* Added sklearn-compatible estimator wrapper for `BARTModel` in Python ([#270](https://github.com/StochasticTree/stochtree/pull/270))

## Bug Fixes

* Fix R bug where our approach to temporarily modifying users' RNG state failed if `.Random.seed` did not exist (i.e. if the R RNG hadn't yet been accessed by an R session) ([#258](https://github.com/StochasticTree/stochtree/issues/258))
* Fix prediction bug for R BART models with random effects with labels that aren't straightforward `1:num_groups` integers when only `y_hat` is requested ([#256](https://github.com/StochasticTree/stochtree/pull/256))
* Fix issue with C++ standard specification in Windows R package config ([#276](https://github.com/StochasticTree/stochtree/pull/276))

# stochtree 0.2.1

Expand Down
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# stochtree 0.3.1

## New Features

* Added `print`, `summary`, `plot`, and `extract_parameter` generic functions in R for the `bartmodel` and `bcfmodel` classes ([#271](https://github.com/StochasticTree/stochtree/pull/271))

## Bug Fixes

* Fix R bug where our approach to temporarily modifying users' RNG state failed if `.Random.seed` did not exist (i.e. if the R RNG hadn't yet been accessed by an R session) ([#258](https://github.com/StochasticTree/stochtree/issues/258))
* Fix prediction bug for R BART models with random effects with labels that aren't straightforward `1:num_groups` integers when only `y_hat` is requested ([#256](https://github.com/StochasticTree/stochtree/pull/256))
* Fix issue with C++ standard specification in Windows R package config ([#276](https://github.com/StochasticTree/stochtree/pull/276))

# stochtree 0.2.1

Expand Down
9 changes: 4 additions & 5 deletions stochtree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
)
from .serialization import JSONSerializer
from .sklearn import (
StochTreeRegressor,
StochTreeBinaryClassifier
StochTreeBARTRegressor,
StochTreeBARTBinaryClassifier,
)
from .utils import (
NotSampledError,
Expand All @@ -35,14 +35,13 @@
_check_matrix_square,
_standardize_array_to_list,
_standardize_array_to_np,
_expand_dims_1d,
_expand_dims_2d,
_expand_dims_2d_diag
)

__all__ = [
"BARTModel",
"BCFModel",
"StochTreeBARTRegressor",
"StochTreeBARTBinaryClassifier",
"Dataset",
"Residual",
"ForestContainer",
Expand Down
2 changes: 1 addition & 1 deletion stochtree/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def sample(
Parameters
----------
X_train : np.array
Training set covariates on which trees may be partitioned.
Training set covariates on which trees are partitioned.
y_train : np.array
Training set outcome.
leaf_basis_train : np.array, optional
Expand Down
2 changes: 2 additions & 0 deletions stochtree/bcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1365,6 +1365,7 @@ def sample(
num_gfr=num_gfr_propensity,
num_burnin=num_burnin_propensity,
num_mcmc=num_mcmc_propensity,
general_params={"random_seed": random_seed},
)
propensity_train = np.mean(
self.bart_propensity_model.y_hat_train, axis=1, keepdims=True
Expand All @@ -1379,6 +1380,7 @@ def sample(
num_gfr=num_gfr_propensity,
num_burnin=num_burnin_propensity,
num_mcmc=num_mcmc_propensity,
general_params={"random_seed": random_seed},
)
propensity_train = np.mean(
self.bart_propensity_model.y_hat_train, axis=1, keepdims=True
Expand Down
20 changes: 9 additions & 11 deletions stochtree/sklearn.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from numpy import isin
from copy import copy
import numpy as np
from scipy.stats import norm
from stochtree import BARTModel, BCFModel
from stochtree import BARTModel
from sklearn.utils._array_api import (
get_namespace,
indexing_dtype,
Expand All @@ -19,7 +17,7 @@



class StochTreeRegressor(RegressorMixin, BaseEstimator):
class StochTreeBARTRegressor(RegressorMixin, BaseEstimator):
"""A scikit-learn-compatible estimator that implements a BART regression model.

Parameters
Expand All @@ -42,10 +40,10 @@ class StochTreeRegressor(RegressorMixin, BaseEstimator):
Attributes
----------
X_ : ndarray, shape (n_samples, n_features)
The input passed during :meth:`fit`.
The covariates (or features) used to define tree partitions.

y_ : ndarray, shape (n_samples,)
The labels passed during :meth:`fit`.
The outcome variable (or labels) used to evaluate tree partitions.

leaf_regression_basis_ : ndarray, shape (n_samples, n_bases)
The basis functions used for leaf regression model if requested.
Expand All @@ -70,7 +68,7 @@ class StochTreeRegressor(RegressorMixin, BaseEstimator):
>>> data = load_boston()
>>> X = data.data
>>> y = data.target
>>> reg = StochTreeRegressor().fit(X, y)
>>> reg = StochTreeBARTRegressor().fit(X, y)
>>> reg.predict(X)
"""

Expand Down Expand Up @@ -265,7 +263,7 @@ def __setstate__(self, state):
self.__dict__.update(state)


class StochTreeBinaryClassifier(ClassifierMixin, BaseEstimator):
class StochTreeBARTBinaryClassifier(ClassifierMixin, BaseEstimator):
"""A scikit-learn-compatible estimator that implements a binary probit BART classifier.

Parameters
Expand All @@ -288,10 +286,10 @@ class StochTreeBinaryClassifier(ClassifierMixin, BaseEstimator):
Attributes
----------
X_ : ndarray, shape (n_samples, n_features)
The input passed during :meth:`fit`.
The covariates (or features) used to define tree partitions.

y_ : ndarray, shape (n_samples,)
The labels passed during :meth:`fit`.
The outcome variable (or labels) used to evaluate tree partitions.

leaf_regression_basis_ : ndarray, shape (n_samples, n_bases)
The basis functions used for leaf regression model if requested.
Expand All @@ -316,7 +314,7 @@ class StochTreeBinaryClassifier(ClassifierMixin, BaseEstimator):
>>> data = load_wine()
>>> X = data.data
>>> y = data.target
>>> clf = StochTreeBinaryClassifier().fit(X, y)
>>> clf = StochTreeBARTBinaryClassifier().fit(X, y)
>>> clf.predict(X)
"""

Expand Down
54 changes: 35 additions & 19 deletions tools/debug/stochtree_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,28 @@
from sklearn.datasets import load_wine, load_breast_cancer
from sklearn.model_selection import GridSearchCV
from sklearn.multiclass import OneVsRestClassifier
from stochtree import StochTreeRegressor, StochTreeBinaryClassifier
from stochtree import (
StochTreeBARTRegressor,
StochTreeBARTBinaryClassifier,
)

# Generate data
# Generate supervised learning data
n = 100
p = 10
rng = np.random.default_rng(42)
X = rng.normal(size=(n, p))
y = X[:, 0] * 3 + rng.normal(size=n)

# Fit and predict a model
reg1 = StochTreeRegressor(general_params={"random_seed": 42})
reg1 = StochTreeBARTRegressor(general_params={"random_seed": 42})
reg1.fit(X, y)
pred1 = reg1.predict(X)

# Check that we get the same results with the same seed
# Also check that we can run the model on pandas inputs
X_df = pd.DataFrame(X)
y_series = pd.Series(y)
reg2 = StochTreeRegressor(general_params={"random_seed": 42})
reg2 = StochTreeBARTRegressor(general_params={"random_seed": 42})
reg2.fit(X_df, y_series)
pred2 = reg2.predict(X_df)

Expand All @@ -36,25 +39,29 @@
plt.title("Comparison of Predictions")
plt.show()

# Check that StochTreeRegressor is a valid estimator
check_estimator(StochTreeRegressor(general_params={"random_seed": 42}, mean_forest_params={"min_samples_leaf": 1}))
# Check that StochTreeBARTRegressor is a valid estimator
check_estimator(
StochTreeBARTRegressor(
general_params={"random_seed": 42}, mean_forest_params={"min_samples_leaf": 1}
)
)

# Check that we can cross validate stochtree BART parameters
param_grid = {
'num_gfr': [10, 40],
'num_mcmc': [0, 1000],
'mean_forest_params': [
{'num_trees': 50, 'alpha': 0.95, 'beta': 2.0},
{'num_trees': 100, 'alpha': 0.90, 'beta': 1.5},
{'num_trees': 200, 'alpha': 0.85, 'beta': 1.0}
]
"num_gfr": [10, 40],
"num_mcmc": [0, 1000],
"mean_forest_params": [
{"num_trees": 50, "alpha": 0.95, "beta": 2.0},
{"num_trees": 100, "alpha": 0.90, "beta": 1.5},
{"num_trees": 200, "alpha": 0.85, "beta": 1.0},
],
}
grid_search = GridSearchCV(
estimator=StochTreeRegressor(),
estimator=StochTreeBARTRegressor(),
param_grid=param_grid,
cv=5,
scoring='r2',
n_jobs=-1
scoring="r2",
n_jobs=-1,
)
grid_search.fit(X, y)
# grid_search.cv_results_
Expand All @@ -66,7 +73,7 @@
y = dataset.target

# Check that we can fit and predict on this dataset
clf = StochTreeBinaryClassifier(general_params={"random_seed": 42})
clf = StochTreeBARTBinaryClassifier(general_params={"random_seed": 42})
clf.fit(X=X, y=y)

# Load a multiclass classification dataset
Expand All @@ -75,8 +82,17 @@
y = dataset.target

# Check that we can fit and predict on this dataset by wrapping in the OneVsRest meta-estimator
clf = OneVsRestClassifier(StochTreeBinaryClassifier(general_params={"random_seed": 42}))
clf = OneVsRestClassifier(
StochTreeBARTBinaryClassifier(general_params={"random_seed": 42})
)
clf.fit(X=X, y=y)

# Check that we have a valid general purpose classifier when wrapping this estimator in the OneVsRest meta-estimator
check_estimator(OneVsRestClassifier(StochTreeBinaryClassifier(general_params={"random_seed": 42}, mean_forest_params={"min_samples_leaf": 1})))
check_estimator(
OneVsRestClassifier(
StochTreeBARTBinaryClassifier(
general_params={"random_seed": 42},
mean_forest_params={"min_samples_leaf": 1},
)
)
)
Loading