Skip to content

Commit 2940f95

Browse files
committed
actually passing all the tests
1 parent b06e7e6 commit 2940f95

File tree

5 files changed

+13
-10
lines changed

5 files changed

+13
-10
lines changed

R/arx_forecaster.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ arx_fcast_epi_workflow <- function(
172172
r <- r %>%
173173
step_epi_naomit() %>%
174174
step_training_window(n_recent = args_list$n_training) %>%
175-
check_enough_train_data(all_predictors(), skip = FALSE)
175+
check_enough_train_data(all_predictors(), n = args_list$check_enough_data_n, skip = FALSE)
176176

177177
if (!is.null(args_list$check_enough_data_n)) {
178178
r <- r %>% check_enough_train_data(

R/check_enough_train_data.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,11 @@ prep.check_enough_train_data <- function(x, training, info = NULL, ...) {
123123
bake.check_enough_train_data <- function(object, new_data, ...) {
124124
col_names <- object$columns
125125
if (object$drop_na) {
126-
newish_data <- tidyr::drop_na(new_data, any_of(unname(col_names)))
126+
non_na_data <- tidyr::drop_na(new_data, any_of(unname(col_names)))
127+
} else {
128+
non_na_data <- new_data
127129
}
128-
cols_not_enough_data <- newish_data %>%
130+
cols_not_enough_data <- non_na_data %>%
129131
group_by(across(all_of(.env$object$epi_keys))) %>%
130132
summarise(across(all_of(.env$col_names), ~ dplyr::n() < .env$object$n), .groups = "drop") %>%
131133
summarise(across(all_of(.env$col_names), any), .groups = "drop") %>%

tests/testthat/_snaps/check_enough_train_data.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838

3939
Code
4040
epi_recipe(toy_epi_df) %>% step_epi_lag(x, lag = c(1, 2)) %>%
41-
check_enough_train_data(all_predictors(), y, n = 2 * n - 5) %>% prep(
41+
check_enough_train_data(all_predictors(), y, n = 2 * n - 4) %>% prep(
4242
toy_epi_df) %>% bake(new_data = NULL)
4343
Condition
4444
Error in `prep()`:

tests/testthat/test-check_enough_train_data.R

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,33 +94,34 @@ test_that("check_enough_train_data only checks train data", {
9494
epiprocess::as_epi_df()
9595
expect_no_error(
9696
epi_recipe(toy_epi_df) %>%
97-
check_enough_train_data(x, y, n = n - 2, epi_keys = "geo_value") %>%
97+
check_enough_train_data(x, y, n = n - 2, epi_keys = "geo_value", skip = TRUE) %>%
9898
prep(toy_epi_df) %>%
9999
bake(new_data = toy_test_data)
100100
)
101101
# Same thing, but skip = FALSE
102102
expect_no_error(
103103
epi_recipe(toy_epi_df) %>%
104-
check_enough_train_data(y, n = n - 2, epi_keys = "geo_value", skip = FALSE) %>%
104+
check_enough_train_data(y, n = n - 2, epi_keys = "geo_value") %>%
105105
prep(toy_epi_df) %>%
106106
bake(new_data = toy_test_data)
107107
)
108108
})
109109

110110
test_that("check_enough_train_data works with all_predictors() downstream of constructed terms", {
111-
# With a lag of 2, we will get 2 * n - 6 non-NA rows
111+
# With a lag of 2, we will get 2 * n - 5 non-NA rows (NA's in x but not in the
112+
# lags don't count)
112113
expect_no_error(
113114
epi_recipe(toy_epi_df) %>%
114115
step_epi_lag(x, lag = c(1, 2)) %>%
115-
check_enough_train_data(all_predictors(), y, n = 2 * n - 6) %>%
116+
check_enough_train_data(all_predictors(), y, n = 2 * n - 5) %>%
116117
prep(toy_epi_df) %>%
117118
bake(new_data = NULL)
118119
)
119120
expect_snapshot(
120121
error = TRUE,
121122
epi_recipe(toy_epi_df) %>%
122123
step_epi_lag(x, lag = c(1, 2)) %>%
123-
check_enough_train_data(all_predictors(), y, n = 2 * n - 5) %>%
124+
check_enough_train_data(all_predictors(), y, n = 2 * n - 4) %>%
124125
prep(toy_epi_df) %>%
125126
bake(new_data = NULL)
126127
)

tests/testthat/test-layer_residual_quantiles.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ test_that("Canned forecasters work with / without", {
103103
)
104104

105105
expect_silent(
106-
arx_forecaster(jhu, "death_rate", c("case_rate", "death_rate"))
106+
arx_forecaster(jhu, "death_rate", c("case_rate", "death_rate"), args_list = arx_args_list(check_enough_data_n = 1))
107107
)
108108
expect_silent(
109109
flatline_forecaster(

0 commit comments

Comments
 (0)