Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# bayesplot (development version)

* Validate equal chain lengths in `validate_df_with_chain()`, reject missing
chain labels, and renumber data-frame chain labels internally when converting
to arrays.
* Added unit tests for previously untested edge cases in `param_range()`, `param_glue()`, and `tidyselect_parameters()` (no-match, partial-match, and negation behavior).
* Bumped minimum version for `rstantools` from `>= 1.5.0` to `>= 2.0.0` .
* Use `rlang::warn()` and `rlang::inform()` for selected PPC user messages instead of base `warning()` and `message()`.
Expand Down
12 changes: 11 additions & 1 deletion R/helpers-mcmc.R
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,13 @@ validate_df_with_chain <- function(x) {
x$chain <- NULL
}
x$Chain <- as.integer(x$Chain)
if (anyNA(x$Chain)) {
abort("Chain values must not be NA.")
}
rows_per_chain <- table(x$Chain)
if (length(unique(rows_per_chain)) != 1) {
abort("All chains must have the same number of iterations.")
}
x
}

Expand All @@ -218,11 +225,14 @@ validate_df_with_chain <- function(x) {
df_with_chain2array <- function(x) {
x <- validate_df_with_chain(x)
chain <- x$Chain
# Renumber arbitrary chain labels to the contiguous 1:N indices used internally.
chain <- match(chain, sort(unique(chain)))
n_chain <- length(unique(chain))
a <- x[, !colnames(x) %in% "Chain", drop = FALSE]
parnames <- colnames(a)
a <- as.matrix(a)
x <- array(NA, dim = c(ceiling(nrow(a) / n_chain), n_chain, ncol(a)))
n_iter <- nrow(a) %/% n_chain
x <- array(NA, dim = c(n_iter, n_chain, ncol(a)))
for (j in seq_len(n_chain)) {
x[, j, ] <- a[chain == j,, drop=FALSE]
}
Expand Down
4 changes: 3 additions & 1 deletion R/mcmc-overview.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
#' frame with one column per parameter (if only a single chain or all chains
#' have already been merged), or a data frame with one column per parameter plus
#' an additional column `"Chain"` that contains the chain number (an integer)
#' corresponding to each row in the data frame.
#' corresponding to each row in the data frame. When a `"Chain"` column is
#' supplied, each chain must have the same number of iterations. Chain labels
#' are used to identify groups and are renumbered internally to `1:N`.
#' * __draws__: Any of the `draws` formats supported by the
#' \pkg{posterior} package.
#'
Expand Down
4 changes: 3 additions & 1 deletion man/MCMC-overview.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 31 additions & 0 deletions tests/testthat/test-helpers-mcmc.R
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,43 @@ test_that("validate_df_with_chain works", {
tbl <- tibble::tibble(parameter=rnorm(n=40), Chain=rep(1:4, each=10))
a <- validate_df_with_chain(tbl)
expect_type(a$Chain, "integer")

missing_chain_df <- data.frame(
Chain = c(1L, 1L, NA_integer_, NA_integer_),
V1 = rnorm(4),
V2 = rnorm(4)
)
expect_error(validate_df_with_chain(missing_chain_df),
"Chain values must not be NA")
})

test_that("df_with_chain2array works", {
a <- df_with_chain2array(dframe_multiple_chains)
expect_mcmc_array(a)

expect_error(df_with_chain2array(dframe), "is_df_with_chain")

# Unequal chain lengths should error via validate_df_with_chain
unequal_df <- data.frame(
Chain = c(1L, 1L, 1L, 1L, 2L, 2L, 2L),
V1 = rnorm(7),
V2 = rnorm(7)
)
expect_error(validate_df_with_chain(unequal_df),
"All chains must have the same number of iterations")
expect_error(df_with_chain2array(unequal_df),
"All chains must have the same number of iterations")

renumbered_df <- data.frame(
Chain = c(2L, 2L, 3L, 3L),
V1 = 1:4,
V2 = 5:8
)
a <- df_with_chain2array(renumbered_df)
expect_equal(dim(a), c(2, 2, 2))
expect_identical(unname(a[, 1, "V1"]), c(1L, 2L))
expect_identical(unname(a[, 2, "V1"]), c(3L, 4L))
expect_identical(as.character(dimnames(a)$Chain), c("1", "2"))
})


Expand Down Expand Up @@ -305,6 +335,7 @@ test_that("diagnostic_factor.rhat works", {
)
expect_identical(levels(r), c("low", "ok", "high"))
})

test_that("diagnostic_factor.neff_ratio works", {
ratios <- new_neff_ratio(c(low = 0.05, low = 0.01,
ok = 0.2, ok = 0.49,
Expand Down
Loading