Skip to content

Commit b2810cc

Browse files
committed
get_predict_data, tests need adapting
1 parent f2ef277 commit b2810cc

26 files changed

+191
-393
lines changed

NAMESPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ export(flatline_forecaster)
167167
export(flusight_hub_formatter)
168168
export(forecast)
169169
export(frosting)
170-
export(get_test_data)
170+
export(get_predict_data)
171171
export(is_epi_recipe)
172172
export(is_epi_workflow)
173173
export(is_layer)

R/cdc_baseline_forecaster.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ cdc_baseline_forecaster <- function(
7878
# target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
7979

8080

81-
latest <- get_test_data(epi_recipe(epi_data), epi_data)
81+
latest <- get_predict_data(epi_recipe(epi_data), epi_data)
8282

8383
f <- frosting() %>%
8484
layer_predict() %>%

R/epi_workflow.R

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ fit.epi_workflow <- function(object, data, ..., control = workflows::control_wor
132132
#' @param new_data A data frame containing the new predictors to preprocess
133133
#' and predict on
134134
#'
135+
#' @param reference_dates A vector matching the type of `time_value` in
136+
#' `new_data` giving the dates of the predictions to keep. Defaults to the `reference_date` of the `object`'s recipe.
137+
#'
135138
#' @inheritParams parsnip::predict.model_fit
136139
#'
137140
#' @return
@@ -155,14 +158,13 @@ fit.epi_workflow <- function(object, data, ..., control = workflows::control_wor
155158
#'
156159
#' preds <- predict(wf, latest)
157160
#' preds
158-
predict.epi_workflow <- function(object, new_data, type = NULL, opts = list(), ...) {
161+
predict.epi_workflow <- function(object, new_data, type = NULL, opts = list(), reference_dates = NULL, ...) {
159162
if (!workflows::is_trained_workflow(object)) {
160163
cli_abort(c(
161164
"Can't predict on an untrained epi_workflow.",
162165
i = "Do you need to call `fit()`?"
163166
))
164167
}
165-
browser()
166168
components <- list()
167169
components$mold <- workflows::extract_mold(object)
168170
components$forged <- hardhat::forge(new_data,
@@ -171,7 +173,8 @@ predict.epi_workflow <- function(object, new_data, type = NULL, opts = list(), .
171173

172174
components$keys <- grab_forged_keys(components$forged, object, new_data)
173175
components <- apply_frosting(object, components, new_data, type = type, opts = opts, ...)
174-
components$predictions
176+
reference_dates <- reference_dates %||% extract_recipe(object)$reference_date
177+
components$predictions %>% filter(time_value %in% reference_dates)
175178
}
176179

177180

@@ -267,11 +270,12 @@ forecast.epi_workflow <- function(object, ..., n_recent = NULL, forecast_date =
267270
}
268271
}
269272

270-
test_data <- get_test_data(
273+
predict_data <- get_predict_data(
271274
hardhat::extract_preprocessor(object),
272-
object$original_data
275+
object$original_data,
276+
reference_date = forecast_date
273277
)
278+
predict_data$time_value %>% max
274279

275-
predictions <- predict(object, new_data = test_data)
276-
280+
predict(object, new_data = predict_data, forecast_date)
277281
}

R/get_predict_data.R

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#' Get test data for prediction based on longest lag period
2+
#'
3+
#' Based on the longest lag period in the recipe,
4+
#' `get_predict_data()` creates an [epi_df][epiprocess::as_epi_df]
5+
#' with columns `geo_value`, `time_value`
6+
#' and other variables in the original dataset,
7+
#' which will be used to create features necessary to produce forecasts.
8+
#'
9+
#' The minimum required (recent) data to produce a forecast is equal to
10+
#' the maximum lag requested (on any predictor) plus the longest horizon
11+
#' used if growth rate calculations are requested by the recipe. This is
12+
#' calculated internally.
13+
#'
14+
#' @param recipe A recipe object.
15+
#' @param x An epi_df. The typical usage is to
16+
#' pass the same data as that used for fitting the recipe.
17+
#' @param test_interval A time interval or integer. The length of time before
18+
#' the `forecast_date` to consider for the forecast. The default is 1 year,
19+
#' which you will likely only need to make longer if you are doing long
20+
#' forecast horizons, or shorter if you are forecasting using an expensive
21+
#' model.
22+
#'
23+
#' @return An object of the same type as `x` with columns `geo_value`, `time_value`, any additional
24+
#' keys, as well other variables in the original dataset.
25+
#' @examples
26+
#' # create recipe
27+
#' rec <- epi_recipe(covid_case_death_rates) %>%
28+
#' step_epi_ahead(death_rate, ahead = 7) %>%
29+
#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
30+
#' step_epi_lag(case_rate, lag = c(0, 7, 14))
31+
#' get_predict_data(recipe = rec, x = covid_case_death_rates)
32+
#' @importFrom rlang %@%
33+
#' @importFrom stats na.omit
34+
#' @export
35+
get_predict_data <- function(recipe,
36+
x,
37+
test_interval = NULL,
38+
reference_date = NULL) {
39+
if (!is_epi_df(x)) cli_abort("`x` must be an `epi_df`.")
40+
check <- hardhat::check_column_names(x, colnames(recipe$template))
41+
if (!check$ok) {
42+
cli_abort(c(
43+
"Some variables used for training are not available in {.arg x}.",
44+
i = "The following required columns are missing: {check$missing_names}"
45+
))
46+
}
47+
reference_date <- reference_date %||% recipe$reference_date
48+
test_interval <- test_interval %||% as.difftime(365, units = "days")
49+
trimmed_x <- x %>%
50+
filter((reference_date - time_value) < test_interval)
51+
52+
if (nrow(trimmed_x) == 0) {
53+
cli_abort(
54+
"predict data is filtered to no rows; check your `test_interval = {test_interval}` and `reference_date= {reference_date}`",
55+
class = "epipredict__get_predict_data__no_predict_data"
56+
)
57+
}
58+
59+
trimmed_x
60+
}

R/get_test_data.R

Lines changed: 0 additions & 113 deletions
This file was deleted.

R/tidy.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
#' step_epi_naomit()
3636
#'
3737
#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)
38-
#' latest <- get_test_data(recipe = r, x = jhu)
38+
#' latest <- get_predict_data(recipe = r, x = jhu)
3939
#'
4040
#' f <- frosting() %>%
4141
#' layer_predict() %>%

_pkgdown.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ reference:
8484
contents:
8585
- frosting
8686
- ends_with("_frosting")
87-
- get_test_data
87+
- get_predict_data
8888
- tidy.frosting
8989

9090
- title: Frosting layers

man/get_test_data.Rd renamed to man/get_predict_data.Rd

Lines changed: 6 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/tidy.frosting.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/_snaps/check_enough_data.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
# check_enough_data only checks train data when skip = FALSE
3838

3939
Code
40-
forecaster %>% predict(new_data = toy_test_data %>% filter(time_value >
40+
forecaster %>% predict(new_data = toy_predict_data %>% filter(time_value >
4141
"2020-01-08"))
4242
Condition
4343
Error in `check_enough_data_core()`:

0 commit comments

Comments
 (0)