@@ -47,20 +47,19 @@ arx_forecaster <- function(
47
47
if (! is_regression(trainer )) {
48
48
cli_abort(" `trainer` must be a {.pkg parsnip} model of mode 'regression'." )
49
49
}
50
-
51
50
wf <- arx_fcast_epi_workflow(epi_data , outcome , predictors , trainer , args_list )
52
51
wf <- fit(wf , epi_data )
53
52
54
53
# get the forecast date for the forecast function
55
54
if (args_list $ adjust_latency == " none" ) {
56
- forecast_date_default <- max(epi_data $ time_value )
55
+ reference_date_default <- max(epi_data $ time_value )
57
56
} else {
58
- forecast_date_default <- attributes(epi_data )$ metadata $ as_of
57
+ reference_date_default <- attributes(epi_data )$ metadata $ as_of
59
58
}
60
- forecast_date <- args_list $ forecast_date %|| % forecast_date_default
61
-
59
+ reference_date <- args_list $ reference_date %|| % reference_date_default
60
+ predict_interval <- args_list $ predict_interval
62
61
63
- preds <- forecast(wf , forecast_date = forecast_date ) %> %
62
+ preds <- forecast(wf , reference_dates = reference_date , predict_interval = predict_interval ) %> %
64
63
as_tibble() %> %
65
64
select(- time_value )
66
65
@@ -126,21 +125,21 @@ arx_fcast_epi_workflow <- function(
126
125
# if they don't and they're not adjusting latency, it defaults to the max time_value
127
126
# if they're adjusting, it defaults to the as_of
128
127
if (args_list $ adjust_latency == " none" ) {
129
- forecast_date_default <- max(epi_data $ time_value )
130
- if (! is.null(args_list $ forecast_date ) && args_list $ forecast_date != forecast_date_default ) {
128
+ reference_date_default <- max(epi_data $ time_value )
129
+ if (! is.null(args_list $ reference_date ) && args_list $ reference_date != reference_date_default ) {
131
130
cli_warn(
132
- " The specified forecast date {args_list$forecast_date } doesn't match the date from which the forecast is actually occurring {forecast_date_default }." ,
131
+ " The specified forecast date {args_list$reference_date } doesn't match the date from which the forecast is actually occurring {reference_date_default }." ,
133
132
class = " epipredict__arx_forecaster__forecast_date_defaulting"
134
133
)
135
134
}
136
135
} else {
137
- forecast_date_default <- attributes(epi_data )$ metadata $ as_of
136
+ reference_date_default <- attributes(epi_data )$ metadata $ as_of
138
137
}
139
- forecast_date <- args_list $ forecast_date %|| % forecast_date_default
140
- target_date <- args_list $ target_date %|| % (forecast_date + args_list $ ahead )
141
- if (forecast_date + args_list $ ahead != target_date ) {
142
- cli_abort(" `forecast_date ` {.val {forecast_date }} + `ahead` {.val {ahead}} must equal `target_date` {.val {target_date}}." ,
143
- class = " epipredict__arx_forecaster__inconsistent_target_ahead_forecaste_date "
138
+ reference_date <- args_list $ reference_date %|| % reference_date_default
139
+ target_date <- args_list $ target_date %|| % (reference_date + args_list $ ahead )
140
+ if (reference_date + args_list $ ahead != target_date ) {
141
+ cli_abort(" `reference_date ` {.val {reference_date }} + `ahead` {.val {ahead}} must equal `target_date` {.val {target_date}}." ,
142
+ class = " epipredict__arx_forecaster__inconsistent_target_ahead_forecast_date "
144
143
)
145
144
}
146
145
@@ -153,12 +152,12 @@ arx_fcast_epi_workflow <- function(
153
152
if (! is.null(method_adjust_latency )) {
154
153
if (method_adjust_latency == " extend_ahead" ) {
155
154
r <- r %> % step_adjust_latency(all_outcomes(),
156
- fixed_forecast_date = forecast_date ,
155
+ fixed_reference_date = reference_date ,
157
156
method = method_adjust_latency
158
157
)
159
158
} else if (method_adjust_latency == " extend_lags" ) {
160
159
r <- r %> % step_adjust_latency(all_predictors(),
161
- fixed_forecast_date = forecast_date ,
160
+ fixed_reference_date = reference_date ,
162
161
method = method_adjust_latency
163
162
)
164
163
}
@@ -218,7 +217,7 @@ arx_fcast_epi_workflow <- function(
218
217
by_key = args_list $ quantile_by_key
219
218
)
220
219
}
221
- f <- layer_add_forecast_date(f , forecast_date = forecast_date ) %> %
220
+ f <- layer_add_forecast_date(f , forecast_date = reference_date ) %> %
222
221
layer_add_target_date(target_date = target_date )
223
222
if (args_list $ nonneg ) f <- layer_threshold(f , dplyr :: starts_with(" .pred" ))
224
223
@@ -238,19 +237,19 @@ arx_fcast_epi_workflow <- function(
238
237
# ' @param n_training Integer. An upper limit for the number of rows per
239
238
# ' key that are used for training
240
239
# ' (in the time unit of the `epi_df`).
241
- # ' @param forecast_date Date. The date from which the forecast is occurring.
240
+ # ' @param reference_date Date. The date from which the forecast is occurring.
242
241
# ' The default `NULL` will determine this automatically from either
243
242
# ' 1. the maximum time value for which there's data if there is no latency
244
243
# ' adjustment (the default case), or
245
244
# ' 2. the `as_of` date of `epi_data` if `adjust_latency` is
246
245
# ' non-`NULL`.
247
246
# ' @param target_date Date. The date that is being forecast. The default `NULL`
248
- # ' will determine this automatically as `forecast_date + ahead`.
247
+ # ' will determine this automatically as `reference_date + ahead`.
249
248
# ' @param adjust_latency Character. One of the `method`s of
250
249
# ' [step_adjust_latency()], or `"none"` (in which case there is no adjustment).
251
- # ' If the `forecast_date ` is after the last day of data, this determines how
250
+ # ' If the `reference_date ` is after the last day of data, this determines how
252
251
# ' to shift the model to account for this difference. The options are:
253
- # ' - `"none"` the default, assumes the `forecast_date ` is the last day of data
252
+ # ' - `"none"` the default, assumes the `reference_date ` is the last day of data
254
253
# ' - `"extend_ahead"`: increase the `ahead` by the latency so it's relative to
255
254
# ' the last day of data. For example, if the last day of data was 3 days ago,
256
255
# ' the ahead becomes `ahead+3`.
@@ -280,6 +279,7 @@ arx_fcast_epi_workflow <- function(
280
279
# ' column names on which to group the data and check threshold within each
281
280
# ' group. Useful if training per group (for example, per geo_value).
282
281
# ' @param ... Space to handle future expansions (unused).
282
+ # ' @inheritParams get_predict_data
283
283
# '
284
284
# '
285
285
# ' @return A list containing updated parameter choices with class `arx_flist`.
@@ -294,7 +294,7 @@ arx_args_list <- function(
294
294
lags = c(0L , 7L , 14L ),
295
295
ahead = 7L ,
296
296
n_training = Inf ,
297
- forecast_date = NULL ,
297
+ reference_date = NULL ,
298
298
target_date = NULL ,
299
299
adjust_latency = c(" none" , " extend_ahead" , " extend_lags" , " locf" ),
300
300
warn_latency = TRUE ,
@@ -304,6 +304,7 @@ arx_args_list <- function(
304
304
quantile_by_key = character (0L ),
305
305
check_enough_data_n = NULL ,
306
306
check_enough_data_epi_keys = NULL ,
307
+ predict_interval = NULL ,
307
308
... ) {
308
309
# error checking if lags is a list
309
310
rlang :: check_dots_empty()
@@ -313,8 +314,8 @@ arx_args_list <- function(
313
314
adjust_latency <- rlang :: arg_match(adjust_latency )
314
315
arg_is_scalar(ahead , n_training , symmetrize , nonneg , adjust_latency , warn_latency )
315
316
arg_is_chr(quantile_by_key , allow_empty = TRUE )
316
- arg_is_scalar(forecast_date , target_date , allow_null = TRUE )
317
- arg_is_date(forecast_date , target_date , allow_null = TRUE )
317
+ arg_is_scalar(reference_date , target_date , allow_null = TRUE )
318
+ arg_is_date(reference_date , target_date , allow_null = TRUE )
318
319
arg_is_nonneg_int(ahead , lags )
319
320
arg_is_lgl(symmetrize , nonneg )
320
321
arg_is_probabilities(quantile_levels , allow_null = TRUE )
@@ -323,9 +324,9 @@ arx_args_list <- function(
323
324
arg_is_pos(check_enough_data_n , allow_null = TRUE )
324
325
arg_is_chr(check_enough_data_epi_keys , allow_null = TRUE )
325
326
326
- if (! is.null(forecast_date ) && ! is.null(target_date )) {
327
- if (forecast_date + ahead != target_date ) {
328
- cli_abort(" `forecast_date ` {.val {forecast_date }} + `ahead` {.val {ahead}} must equal `target_date` {.val {target_date}}." ,
327
+ if (! is.null(reference_date ) && ! is.null(target_date )) {
328
+ if (reference_date + ahead != target_date ) {
329
+ cli_abort(" `reference_date ` {.val {reference_date }} + `ahead` {.val {ahead}} must equal `target_date` {.val {target_date}}." ,
329
330
class = " epipredict__arx_args__inconsistent_target_ahead_forecaste_date"
330
331
)
331
332
}
@@ -338,8 +339,9 @@ arx_args_list <- function(
338
339
ahead ,
339
340
n_training ,
340
341
quantile_levels ,
341
- forecast_date ,
342
+ reference_date ,
342
343
target_date ,
344
+ predict_interval ,
343
345
adjust_latency ,
344
346
warn_latency ,
345
347
symmetrize ,
0 commit comments