Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 47 additions & 47 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -2331,26 +2331,26 @@ predict.bartmodel <- function(

#' @title Print a summary of the BART model
#' @description Prints a summary of the BART model, including the model terms and their specifications.
#' @param bart_model The BART model object
#' @param x The BART model object
#' @param ... Additional arguments
#' @export
#' @return BART model object unchanged after printing summary
print.bartmodel <- function(bart_model, ...) {
print.bartmodel <- function(x, ...) {
# What type of model was run
model_terms <- c()
if (bart_model$model_params$include_mean_forest) {
if (x$model_params$include_mean_forest) {
model_terms <- c(model_terms, "mean forest")
}
if (bart_model$model_params$include_variance_forest) {
if (x$model_params$include_variance_forest) {
model_terms <- c(model_terms, "variance forest")
}
if (bart_model$model_params$has_rfx) {
if (x$model_params$has_rfx) {
model_terms <- c(model_terms, "additive random effects")
}
if (bart_model$model_params$sample_sigma2_global) {
if (x$model_params$sample_sigma2_global) {
model_terms <- c(model_terms, "global error variance model")
}
if (bart_model$model_params$sample_sigma2_leaf) {
if (x$model_params$sample_sigma2_leaf) {
model_terms <- c(model_terms, "mean forest leaf scale model")
}
if (length(model_terms) > 2) {
Expand All @@ -2372,27 +2372,27 @@ print.bartmodel <- function(bart_model, ...) {
}

# Outcome and leaf model details
if (bart_model$model_params$leaf_regression) {
if (x$model_params$leaf_regression) {
summary_message <- paste0(
summary_message,
"\n",
"Outcome was modeled ",
ifelse(
bart_model$model_params$probit_outcome_model,
x$model_params$probit_outcome_model,
"with a probit link",
"as gaussian"
),
" with a leaf regression prior with ",
bart_model$model_params$leaf_dimension,
x$model_params$leaf_dimension,
" bases for the mean forest"
)
} else if (bart_model$model_params$include_mean_forest) {
} else if (x$model_params$include_mean_forest) {
summary_message <- paste0(
summary_message,
"\n",
"Outcome was modeled ",
ifelse(
bart_model$model_params$probit_outcome_model,
x$model_params$probit_outcome_model,
"with a probit link",
"as gaussian"
),
Expand All @@ -2404,15 +2404,15 @@ print.bartmodel <- function(bart_model, ...) {
"\n",
"Outcome was modeled ",
ifelse(
bart_model$model_params$probit_outcome_model,
x$model_params$probit_outcome_model,
"with a probit link",
"as gaussian"
),
)
}

# Standardization
if (bart_model$model_params$standardize) {
if (x$model_params$standardize) {
summary_message <- paste0(
summary_message,
"\n",
Expand All @@ -2421,14 +2421,14 @@ print.bartmodel <- function(bart_model, ...) {
}

# Random effects details
if (bart_model$model_params$has_rfx) {
if (bart_model$model_params$rfx_model_spec == "custom") {
if (x$model_params$has_rfx) {
if (x$model_params$rfx_model_spec == "custom") {
summary_message <- paste0(
summary_message,
"\n",
"Random effects were fit with a user-supplied basis"
)
} else if (bart_model$model_params$rfx_model_spec == "intercept_only") {
} else if (x$model_params$rfx_model_spec == "intercept_only") {
summary_message <- paste0(
summary_message,
"\n",
Expand All @@ -2442,24 +2442,24 @@ print.bartmodel <- function(bart_model, ...) {
summary_message,
"\n",
"The sampler was run for ",
bart_model$model_params$num_gfr,
x$model_params$num_gfr,
" GFR iterations, with ",
bart_model$model_params$num_chains,
x$model_params$num_chains,
ifelse(
bart_model$model_params$num_chains == 1,
x$model_params$num_chains == 1,
" chain of ",
" chains of "
),
bart_model$model_params$num_burnin,
x$model_params$num_burnin,
" burn-in iterations and ",
bart_model$model_params$num_mcmc,
x$model_params$num_mcmc,
" MCMC iterations, ",
ifelse(
bart_model$model_params$keep_every == 1,
x$model_params$keep_every == 1,
"retaining every iteration (i.e. no thinning)",
paste0(
"retaining every ",
bart_model$model_params$keep_every,
x$model_params$keep_every,
"th iteration (i.e. thinning)"
)
)
Expand All @@ -2469,24 +2469,24 @@ print.bartmodel <- function(bart_model, ...) {
cat(summary_message, "\n")

# Return bart_model invisibly
invisible(bart_model)
invisible(x)
}

#' @title Summarize the BART model fit and sampled terms.
#' @description Summarize the BART with a description of the model that was fit and numeric summaries of any sampled quantities.
#' @param bart_model The BART model object
#' @param object The BART model object
#' @param ... Additional arguments
#' @export
#' @return BART model object unchanged after summarizing
summary.bartmodel <- function(bart_model, ...) {
summary.bartmodel <- function(object, ...) {
# First, print the BART model
tmp <- print(bart_model)
tmp <- print(object)

# Summarize any sampled quantities

# Global error scale
if (bart_model$model_params$sample_sigma2_global) {
sigma2_samples <- bart_model$sigma2_global_samples
if (object$model_params$sample_sigma2_global) {
sigma2_samples <- object$sigma2_global_samples
n_samples <- length(sigma2_samples)
mean_sigma2 <- mean(sigma2_samples)
sd_sigma2 <- sd(sigma2_samples)
Expand All @@ -2504,8 +2504,8 @@ summary.bartmodel <- function(bart_model, ...) {
}

# Leaf scale
if (bart_model$model_params$sample_sigma2_leaf) {
sigma2_leaf_samples <- bart_model$sigma2_leaf_samples
if (object$model_params$sample_sigma2_leaf) {
sigma2_leaf_samples <- object$sigma2_leaf_samples
n_samples <- length(sigma2_leaf_samples)
mean_sigma2 <- mean(sigma2_leaf_samples)
sd_sigma2 <- sd(sigma2_leaf_samples)
Expand All @@ -2523,8 +2523,8 @@ summary.bartmodel <- function(bart_model, ...) {
}

# In-sample predictions
if (!is.null(bart_model$y_hat_train)) {
y_hat_train_mean <- rowMeans(bart_model$y_hat_train)
if (!is.null(object$y_hat_train)) {
y_hat_train_mean <- rowMeans(object$y_hat_train)
n_y_hat_train <- length(y_hat_train_mean)
mean_y_hat_train <- mean(y_hat_train_mean)
sd_y_hat_train <- sd(y_hat_train_mean)
Expand All @@ -2542,8 +2542,8 @@ summary.bartmodel <- function(bart_model, ...) {
}

# Test-set predictions
if (!is.null(bart_model$y_hat_test)) {
y_hat_test_mean <- rowMeans(bart_model$y_hat_test)
if (!is.null(object$y_hat_test)) {
y_hat_test_mean <- rowMeans(object$y_hat_test)
n_y_hat_test <- length(y_hat_test_mean)
mean_y_hat_test <- mean(y_hat_test_mean)
sd_y_hat_test <- sd(y_hat_test_mean)
Expand All @@ -2562,38 +2562,38 @@ summary.bartmodel <- function(bart_model, ...) {

# Random effects
# TODO: add random effects summaries once indexing is fixed
if (bart_model$model_params$has_rfx) {
# rfx_summary <- getRandomEffectSamples(bart_model)
if (object$model_params$has_rfx) {
# rfx_summary <- getRandomEffectSamples(object)
# ...
}

# Return bart_model invisibly
invisible(bart_model)
invisible(object)
}

#' @title Plot the BART model fit.
#' @description Plot the BART model fit and any relevant sampled quantities. This will default to a traceplot of the global error scale and the in-sample mean forest predictions for the first train set observation. Since `stochtree::bart()` is flexible and it's possible to sample a model with a fixed global error scale and no mean forest, this procedure is adaptive and will attempt to plot a trace of whichever model terms are included if these two default terms are omitted.
#' @param bart_model The BART model object
#' @param x The BART model object
#' @param ... Additional arguments
#' @export
#' @return BART model object unchanged after summarizing
plot.bartmodel <- function(bart_model, ...) {
plot.bartmodel <- function(x, ...) {
# Check if model has global error scale samples
has_sigma2_samples <- bart_model$model_params$sample_sigma2_global
has_mean_forest_preds <- !is.null(bart_model$y_hat_train)
has_sigma2_samples <- x$model_params$sample_sigma2_global
has_mean_forest_preds <- !is.null(x$y_hat_train)

# First try combinations of sigma2 and mean forest predictions
if (has_sigma2_samples || has_mean_forest_preds) {
if (has_sigma2_samples) {
plot(
bart_model$sigma2_global_samples,
x$sigma2_global_samples,
type = "l",
ylab = "Sigma^2",
main = "Global error scale traceplot"
)
} else if (has_mean_forest_preds) {
plot(
bart_model$y_hat_train[1, ],
x$y_hat_train[1, ],
type = "l",
ylab = "Predictions",
main = "In-sample mean function trace for the first train set observation"
Expand All @@ -2605,8 +2605,8 @@ plot.bartmodel <- function(bart_model, ...) {
)
}

# Return bart_model invisibly
invisible(bart_model)
# Return x invisibly
invisible(x)
}

#' Extract raw sample values for each of the random effect parameter terms.
Expand Down
Loading
Loading