Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ Suggests:
withr (>= 3.0.0)
Encoding: UTF-8
Language: en-US
RoxygenNote: 7.3.3
Roxygen: list(markdown = TRUE)
Config/testthat/edition: 3
Config/testthat/parallel: true
Expand All @@ -169,3 +168,4 @@ Config/Needs/website:
easystats/easystatstemplate
Config/rcmdcheck/ignore-inconsequential-notes: true
Remotes: easystats/insight, easystats/modelbased
Config/roxygen2/version: 8.0.0
1 change: 0 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ S3method(as.data.frame,check_outliers)
S3method(as.data.frame,icc)
S3method(as.data.frame,looic)
S3method(as.data.frame,performance_accuracy)
S3method(as.data.frame,performance_cv)
S3method(as.data.frame,performance_pcp)
S3method(as.data.frame,performance_score)
S3method(as.data.frame,r2_bayes)
Expand Down
282 changes: 148 additions & 134 deletions R/performance_cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
#' @param method Character string, indicating the cross-validation method to use:
#' whether holdout (`"holdout"`, aka train-test), k-fold (`"k_fold"`), or
#' leave-one-out (`"loo"`). If `data` is supplied, this argument is ignored.
#' @param metrics Can be `"all"`, `"common"` or a character vector of metrics to be
#' computed (some of `c("ELPD", "Deviance", "MSE", "RMSE", "R2")`). "common" will
#' compute R2 and RMSE.
#' @param metrics Can be `"all"`, `"common"` or a character vector of metrics to
#' be computed (some of `"MSE"`, `"RMSE"`, `"R2"`). `"common"` will compute R2
#' and RMSE.
#' @param prop If `method = "holdout"`, what proportion of the sample to hold
#' out as the test sample?
#' @param k If `method = "k_fold"`, the number of folds to use.
Expand Down Expand Up @@ -46,161 +46,187 @@
verbose = TRUE,
...
) {
# 1. Parse and standardize requested metrics
if (all(metrics == "all")) {
metrics <- c("MSE", "RMSE", "R2")
} else if (all(metrics == "common")) {
metrics <- c("RMSE", "R2")
} else {
metrics <- toupper(metrics)
metrics[metrics == "DEVIANCE"] <- "Deviance"
}

# Warn user about requested metrics that aren't implemented yet
missing_metrics <- setdiff(metrics, c("MSE", "RMSE", "R2"))
if (length(missing_metrics)) {
insight::format_error(
paste0(
"Metric",
ifelse(length(missing_metrics) > 1, "s '", " '"),
paste0(missing_metrics, collapse = "', '"),

Check warning on line 65 in R/performance_cv.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/performance_cv.R,line=65,col=9,[paste_linter] Use paste(), not paste0(), to collapse a character vector when sep= is not used.
"' not yet supported."
)
)
}

# 2. Validate cross-validation method (ignored if external data is provided)
if (is.null(data)) {
method <- insight::validate_argument(method, c("holdout", "k_fold", "loo"))
}

# 3. Model compatibility checks
if (!is.null(data) && inherits(model, "BFBayesFactor")) {
insight::format_error("Models of class 'BFBayesFactor' not yet supported.")
}

# Extract foundational model information using insight
resp.name <- insight::find_response(model)
model_data <- insight::get_data(model, verbose = FALSE)
info <- insight::model_info(model, verbose = verbose)
if (info$is_linear) {
if (!is.null(data)) {
method <- "holdout"
out <- NULL

if (!info$is_linear) {
insight::format_error("Only linear models currently supported.")
}

# 4. Core CV logic (currently restricted to linear models)
# Branch A: Out-of-sample prediction (external data provided)
if (!is.null(data)) {
method <- "holdout"
stack <- TRUE # Stacking is implicit for a single test set
test_resp <- data[, resp.name]
test_pred <- insight::get_predicted(model, ci = NULL, data = data)
test_resd <- test_resp - test_pred

# Branch B: Standard Holdout (Train-Test Split)
} else if (method == "holdout") {
# Sample indices for the training set based on 'prop'
train_i <- sample.int(
nrow(model_data),
size = round((1 - prop) * nrow(model_data)),
replace = FALSE
)

# Re-fit model on training data and predict on the holdout (test) data
model_upd <- stats::update(model, data = model_data[train_i, ])
test_resp <- model_data[-train_i, resp.name]
test_pred <- insight::get_predicted(
model_upd,
ci = NULL,
data = model_data[-train_i, ]
)
test_resd <- test_resp - test_pred

# Branch C: Leave-One-Out (LOO) Fast Approximation for Linear Models
} else if (method == "loo" && !info$is_bayesian) {
# Mathematical shortcut for exact LOO MSE using hat values (leverage),
# avoiding the need to refit the model N times.
model_response <- insight::get_response(model)
MSE <- mean(
insight::get_residuals(model, weighted = TRUE)^2 / (1 - stats::hatvalues(model))^2
)
RMSE <- sqrt(MSE)
# fmt: skip
R2 <- 1 - MSE / (mean(model_response^2, na.rm = TRUE) - mean(model_response, na.rm = TRUE)^2)
out <- data.frame(MSE = MSE, RMSE = RMSE, R2 = R2)

# Branch D: K-Fold CV (or explicit LOO for non-linear/Bayesian models)
} else {
if (method == "loo") {
if (info$is_bayesian && verbose) {
insight::format_alert(
"Simple LOO cross-validation can be very slow for MCMC models.",
"Try loo::loo() instead."
)
}
stack <- TRUE
test_resp <- data[, resp.name]
test_pred <- insight::get_predicted(model, ci = NULL, data = data)
test_resd <- test_resp - test_pred
} else if (method == "holdout") {
train_i <- sample.int(
nrow(model_data),
size = round((1 - prop) * nrow(model_data)),
replace = FALSE
)
model_upd <- stats::update(model, data = model_data[train_i, ])
test_resp <- model_data[-train_i, resp.name]
test_pred <- insight::get_predicted(
model_upd,
ci = NULL,
data = model_data[-train_i, ]
k <- nrow(model_data) # LOO is just k-fold where k = N
}

# Catch misconfigured k values
if (k > nrow(model_data)) {
insight::format_alert(
"Requested number of folds (k) larger than the sample size.",
"'k' set equal to the sample size (leave-one-out [LOO])."
)
k <- nrow(model_data)
}

# Generate fold indices and re-fit the model across all folds
cv_folds <- .crossv_kfold(model_data, k = k)
models_upd <- lapply(cv_folds, function(.x) {
stats::update(model, data = model_data[.x$train, ])
})

# Extract predictions and actual responses for each fold
test_pred <- mapply(

Check warning on line 161 in R/performance_cv.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/performance_cv.R,line=161,col=18,[undesirable_function_linter] Avoid undesirable function "mapply".
function(.x, .y) {
insight::get_predicted(.y, ci = NULL, data = model_data[.x$test, ])
},
cv_folds,
models_upd,
SIMPLIFY = FALSE
)
test_resp <- lapply(cv_folds, function(.x) {
as.data.frame(model_data[.x$test, ])[[resp.name]]
})
}

# 5. Aggregate Metrics (Stacking vs Averaging)

# Skip aggregation if the fast LOO approximation already computed 'out'
if (is.null(out)) {
if (isTRUE(stack)) {
# Pool all fold predictions/residuals together to calculate one global metric
test_resp <- unlist(test_resp)
test_pred <- unlist(test_pred)
test_resd <- test_resp - test_pred
} else if (method == "loo" && !info$is_bayesian) {
model_response <- insight::get_response(model)
MSE <- mean(
insight::get_residuals(model, weighted = TRUE)^2 /
(1 - stats::hatvalues(model))^2
)
mean(test_resd^2, na.rm = TRUE)

MSE <- mean(test_resd^2, na.rm = TRUE)
RMSE <- sqrt(MSE)
R2 <- 1 -
MSE /
(mean(model_response^2, na.rm = TRUE) - mean(model_response, na.rm = TRUE)^2)
R2 <- 1 - MSE / mean((test_resp - mean(test_resp, na.rm = TRUE))^2, na.rm = TRUE)
out <- data.frame(MSE = MSE, RMSE = RMSE, R2 = R2)
} else {
# Manual method for LOO, use this for non-linear and Bayesian models
if (method == "loo") {
if (info$is_bayesian && verbose) {
insight::format_alert(
"Simple LOO cross-validation can be very slow for MCMC models.",
"Try loo::loo() instead."
)
}
stack <- TRUE
k <- nrow(model_data)
}
if (k > nrow(model_data)) {
message(insight::color_text(
insight::format_message(
"Requested number of folds (k) larger than the sample size.",
"'k' set equal to the sample size (leave-one-out [LOO])."
),
color = "yellow"
))
k <- nrow(model_data)
}
cv_folds <- .crossv_kfold(model_data, k = k)
models_upd <- lapply(cv_folds, function(.x) {
stats::update(model, data = model_data[.x$train, ])
})
test_pred <- mapply(
# Calculate metrics *within* each fold, then average the results
test_resd <- mapply(

Check warning on line 190 in R/performance_cv.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/performance_cv.R,line=190,col=20,[undesirable_function_linter] Avoid undesirable function "mapply".
function(.x, .y) {
insight::get_predicted(.y, ci = NULL, data = model_data[.x$test, ])
.x - .y
},
cv_folds,
models_upd,
test_resp,
test_pred,
SIMPLIFY = FALSE
)
test_resp <- lapply(cv_folds, function(.x) {
as.data.frame(model_data[.x$test, ])[[resp.name]]

MSEs <- sapply(test_resd, function(x) mean(x^2, na.rm = TRUE))
RMSEs <- sqrt(MSEs)
resp_vars <- sapply(test_resp, function(x) {
mean((x - mean(x, na.rm = TRUE))^2, na.rm = TRUE)
})
R2s <- 1 - MSEs / resp_vars

out <- data.frame(
MSE = mean(MSEs),
MSE_SE = stats::sd(MSEs),
RMSE = mean(RMSEs),
RMSE_SE = stats::sd(RMSEs),
R2 = mean(R2s),
R2_SE = stats::sd(R2s)
)
}
} else {
insight::format_error("Only linear models currently supported.")
}
if (isTRUE(stack)) {
test_resp <- unlist(test_resp)
test_pred <- unlist(test_pred)
test_resd <- test_resp - test_pred
MSE <- mean(test_resd^2, na.rm = TRUE)
RMSE <- sqrt(MSE)
R2 <- 1 - MSE / mean((test_resp - mean(test_resp, na.rm = TRUE))^2, na.rm = TRUE)
out <- data.frame(MSE = MSE, RMSE = RMSE, R2 = R2)
} else {
test_resd <- mapply(
function(.x, .y) {
.x - .y
},
test_resp,
test_pred,
SIMPLIFY = FALSE
)
MSEs <- sapply(test_resd, function(x) mean(x^2, na.rm = TRUE))
RMSEs <- sqrt(MSEs)
resp_vars <- sapply(test_resp, function(x) {
mean((x - mean(x, na.rm = TRUE))^2, na.rm = TRUE)
})
R2s <- 1 - MSEs / resp_vars
out <- data.frame(
MSE = mean(MSEs),
MSE_SE = stats::sd(MSEs),
RMSE = mean(RMSEs),
RMSE_SE = stats::sd(RMSEs),
R2 = mean(R2s),
R2_SE = stats::sd(R2s)
)
}

out <- out[, colnames(out) %in% c(metrics, paste0(metrics, "_SE"))]
# 6. Format Final Output
# Filter the data frame to include only the requested metrics
out <- out[, colnames(out) %in% c(metrics, paste0(metrics, "_SE")), drop = FALSE]

attr(out, "method") <- method
attr(out, "k") <- if (method == "k_fold") k
attr(out, "prop") <- if (method == "holdout") prop
missing_metrics <- setdiff(metrics, c("MSE", "RMSE", "R2"))
if (length(missing_metrics)) {
message(insight::colour_text(
insight::format_message(
paste0(
"Metric",
ifelse(length(missing_metrics) > 1, "s '", " '"),
paste0(missing_metrics, collapse = "', '"),
"' not yet supported."
)
),
colour = "red"
))
}

class(out) <- c("performance_cv", "data.frame")
return(out)
out
}

# TODO: implement performance::log_lik() function for deviance/elpd metrics
# - When given a model, it should pass it to insight::get_loglikelihood, stats4::logLik, stats::logLik, or rstantools::log_lik
# - When given a model and new data, it should pass to rstantools::log_lik if stan
# or compute a df like this:
# df <- list(residuals = cv_residuals); class(df) <- class(model)
# then pass this df to stats4::logLik or stats::logLik
# - for model classes that do not compute their ll inside of logLik,
# then compute the ll by running:
# logLik(update(model, formula = {{response}} ~ 0, offset = predict(model, newdata), data = newdata))
# TODO: implement performance::log_lik() function for deviance/elpd metrics ...

# methods ----------------------------------

Expand Down Expand Up @@ -238,15 +264,3 @@
))
invisible(x)
}

#' @export
as.data.frame.performance_cv <- function(x, row.names = NULL, ...) {
data.frame(
Accuracy = x$Accuracy,
SE = x$SE,
Method = x$Method,
stringsAsFactors = FALSE,
row.names = row.names,
...
)
}
22 changes: 11 additions & 11 deletions man/check_autocorrelation.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading