Skip to content

Commit f421a2d

Browse files
committed
using check_enough_train_data in practice
1 parent 7f08d40 commit f421a2d

File tree

4 files changed

+44
-5
lines changed

4 files changed

+44
-5
lines changed

R/arx_forecaster.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,9 @@ arx_fcast_epi_workflow <- function(
171171
step_epi_ahead(!!outcome, ahead = args_list$ahead)
172172
r <- r %>%
173173
step_epi_naomit() %>%
174-
step_training_window(n_recent = args_list$n_training)
174+
step_training_window(n_recent = args_list$n_training) %>%
175+
check_enough_train_data(all_predictors(), skip = FALSE)
176+
175177
if (!is.null(args_list$check_enough_data_n)) {
176178
r <- r %>% check_enough_train_data(
177179
all_predictors(),

R/check_enough_train_data.R

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ check_enough_train_data <-
4747
role = NA,
4848
trained = FALSE,
4949
columns = NULL,
50-
skip = TRUE,
50+
skip = FALSE,
5151
id = rand_id("enough_train_data")) {
5252
recipes::add_check(
5353
recipe,
@@ -90,7 +90,7 @@ prep.check_enough_train_data <- function(x, training, info = NULL, ...) {
9090
}
9191

9292
if (x$drop_na) {
93-
training <- tidyr::drop_na(training)
93+
training <- tidyr::drop_na(training, any_of(unname(col_names)))
9494
}
9595
cols_not_enough_data <- training %>%
9696
group_by(across(all_of(.env$x$epi_keys))) %>%
@@ -101,7 +101,8 @@ prep.check_enough_train_data <- function(x, training, info = NULL, ...) {
101101

102102
if (length(cols_not_enough_data) > 0) {
103103
cli_abort(
104-
"The following columns don't have enough data to predict: {cols_not_enough_data}."
104+
"The following columns don't have enough data to predict: {cols_not_enough_data}.",
105+
class = "epipredict__not_enough_train_data"
105106
)
106107
}
107108

@@ -120,6 +121,23 @@ prep.check_enough_train_data <- function(x, training, info = NULL, ...) {
120121

121122
#' @export
122123
bake.check_enough_train_data <- function(object, new_data, ...) {
124+
col_names <- object$columns
125+
if (object$drop_na) {
126+
newish_data <- tidyr::drop_na(new_data, any_of(unname(col_names)))
127+
}
128+
cols_not_enough_data <- newish_data %>%
129+
group_by(across(all_of(.env$object$epi_keys))) %>%
130+
summarise(across(all_of(.env$col_names), ~ dplyr::n() < .env$object$n), .groups = "drop") %>%
131+
summarise(across(all_of(.env$col_names), any), .groups = "drop") %>%
132+
unlist() %>%
133+
names(.)[.]
134+
135+
if (length(cols_not_enough_data) > 0) {
136+
cli_abort(
137+
"The following columns don't have enough data to predict: {cols_not_enough_data}.",
138+
class = "epipredict__not_enough_train_data"
139+
)
140+
}
123141
new_data
124142
}
125143

R/epi_workflow.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,6 @@ forecast.epi_workflow <- function(object, ..., n_recent = NULL, forecast_date =
270270
hardhat::extract_preprocessor(object),
271271
object$original_data
272272
)
273-
273+
test_data
274274
predict(object, new_data = test_data)
275275
}

tests/testthat/test-arx_forecaster.R

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,22 @@ test_that("arx_forecaster errors if forecast date, target date, and ahead are in
2424
class = "epipredict__arx_args__inconsistent_target_ahead_forecaste_date"
2525
)
2626
})
27+
28+
test_that("warns if there's not enough data to predict", {
29+
edf <- tibble(
30+
geo_value = "ct",
31+
time_value = seq(as.Date("2020-10-01"), as.Date("2023-05-31"), by = "day"),
32+
) %>%
33+
mutate(value = seq_len(nrow(.)) + rnorm(nrow(.))) %>%
34+
# Oct to May (flu season, ish) only:
35+
filter(!between(as.POSIXlt(time_value)$mon + 1L, 6L, 9L)) %>%
36+
# and actually, pretend we're around mid-October 2022:
37+
filter(time_value <= as.Date("2022-10-12")) %>%
38+
as_epi_df(as_of = as.Date("2022-10-12"))
39+
edf %>% filter(time_value > "2022-08-01")
40+
41+
expect_error(
42+
edf %>% arx_forecaster("value"),
43+
class = "epipredict__not_enough_train_data"
44+
)
45+
})

0 commit comments

Comments
 (0)