-
Notifications
You must be signed in to change notification settings - Fork 25
Handle missing forecasts before summarisation #1156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
42bf377
9342685
8d81290
3e1298d
b15c5cf
6253cab
26fd2d8
1d50511
06b65e7
dbcddb2
6ca383f
797b926
048be7c
3d266ab
a84aea2
3fde364
3ff432a
3490100
e11bc62
eb1ebcd
8856e8d
bf60b74
12ef1a7
d9aa0fd
3cfcfcc
22c790f
a7934c6
7c88db2
1803d9c
cfdc3c4
1cf4beb
f9f91f1
5da130a
8e380f8
1953157
0d3cf4b
5a4beee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,228 @@ | ||
| #' @title Filter scores | ||
| #' | ||
| #' @description | ||
| #' Filter a `scores` object using a supplied strategy function. | ||
| #' `filter_scores()` is responsible for preserving the `scores` | ||
| #' class and the `metrics` attribute; the strategy is | ||
| #' responsible only for the filtering logic. | ||
| #' | ||
| #' Strategies are constructed by helpers such as | ||
| #' [filter_to_intersection()] and [filter_to_include()] and can | ||
| #' also be user-defined. A strategy is a function with | ||
| #' signature `function(scores, compare)` that returns a | ||
| #' filtered data.table with the same columns as its input. | ||
| #' | ||
| #' @param scores An object of class `scores` (a data.table with | ||
| #' an additional `metrics` attribute as produced by [score()]). | ||
| #' @param strategy A strategy function. See Description for the | ||
| #' expected signature. Default: [filter_to_intersection()]. | ||
| #' @param compare Character string (default `"model"`) naming the | ||
| #' column whose values are compared when deciding which | ||
| #' target combinations to keep. | ||
| #' | ||
| #' @return A `scores` object with the same class and `metrics` | ||
| #' attribute as the input, with rows filtered according to | ||
| #' `strategy`. | ||
| #' | ||
| #' @seealso [filter_to_intersection()], [filter_to_include()], | ||
| #' \code{vignette("handling-missing-forecasts")} | ||
| #' @importFrom cli cli_inform | ||
| #' @importFrom checkmate assert_class assert_character | ||
| #' assert_subset | ||
| #' @export | ||
| #' @keywords postprocess-scores | ||
| #' @examples | ||
| #' \dontshow{ | ||
| #' data.table::setDTthreads(2) | ||
| #' } | ||
| #' scores <- example_quantile |> | ||
| #' as_forecast_quantile() |> | ||
| #' score() | ||
| #' | ||
| #' # Keep only targets covered by every model (the default) | ||
| #' filter_scores(scores) | ||
| #' | ||
| #' # Keep targets covered by at least 75% of models | ||
| #' filter_scores( | ||
| #' scores, | ||
| #' strategy = filter_to_intersection(min_coverage = 0.75) | ||
| #' ) | ||
| #' | ||
| #' # Keep only targets covered by a named model | ||
| #' filter_scores( | ||
| #' scores, | ||
| #' strategy = filter_to_include("EuroCOVIDhub-baseline") | ||
| #' ) | ||
| filter_scores <- function( | ||
| scores, | ||
| strategy = filter_to_intersection(), | ||
|
sbfnk marked this conversation as resolved.
|
||
| compare = "model" | ||
|
sbfnk marked this conversation as resolved.
|
||
| ) { | ||
| assert_class(scores, "scores") | ||
| assert_character(compare, len = 1) | ||
| assert_subset(compare, names(scores)) | ||
| assert_strategy(strategy, required = "compare") | ||
|
|
||
| original_metrics <- attr(scores, "metrics") | ||
|
|
||
| result <- strategy(scores, compare = compare) | ||
|
sbfnk marked this conversation as resolved.
|
||
|
|
||
| n_before <- nrow(scores) | ||
| n_after <- nrow(result) | ||
| n_dropped <- n_before - n_after | ||
|
|
||
| if (n_dropped == 0) { | ||
| cli_inform(c( | ||
| i = "No rows filtered. Returning scores unchanged." | ||
| )) | ||
| return(scores) | ||
| } | ||
|
|
||
| cli_inform(c( | ||
| i = "Filtered out {n_dropped} rows.", | ||
| i = "{n_after} of {n_before} rows remaining." # nolint: duplicate_argument_linter | ||
| )) | ||
|
|
||
| return(new_scores(result, original_metrics)) | ||
| } | ||
|
|
||
|
|
||
| #' @title Filter to target combinations meeting a coverage threshold | ||
| #' | ||
| #' @description | ||
| #' Strategy for [filter_scores()] that keeps target combinations | ||
| #' covered by at least `min_coverage` of the values in the | ||
| #' `compare` column. With the default `min_coverage = 1`, only | ||
| #' target combinations present for every compare value are kept | ||
| #' (strict intersection across the full set). | ||
| #' | ||
| #' To restrict to the targets covered by a named subset of | ||
| #' compare values instead of by a proportion, use | ||
| #' [filter_to_include()]. | ||
| #' | ||
| #' @param min_coverage Numeric between 0 and 1 (default `1`). | ||
| #' Minimum proportion of compare values that must cover a | ||
| #' target combination for it to be kept. | ||
| #' | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. looking through the code I think
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. More generally the documentation is quite dense (I didn't understand it before looking at the code) and really could do with some examples.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Split:
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Docs rewritten and examples added on both strategy factories. 0d3cf4b. |
||
| #' @return A strategy function for [filter_scores()]. Intended | ||
| #' to be passed to `filter_scores()` rather than called | ||
| #' directly — `filter_scores()` is where the `scores` class | ||
| #' and `metrics` attribute are preserved. | ||
| #' | ||
| #' @seealso [filter_scores()], [filter_to_include()] | ||
| #' @importFrom data.table as.data.table setkeyv uniqueN | ||
| #' @importFrom checkmate assert_number | ||
| #' @export | ||
| #' @keywords postprocess-scores | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. examples?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| #' @examples | ||
| #' \dontshow{ | ||
| #' data.table::setDTthreads(2) | ||
| #' } | ||
| #' scores <- example_quantile |> | ||
| #' as_forecast_quantile() |> | ||
| #' score() | ||
| #' filter_scores( | ||
| #' scores, | ||
| #' strategy = filter_to_intersection(min_coverage = 0.75) | ||
| #' ) | ||
| filter_to_intersection <- function(min_coverage = 1) { | ||
| assert_number(min_coverage, lower = 0, upper = 1) | ||
|
|
||
| function(scores, compare = "model") { | ||
| scores <- data.table::as.data.table(scores) | ||
| forecast_unit <- get_forecast_unit(scores) | ||
| target_cols <- setdiff(forecast_unit, compare) | ||
|
|
||
| n_total <- data.table::uniqueN(scores[[compare]]) | ||
|
|
||
| target_coverage <- scores[, | ||
| .(n_compare = data.table::uniqueN(get(compare))), | ||
| by = target_cols | ||
| ] | ||
|
|
||
| keep <- target_coverage$n_compare / n_total >= min_coverage | ||
| qualifying <- target_coverage[keep, target_cols, with = FALSE] | ||
|
|
||
| data.table::setkeyv(scores, target_cols) | ||
| data.table::setkeyv(qualifying, target_cols) | ||
| scores[qualifying, nomatch = NULL] | ||
| } | ||
| } | ||
|
|
||
|
|
||
| #' @title Filter to targets covered by named compare values | ||
| #' | ||
| #' @description | ||
| #' Strategy for [filter_scores()] that restricts the kept | ||
| #' target combinations to those covered by every value listed | ||
| #' in `include`. With a single value this keeps only that | ||
| #' value's targets; with several values, the intersection of | ||
| #' their target sets is kept. | ||
| #' | ||
| #' To use a proportion-based threshold over all compare values | ||
| #' instead, use [filter_to_intersection()]. | ||
| #' | ||
| #' @param include Character vector of length one or more. Values | ||
| #' from the `compare` column whose target sets should be | ||
| #' intersected. | ||
| #' | ||
| #' @return A strategy function for [filter_scores()]. Intended | ||
| #' to be passed to `filter_scores()` rather than called | ||
| #' directly — `filter_scores()` is where the `scores` class | ||
| #' and `metrics` attribute are preserved. | ||
| #' | ||
| #' @seealso [filter_scores()], [filter_to_intersection()] | ||
| #' @importFrom data.table as.data.table setkeyv | ||
| #' @importFrom checkmate assert_character | ||
| #' @importFrom cli cli_abort | ||
| #' @export | ||
| #' @keywords postprocess-scores | ||
| #' @examples | ||
| #' \dontshow{ | ||
| #' data.table::setDTthreads(2) | ||
| #' } | ||
| #' scores <- example_quantile |> | ||
| #' as_forecast_quantile() |> | ||
| #' score() | ||
| #' filter_scores( | ||
| #' scores, | ||
| #' strategy = filter_to_include("EuroCOVIDhub-baseline") | ||
| #' ) | ||
| filter_to_include <- function(include) { | ||
| assert_character(include, min.len = 1) | ||
|
|
||
| function(scores, compare = "model") { | ||
| scores <- data.table::as.data.table(scores) | ||
| forecast_unit <- get_forecast_unit(scores) | ||
| target_cols <- setdiff(forecast_unit, compare) | ||
|
|
||
| unknown <- setdiff(include, unique(scores[[compare]])) | ||
| if (length(unknown) > 0) { | ||
| cli_abort(c( | ||
| "!" = paste0( | ||
| "{.val {unknown}} not found in ", | ||
| "{.arg {compare}} column." | ||
| ) | ||
| )) | ||
| } | ||
|
|
||
| target_sets <- lapply(include, function(v) { | ||
| unique( | ||
| scores[ | ||
| scores[[compare]] == v, | ||
| target_cols, | ||
| with = FALSE | ||
| ] | ||
| ) | ||
| }) | ||
|
|
||
| qualifying <- Reduce( | ||
| function(a, b) merge(a, b, by = target_cols), | ||
| target_sets | ||
| ) | ||
|
|
||
| data.table::setkeyv(scores, target_cols) | ||
| data.table::setkeyv(qualifying, target_cols) | ||
| scores[qualifying, nomatch = NULL] | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.