Skip to content
Open
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ S3method(get_pit_histogram,forecast_quantile)
S3method(get_pit_histogram,forecast_sample)
S3method(head,forecast)
S3method(print,forecast)
S3method(print,forecast_multivariate_point)
S3method(print,forecast_multivariate_sample)
S3method(score,default)
S3method(score,forecast_binary)
S3method(score,forecast_multivariate_point)
Expand Down
15 changes: 15 additions & 0 deletions R/class-forecast-multivariate-point.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,21 @@ is_forecast_multivariate_point <- function(x) {
# nolint end


#' @title Print information about a multivariate point forecast object
#' @description
#' This function prints information about a multivariate point forecast object,
#' including "Forecast type", "Forecast unit", and "Joint across" columns.
#'
#' @param x A forecast object of class `forecast_multivariate_point`.
#' @param ... Additional arguments for [print()].
#' @returns Returns `x` invisibly.
#' @export
#' @keywords gain-insights
print.forecast_multivariate_point <- function(x, ...) {
print_multivariate_forecast(x, ...)
}


#' @importFrom stats na.omit
#' @importFrom data.table setattr copy
#' @rdname score
Expand Down
76 changes: 76 additions & 0 deletions R/class-forecast-multivariate-sample.R
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,82 @@ is_forecast_multivariate_sample <- function(x) {
# nolint end


#' @title Print information about a multivariate forecast object
#' @description
#' This function prints information about a multivariate forecast object,
#' including "Forecast type", "Forecast unit", and "Joint across" columns.
#'
#' @param x A forecast object of class `forecast_multivariate_sample`.
#' @param ... Additional arguments for [print()].
#' @returns Returns `x` invisibly.
#' @importFrom cli col_blue cli_text
#' @export
#' @keywords gain-insights
print.forecast_multivariate_sample <- function(x, ...) {
print_multivariate_forecast(x, ...)
}


#' Print helper for multivariate forecast objects
#'
#' Shared implementation for printing multivariate forecast objects.
#' Displays forecast type, forecast unit, and "Joint across" columns.
#' @param x A multivariate forecast object.
#' @param ... Additional arguments passed to [print()].
#' @returns Returns `x` invisibly.
#' @importFrom cli col_blue cli_text cli_inform
#' @keywords internal
print_multivariate_forecast <- function(x, ...) {
forecast_type <- try(
do.call(get_forecast_type, list(forecast = x)),
silent = TRUE
)
forecast_unit <- try(
do.call(get_forecast_unit, list(data = x)),
silent = TRUE
)

if (inherits(forecast_type, "try-error")) {
cli_inform(
c(
"!" = "Could not determine forecast type due to error in validation."
)
)
} else {
cli_text(
col_blue("Forecast type: "),
"{forecast_type}"
)
}

if (inherits(forecast_unit, "try-error")) {
cli_inform(
c(
"!" = "Could not determine forecast unit."
)
)
} else {
cli_text(col_blue("Forecast unit:"))
cli_text("{forecast_unit}")
}

joint_across <- try(
setdiff(get_forecast_unit(x), get_grouping(x)),
silent = TRUE
)
if (!inherits(joint_across, "try-error") && length(joint_across) > 0) {
cli_text(col_blue("Joint across:"))
cli_text("{joint_across}")
}

cat("\n")

print(as.data.table(x), ...)

return(invisible(x))
}


#' @importFrom stats na.omit
#' @importFrom data.table setattr copy
#' @importFrom methods formalArgs
Expand Down
22 changes: 0 additions & 22 deletions R/class-forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -451,28 +451,6 @@ print.forecast <- function(x, ...) {
)
}

# For multivariate forecasts, show joint_across
if (".mv_group_id" %in% names(x)) {
joint_across <- try(
setdiff(
get_forecast_unit(x),
get_grouping(x)
),
silent = TRUE
)
if (!inherits(joint_across, "try-error") &&
length(joint_across) > 0) {
cli_text(
col_blue(
"Joint across:"
)
)
cli_text(
"{joint_across}"
)
}
}

cat("\n")

NextMethod()
Expand Down
21 changes: 21 additions & 0 deletions man/print.forecast_multivariate_point.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 21 additions & 0 deletions man/print.forecast_multivariate_sample.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 21 additions & 0 deletions man/print_multivariate_forecast.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

41 changes: 41 additions & 0 deletions tests/testthat/_snaps/class-forecast-multivariate-point.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,44 @@
Output
Number of non-NA scores: 224

# print.forecast_multivariate_point() displays joint_across columns

Code
print(result)
Message
Forecast type: multivariate_point
Forecast unit:
location, model, target_type, target_end_date, and horizon
Joint across:
location
Output

Key: <location, target_end_date, target_type>
Index: <.mv_group_id>
observed predicted location model target_type
<num> <int> <char> <char> <char>
1: 106987 119258 DE EuroCOVIDhub-ensemble Cases
2: 106987 132607 DE EuroCOVIDhub-baseline Cases
3: 106987 151179 DE epiforecasts-EpiNow2 Cases
4: 1582 1568 DE EuroCOVIDhub-ensemble Deaths
5: 1582 1597 DE EuroCOVIDhub-baseline Deaths
---
883: 78 131 IT EuroCOVIDhub-baseline Deaths
884: 78 79 IT UMass-MechBayes Deaths
885: 78 124 IT UMass-MechBayes Deaths
886: 78 104 IT epiforecasts-EpiNow2 Deaths
887: 78 186 IT epiforecasts-EpiNow2 Deaths
target_end_date horizon .mv_group_id
<Date> <num> <int>
1: 2021-05-08 1 1
2: 2021-05-08 1 2
3: 2021-05-08 1 3
4: 2021-05-08 1 4
5: 2021-05-08 1 5
---
883: 2021-07-24 2 220
884: 2021-07-24 3 221
885: 2021-07-24 2 222
886: 2021-07-24 3 223
887: 2021-07-24 2 224

79 changes: 79 additions & 0 deletions tests/testthat/_snaps/class-forecast-multivariate-sample.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,82 @@
Output
Energy score range: 37.8373892350605 to 433525.521054322

# print.forecast_multivariate_sample() displays joint_across columns

Code
print(example_multivariate_sample)
Message
Forecast type: multivariate_sample
Forecast unit:
location, location_name, target_end_date, target_type, forecast_date, model,
and horizon
Joint across:
location and location_name
Output

location location_name target_end_date target_type forecast_date
<char> <char> <Date> <char> <Date>
1: DE Germany 2021-01-02 Cases <NA>
2: DE Germany 2021-01-02 Deaths <NA>
3: DE Germany 2021-01-09 Cases <NA>
4: DE Germany 2021-01-09 Deaths <NA>
5: DE Germany 2021-01-16 Cases <NA>
---
35620: IT Italy 2021-07-24 Deaths 2021-07-12
35621: IT Italy 2021-07-24 Deaths 2021-07-12
35622: IT Italy 2021-07-24 Deaths 2021-07-12
35623: IT Italy 2021-07-24 Deaths 2021-07-12
35624: IT Italy 2021-07-24 Deaths 2021-07-12
model horizon predicted sample_id observed .mv_group_id
<char> <num> <num> <int> <num> <int>
1: <NA> NA NA NA 127300 1
2: <NA> NA NA NA 4534 2
3: <NA> NA NA NA 154922 3
4: <NA> NA NA NA 6117 4
5: <NA> NA NA NA 110183 5
---
35620: epiforecasts-EpiNow2 2 159.84534 36 78 260
35621: epiforecasts-EpiNow2 2 128.21214 37 78 260
35622: epiforecasts-EpiNow2 2 190.52560 38 78 260
35623: epiforecasts-EpiNow2 2 141.06659 39 78 260
35624: epiforecasts-EpiNow2 2 24.43419 40 78 260

# print.forecast_multivariate_sample() shows correct joint_across for single-column grouping

Code
print(result)
Message
Forecast type: multivariate_sample
Forecast unit:
location, model, target_type, target_end_date, and horizon
Joint across:
location
Output

predicted sample_id observed location model
<num> <int> <num> <char> <char>
1: 102672.00034 1 106987 DE EuroCOVIDhub-ensemble
2: 164763.08492 2 106987 DE EuroCOVIDhub-ensemble
3: 153042.63536 3 106987 DE EuroCOVIDhub-ensemble
4: 119544.25389 4 106987 DE EuroCOVIDhub-ensemble
5: 81230.71875 5 106987 DE EuroCOVIDhub-ensemble
---
35476: 159.84534 36 78 IT epiforecasts-EpiNow2
35477: 128.21214 37 78 IT epiforecasts-EpiNow2
35478: 190.52560 38 78 IT epiforecasts-EpiNow2
35479: 141.06659 39 78 IT epiforecasts-EpiNow2
35480: 24.43419 40 78 IT epiforecasts-EpiNow2
target_type target_end_date horizon .mv_group_id
<char> <Date> <num> <int>
1: Cases 2021-05-08 1 1
2: Cases 2021-05-08 1 1
3: Cases 2021-05-08 1 1
4: Cases 2021-05-08 1 1
5: Cases 2021-05-08 1 1
---
35476: Deaths 2021-07-24 2 224
35477: Deaths 2021-07-24 2 224
35478: Deaths 2021-07-24 2 224
35479: Deaths 2021-07-24 2 224
35480: Deaths 2021-07-24 2 224

24 changes: 24 additions & 0 deletions tests/testthat/test-class-forecast-multivariate-point.R
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,27 @@ test_that(
)
}
)


# ==============================================================================
# print.forecast_multivariate_point()
# ==============================================================================
test_that("print.forecast_multivariate_point() displays joint_across columns", {
result <- make_mv_point()
expect_snapshot(print(result))
})

test_that("print.forecast_multivariate_point() returns object invisibly", {
result <- make_mv_point()
expect_invisible(print(result))
out <- print(result)
expect_identical(out, result)
})

test_that("print.forecast_multivariate_point() shows expected sections", {
result <- make_mv_point()
out <- capture.output(print(result), type = "message")
expect_true(any(grepl("Forecast type:", out, fixed = TRUE)))
expect_true(any(grepl("Forecast unit:", out, fixed = TRUE)))
expect_true(any(grepl("Joint across:", out, fixed = TRUE)))
})
Loading
Loading