Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 8f3cb72

Browse files
committedMay 5, 2025·
partial test fixes, reference_date vs forecast_date
1 parent b2810cc commit 8f3cb72

16 files changed

+137
-127
lines changed
 

‎R/arx_forecaster.R

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -47,20 +47,19 @@ arx_forecaster <- function(
4747
if (!is_regression(trainer)) {
4848
cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'regression'.")
4949
}
50-
5150
wf <- arx_fcast_epi_workflow(epi_data, outcome, predictors, trainer, args_list)
5251
wf <- fit(wf, epi_data)
5352

5453
# get the forecast date for the forecast function
5554
if (args_list$adjust_latency == "none") {
56-
forecast_date_default <- max(epi_data$time_value)
55+
reference_date_default <- max(epi_data$time_value)
5756
} else {
58-
forecast_date_default <- attributes(epi_data)$metadata$as_of
57+
reference_date_default <- attributes(epi_data)$metadata$as_of
5958
}
60-
forecast_date <- args_list$forecast_date %||% forecast_date_default
61-
59+
reference_date <- args_list$reference_date %||% reference_date_default
60+
predict_interval <- args_list$predict_interval
6261

63-
preds <- forecast(wf, forecast_date = forecast_date) %>%
62+
preds <- forecast(wf, reference_dates = reference_date, predict_interval = predict_interval) %>%
6463
as_tibble() %>%
6564
select(-time_value)
6665

@@ -126,21 +125,21 @@ arx_fcast_epi_workflow <- function(
126125
# if they don't and they're not adjusting latency, it defaults to the max time_value
127126
# if they're adjusting, it defaults to the as_of
128127
if (args_list$adjust_latency == "none") {
129-
forecast_date_default <- max(epi_data$time_value)
130-
if (!is.null(args_list$forecast_date) && args_list$forecast_date != forecast_date_default) {
128+
reference_date_default <- max(epi_data$time_value)
129+
if (!is.null(args_list$reference_date) && args_list$reference_date != reference_date_default) {
131130
cli_warn(
132-
"The specified forecast date {args_list$forecast_date} doesn't match the date from which the forecast is actually occurring {forecast_date_default}.",
131+
"The specified forecast date {args_list$reference_date} doesn't match the date from which the forecast is actually occurring {reference_date_default}.",
133132
class = "epipredict__arx_forecaster__forecast_date_defaulting"
134133
)
135134
}
136135
} else {
137-
forecast_date_default <- attributes(epi_data)$metadata$as_of
136+
reference_date_default <- attributes(epi_data)$metadata$as_of
138137
}
139-
forecast_date <- args_list$forecast_date %||% forecast_date_default
140-
target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
141-
if (forecast_date + args_list$ahead != target_date) {
142-
cli_abort("`forecast_date` {.val {forecast_date}} + `ahead` {.val {ahead}} must equal `target_date` {.val {target_date}}.",
143-
class = "epipredict__arx_forecaster__inconsistent_target_ahead_forecaste_date"
138+
reference_date <- args_list$reference_date %||% reference_date_default
139+
target_date <- args_list$target_date %||% (reference_date + args_list$ahead)
140+
if (reference_date + args_list$ahead != target_date) {
141+
cli_abort("`reference_date` {.val {reference_date}} + `ahead` {.val {ahead}} must equal `target_date` {.val {target_date}}.",
142+
class = "epipredict__arx_forecaster__inconsistent_target_ahead_forecast_date"
144143
)
145144
}
146145

@@ -153,12 +152,12 @@ arx_fcast_epi_workflow <- function(
153152
if (!is.null(method_adjust_latency)) {
154153
if (method_adjust_latency == "extend_ahead") {
155154
r <- r %>% step_adjust_latency(all_outcomes(),
156-
fixed_forecast_date = forecast_date,
155+
fixed_reference_date = reference_date,
157156
method = method_adjust_latency
158157
)
159158
} else if (method_adjust_latency == "extend_lags") {
160159
r <- r %>% step_adjust_latency(all_predictors(),
161-
fixed_forecast_date = forecast_date,
160+
fixed_reference_date = reference_date,
162161
method = method_adjust_latency
163162
)
164163
}
@@ -218,7 +217,7 @@ arx_fcast_epi_workflow <- function(
218217
by_key = args_list$quantile_by_key
219218
)
220219
}
221-
f <- layer_add_forecast_date(f, forecast_date = forecast_date) %>%
220+
f <- layer_add_forecast_date(f, forecast_date = reference_date) %>%
222221
layer_add_target_date(target_date = target_date)
223222
if (args_list$nonneg) f <- layer_threshold(f, dplyr::starts_with(".pred"))
224223

@@ -238,19 +237,19 @@ arx_fcast_epi_workflow <- function(
238237
#' @param n_training Integer. An upper limit for the number of rows per
239238
#' key that are used for training
240239
#' (in the time unit of the `epi_df`).
241-
#' @param forecast_date Date. The date from which the forecast is occurring.
240+
#' @param reference_date Date. The date from which the forecast is occurring.
242241
#' The default `NULL` will determine this automatically from either
243242
#' 1. the maximum time value for which there's data if there is no latency
244243
#' adjustment (the default case), or
245244
#' 2. the `as_of` date of `epi_data` if `adjust_latency` is
246245
#' non-`NULL`.
247246
#' @param target_date Date. The date that is being forecast. The default `NULL`
248-
#' will determine this automatically as `forecast_date + ahead`.
247+
#' will determine this automatically as `reference_date + ahead`.
249248
#' @param adjust_latency Character. One of the `method`s of
250249
#' [step_adjust_latency()], or `"none"` (in which case there is no adjustment).
251-
#' If the `forecast_date` is after the last day of data, this determines how
250+
#' If the `reference_date` is after the last day of data, this determines how
252251
#' to shift the model to account for this difference. The options are:
253-
#' - `"none"` the default, assumes the `forecast_date` is the last day of data
252+
#' - `"none"` the default, assumes the `reference_date` is the last day of data
254253
#' - `"extend_ahead"`: increase the `ahead` by the latency so it's relative to
255254
#' the last day of data. For example, if the last day of data was 3 days ago,
256255
#' the ahead becomes `ahead+3`.
@@ -280,6 +279,7 @@ arx_fcast_epi_workflow <- function(
280279
#' column names on which to group the data and check threshold within each
281280
#' group. Useful if training per group (for example, per geo_value).
282281
#' @param ... Space to handle future expansions (unused).
282+
#' @inheritParams get_predict_data
283283
#'
284284
#'
285285
#' @return A list containing updated parameter choices with class `arx_flist`.
@@ -294,7 +294,7 @@ arx_args_list <- function(
294294
lags = c(0L, 7L, 14L),
295295
ahead = 7L,
296296
n_training = Inf,
297-
forecast_date = NULL,
297+
reference_date = NULL,
298298
target_date = NULL,
299299
adjust_latency = c("none", "extend_ahead", "extend_lags", "locf"),
300300
warn_latency = TRUE,
@@ -304,6 +304,7 @@ arx_args_list <- function(
304304
quantile_by_key = character(0L),
305305
check_enough_data_n = NULL,
306306
check_enough_data_epi_keys = NULL,
307+
predict_interval = NULL,
307308
...) {
308309
# error checking if lags is a list
309310
rlang::check_dots_empty()
@@ -313,8 +314,8 @@ arx_args_list <- function(
313314
adjust_latency <- rlang::arg_match(adjust_latency)
314315
arg_is_scalar(ahead, n_training, symmetrize, nonneg, adjust_latency, warn_latency)
315316
arg_is_chr(quantile_by_key, allow_empty = TRUE)
316-
arg_is_scalar(forecast_date, target_date, allow_null = TRUE)
317-
arg_is_date(forecast_date, target_date, allow_null = TRUE)
317+
arg_is_scalar(reference_date, target_date, allow_null = TRUE)
318+
arg_is_date(reference_date, target_date, allow_null = TRUE)
318319
arg_is_nonneg_int(ahead, lags)
319320
arg_is_lgl(symmetrize, nonneg)
320321
arg_is_probabilities(quantile_levels, allow_null = TRUE)
@@ -323,9 +324,9 @@ arx_args_list <- function(
323324
arg_is_pos(check_enough_data_n, allow_null = TRUE)
324325
arg_is_chr(check_enough_data_epi_keys, allow_null = TRUE)
325326

326-
if (!is.null(forecast_date) && !is.null(target_date)) {
327-
if (forecast_date + ahead != target_date) {
328-
cli_abort("`forecast_date` {.val {forecast_date}} + `ahead` {.val {ahead}} must equal `target_date` {.val {target_date}}.",
327+
if (!is.null(reference_date) && !is.null(target_date)) {
328+
if (reference_date + ahead != target_date) {
329+
cli_abort("`reference_date` {.val {reference_date}} + `ahead` {.val {ahead}} must equal `target_date` {.val {target_date}}.",
329330
class = "epipredict__arx_args__inconsistent_target_ahead_forecaste_date"
330331
)
331332
}
@@ -338,8 +339,9 @@ arx_args_list <- function(
338339
ahead,
339340
n_training,
340341
quantile_levels,
341-
forecast_date,
342+
reference_date,
342343
target_date,
344+
predict_interval,
343345
adjust_latency,
344346
warn_latency,
345347
symmetrize,

‎R/epi_workflow.R

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,18 @@ predict.epi_workflow <- function(object, new_data, type = NULL, opts = list(), r
174174
components$keys <- grab_forged_keys(components$forged, object, new_data)
175175
components <- apply_frosting(object, components, new_data, type = type, opts = opts, ...)
176176
reference_dates <- reference_dates %||% extract_recipe(object)$reference_date
177-
components$predictions %>% filter(time_value %in% reference_dates)
177+
#browser()
178+
predictions <- components$predictions %>% filter(time_value %in% reference_dates)
179+
predictions
180+
if (nrow(predictions) == 0) {
181+
last_pred_date <- components$predictions %>% pull(time_value) %>% max()
182+
last_data_date <- new_data %>% pull(time_value) %>% max()
183+
cli_warn(
184+
"no predictions on the reference date(s) {reference_dates}. The last prediction was on {last_pred_date}. The most recent prediction data is on {last_data_date}",
185+
class = "epipredict__predict_epi_workflow__no_predictions"
186+
)
187+
}
188+
predictions
178189
}
179190

180191

@@ -242,14 +253,12 @@ print.epi_workflow <- function(x, ...) {
242253
#' example, suppose n_recent = 3, then if the 3 most recent observations in any
243254
#' geo_value are all NA’s, we won’t be able to fill anything, and an error
244255
#' message will be thrown. (See details.)
245-
#' @param forecast_date By default, this is set to the maximum time_value in x.
246-
#' But if there is data latency such that recent NA's should be filled, this may
247-
#' be after the last available time_value.
256+
#' @inheritParams get_predict_data
248257
#'
249258
#' @return A forecast tibble.
250259
#'
251260
#' @export
252-
forecast.epi_workflow <- function(object, ..., n_recent = NULL, forecast_date = NULL) {
261+
forecast.epi_workflow <- function(object, ..., n_recent = NULL, reference_dates = NULL, predict_interval = NULL) {
253262
rlang::check_dots_empty()
254263

255264
if (!object$trained) {
@@ -259,6 +268,7 @@ forecast.epi_workflow <- function(object, ..., n_recent = NULL, forecast_date =
259268
))
260269
}
261270

271+
#browser()
262272
frosting_fd <- NULL
263273
if (has_postprocessor(object) && detect_layer(object, "layer_add_forecast_date")) {
264274
frosting_fd <- extract_argument(object, "layer_add_forecast_date", "forecast_date")
@@ -273,9 +283,9 @@ forecast.epi_workflow <- function(object, ..., n_recent = NULL, forecast_date =
273283
predict_data <- get_predict_data(
274284
hardhat::extract_preprocessor(object),
275285
object$original_data,
276-
reference_date = forecast_date
286+
reference_date = reference_dates,
287+
predict_interval = predict_interval
277288
)
278-
predict_data$time_value %>% max
279289

280-
predict(object, new_data = predict_data, forecast_date)
290+
predict(object, new_data = predict_data, reference_dates = reference_dates)
281291
}

‎R/get_predict_data.R

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,18 @@
1414
#' @param recipe A recipe object.
1515
#' @param x An epi_df. The typical usage is to
1616
#' pass the same data as that used for fitting the recipe.
17-
#' @param test_interval A time interval or integer. The length of time before
17+
#' @param predict_interval A time interval or integer. The length of time before
1818
#' the `forecast_date` to consider for the forecast. The default is 1 year,
1919
#' which you will likely only need to make longer if you are doing long
2020
#' forecast horizons, or shorter if you are forecasting using an expensive
2121
#' model.
22+
#' @param reference_date By default, this is set to the maximum time_value in x.
23+
#' But if there is data latency such that recent NA's should be filled, this may
24+
#' be after the last available time_value.
2225
#'
23-
#' @return An object of the same type as `x` with columns `geo_value`, `time_value`, any additional
24-
#' keys, as well other variables in the original dataset.
26+
#' @return An object of the same type as `x` with columns `geo_value`,
27+
#' `time_value`, any additional keys, as well other variables in the original
28+
#' dataset.
2529
#' @examples
2630
#' # create recipe
2731
#' rec <- epi_recipe(covid_case_death_rates) %>%
@@ -34,7 +38,7 @@
3438
#' @export
3539
get_predict_data <- function(recipe,
3640
x,
37-
test_interval = NULL,
41+
predict_interval = NULL,
3842
reference_date = NULL) {
3943
if (!is_epi_df(x)) cli_abort("`x` must be an `epi_df`.")
4044
check <- hardhat::check_column_names(x, colnames(recipe$template))
@@ -45,13 +49,13 @@ get_predict_data <- function(recipe,
4549
))
4650
}
4751
reference_date <- reference_date %||% recipe$reference_date
48-
test_interval <- test_interval %||% as.difftime(365, units = "days")
52+
predict_interval <- predict_interval %||% as.difftime(365, units = "days")
4953
trimmed_x <- x %>%
50-
filter((reference_date - time_value) < test_interval)
54+
filter((reference_date - time_value) < predict_interval)
5155

5256
if (nrow(trimmed_x) == 0) {
5357
cli_abort(
54-
"predict data is filtered to no rows; check your `test_interval = {test_interval}` and `reference_date= {reference_date}`",
58+
"predict data is filtered to no rows; check your `predict_interval = {predict_interval}`, `reference_date= {reference_date}` and latest data {max(x$time_value)}",
5559
class = "epipredict__get_predict_data__no_predict_data"
5660
)
5761
}

‎man/arx_args_list.Rd

Lines changed: 9 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎man/arx_class_args_list.Rd

Lines changed: 0 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎man/cdc_baseline_args_list.Rd

Lines changed: 0 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎man/climate_args_list.Rd

Lines changed: 0 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎man/flatline_args_list.Rd

Lines changed: 0 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎man/forecast.epi_workflow.Rd

Lines changed: 14 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎man/get_predict_data.Rd

Lines changed: 11 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎man/grf_quantiles.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎man/predict-epi_workflow.Rd

Lines changed: 11 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎man/step_adjust_latency.Rd

Lines changed: 19 additions & 27 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎tests/testthat/_snaps/get_predict_data.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
get_predict_data(recipe = r, x = covid_case_death_rates)
55
Condition
66
Error in `get_predict_data()`:
7-
! predict data is filtered to no rows; check your `test_interval = 365` and `reference_date= 2023-03-10`
7+
! predict data is filtered to no rows; check your `predict_interval = 365` and `reference_date= 2023-03-10`
88

99
# expect error that geo_value or time_value does not exist
1010

‎tests/testthat/test-arx_forecaster.R

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,27 @@
11
train_data <- epidatasets::cases_deaths_subset
22
test_that("arx_forecaster warns if forecast date beyond the implicit one", {
33
bad_date <- max(train_data$time_value) + 300
4-
expect_error(
4+
expect_warning(
55
expect_warning(
66
arx1 <- arx_forecaster(
77
train_data,
88
"death_rate_7d_av",
99
c("death_rate_7d_av", "case_rate_7d_av"),
10-
args_list = (arx_args_list(forecast_date = bad_date))
10+
args_list = (arx_args_list(reference_date = bad_date))
1111
),
1212
class = "epipredict__arx_forecaster__forecast_date_defaulting"
1313
),
14-
class = "epipredict__get_predict_data__no_predict_data")
14+
class = "epipredict__predict_epi_workflow__no_predictions")
1515
})
1616

17-
test_that("arx_forecaster errors if forecast date, target date, and ahead are inconsistent", {
17+
test_that("arx_forecaster errors if reference date, target date, and ahead are inconsistent", {
1818
max_date <- max(train_data$time_value)
1919
expect_error(
2020
arx1 <- arx_forecaster(
2121
train_data,
2222
"death_rate_7d_av",
2323
c("death_rate_7d_av", "case_rate_7d_av"),
24-
args_list = (arx_args_list(ahead = 5, target_date = max_date, forecast_date = max_date))
24+
args_list = (arx_args_list(ahead = 5, target_date = max_date, reference_date = max_date))
2525
),
2626
class = "epipredict__arx_args__inconsistent_target_ahead_forecaste_date"
2727
)
@@ -38,10 +38,9 @@ test_that("warns if there's not enough data to predict", {
3838
# and actually, pretend we're around mid-October 2022:
3939
filter(time_value <= as.Date("2022-10-12")) %>%
4040
as_epi_df(as_of = as.Date("2022-10-12"))
41-
edf %>% filter(time_value > "2022-08-01")
4241

4342
expect_error(
44-
edf %>% arx_forecaster("value"),
45-
class = "epipredict__not_enough_data"
43+
edf %>% arx_forecaster("value", args_list = arx_args_list(predict_interval = as.difftime(0, units = "days"))),
44+
class = "epipredict__get_predict_data__no_predict_data"
4645
)
4746
})

‎tests/testthat/test-get_predict_data.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
suppressPackageStartupMessages(library(dplyr))
22
forecast_date <- max(covid_case_death_rates$time_value)
3-
test_that("return expected number of rows for various `test_intervals`", {
3+
test_that("return expected number of rows for various `predict_intervals`", {
44
r <- epi_recipe(covid_case_death_rates, reference_date = forecast_date) %>%
55
step_epi_ahead(death_rate, ahead = 7) %>%
66
step_epi_lag(death_rate, lag = c(0, 7, 14, 21, 28)) %>%
@@ -15,14 +15,14 @@ test_that("return expected number of rows for various `test_intervals`", {
1515
dplyr::n_distinct(covid_case_death_rates$geo_value) * 365
1616
)
1717

18-
predict_data <- get_predict_data(recipe = r, test_interval = 5, x = covid_case_death_rates)
18+
predict_data <- get_predict_data(recipe = r, predict_interval = 5, x = covid_case_death_rates)
1919

2020
expect_equal(
2121
nrow(predict_data),
2222
dplyr::n_distinct(covid_case_death_rates$geo_value) * 5
2323
)
2424

25-
predict_data <- get_predict_data(recipe = r, test_interval = as.difftime(35, units = "days"), x = covid_case_death_rates)
25+
predict_data <- get_predict_data(recipe = r, predict_interval = as.difftime(35, units = "days"), x = covid_case_death_rates)
2626

2727
expect_equal(
2828
nrow(predict_data),

0 commit comments

Comments
 (0)
Please sign in to comment.