Skip to content

Commit ee5cc40

Browse files
committed
feat: add step_/layer_ epi_YeoJohnson
1 parent 7f08d40 commit ee5cc40

File tree

7 files changed

+949
-2
lines changed

7 files changed

+949
-2
lines changed

NAMESPACE

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ S3method(bake,check_enough_train_data)
1919
S3method(bake,epi_recipe)
2020
S3method(bake,step_adjust_latency)
2121
S3method(bake,step_climate)
22+
S3method(bake,step_epi_YeoJohnson)
2223
S3method(bake,step_epi_ahead)
2324
S3method(bake,step_epi_lag)
2425
S3method(bake,step_epi_slide)
@@ -53,6 +54,7 @@ S3method(prep,check_enough_train_data)
5354
S3method(prep,epi_recipe)
5455
S3method(prep,step_adjust_latency)
5556
S3method(prep,step_climate)
57+
S3method(prep,step_epi_YeoJohnson)
5658
S3method(prep,step_epi_ahead)
5759
S3method(prep,step_epi_lag)
5860
S3method(prep,step_epi_slide)
@@ -74,6 +76,7 @@ S3method(print,flatline)
7476
S3method(print,frosting)
7577
S3method(print,layer_add_forecast_date)
7678
S3method(print,layer_add_target_date)
79+
S3method(print,layer_epi_YeoJohnson)
7780
S3method(print,layer_naomit)
7881
S3method(print,layer_point_from_distn)
7982
S3method(print,layer_population_scaling)
@@ -84,6 +87,7 @@ S3method(print,layer_threshold)
8487
S3method(print,layer_unnest)
8588
S3method(print,step_adjust_latency)
8689
S3method(print,step_climate)
90+
S3method(print,step_epi_YeoJohnson)
8791
S3method(print,step_epi_ahead)
8892
S3method(print,step_epi_lag)
8993
S3method(print,step_epi_slide)
@@ -99,6 +103,7 @@ S3method(run_mold,default_epi_recipe_blueprint)
99103
S3method(slather,layer_add_forecast_date)
100104
S3method(slather,layer_add_target_date)
101105
S3method(slather,layer_cdc_flatline_quantiles)
106+
S3method(slather,layer_epi_YeoJohnson)
102107
S3method(slather,layer_naomit)
103108
S3method(slather,layer_point_from_distn)
104109
S3method(slather,layer_population_scaling)
@@ -112,6 +117,7 @@ S3method(snap,quantile_pred)
112117
S3method(tidy,check_enough_train_data)
113118
S3method(tidy,frosting)
114119
S3method(tidy,layer)
120+
S3method(tidy,step_epi_YeoJohnson)
115121
S3method(update,layer)
116122
S3method(vec_arith,quantile_pred)
117123
S3method(vec_arith.numeric,quantile_pred)
@@ -174,6 +180,7 @@ export(layer)
174180
export(layer_add_forecast_date)
175181
export(layer_add_target_date)
176182
export(layer_cdc_flatline_quantiles)
183+
export(layer_epi_YeoJohnson)
177184
export(layer_naomit)
178185
export(layer_point_from_distn)
179186
export(layer_population_scaling)
@@ -205,6 +212,7 @@ export(smooth_quantile_reg)
205212
export(snap)
206213
export(step_adjust_latency)
207214
export(step_climate)
215+
export(step_epi_YeoJohnson)
208216
export(step_epi_ahead)
209217
export(step_epi_lag)
210218
export(step_epi_naomit)

R/layer_yeo_johnson.r

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
#' Unormalizing transformation
2+
#'
3+
#' Will undo a step_epi_YeoJohnson transformation.
4+
#'
5+
#' @param frosting a `frosting` postprocessor. The layer will be added to the
6+
#' sequence of operations for this frosting.
7+
#' @param ... One or more selector functions to scale variables
8+
#' for this step. See [recipes::selections()] for more details.
9+
#' @param df a data frame that contains the population data to be used for
10+
#' inverting the existing scaling.
11+
#' @param by A (possibly named) character vector of variables to join by.
12+
#' @param id a random id string
13+
#'
14+
#' @return an updated `frosting` postprocessor
15+
#' @export
16+
#' @examples
17+
#' library(dplyr)
18+
#' jhu <- epidatasets::cases_deaths_subset %>%
19+
#' filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>%
20+
#' select(geo_value, time_value, cases)
21+
#'
22+
#' # Create a recipe with a Yeo-Johnson transformation.
23+
#' r <- epi_recipe(jhu) %>%
24+
#' step_epi_YeoJohnson(cases) %>%
25+
#' step_epi_lag(cases, lag = 0) %>%
26+
#' step_epi_ahead(cases, ahead = 0, role = "outcome") %>%
27+
#' step_epi_naomit()
28+
#'
29+
#' # Create a frosting layer that will undo the Yeo-Johnson transformation.
30+
#' f <- frosting() %>%
31+
#' layer_predict() %>%
32+
#' layer_epi_YeoJohnson(.pred)
33+
#'
34+
#' # Create a workflow and fit it.
35+
#' wf <- epi_workflow(r, linear_reg()) %>%
36+
#' fit(jhu) %>%
37+
#' add_frosting(f)
38+
#'
39+
#' # Forecast the workflow, which should reverse the Yeo-Johnson transformation.
40+
#' forecast(wf)
41+
#' # Compare to the original data.
42+
#' plot(density(jhu$cases))
43+
#' plot(density(forecast(wf)$cases))
44+
layer_epi_YeoJohnson <- function(frosting, ..., lambdas = NULL, by = NULL, id = rand_id("epi_YeoJohnson")) {
45+
checkmate::assert_tibble(lambdas, min.rows = 1, null.ok = TRUE)
46+
47+
add_layer(
48+
frosting,
49+
layer_epi_YeoJohnson_new(
50+
lambdas = lambdas,
51+
by = by,
52+
terms = dplyr::enquos(...),
53+
id = id
54+
)
55+
)
56+
}
57+
58+
layer_epi_YeoJohnson_new <- function(lambdas, by, terms, id) {
59+
epipredict:::layer("epi_YeoJohnson", lambdas = lambdas, by = by, terms = terms, id = id)
60+
}
61+
62+
#' @export
63+
#' @importFrom workflows extract_preprocessor
64+
slather.layer_epi_YeoJohnson <- function(object, components, workflow, new_data, ...) {
65+
rlang::check_dots_empty()
66+
67+
# Get the lambdas from the layer or from the workflow.
68+
lambdas <- object$lambdas %||% get_lambdas_in_layer(workflow)
69+
70+
# If the by is not specified, try to infer it from the lambdas.
71+
if (is.null(object$by)) {
72+
# Assume `layer_predict` has calculated the prediction keys and other
73+
# layers don't change the prediction key colnames:
74+
prediction_key_colnames <- names(components$keys)
75+
lhs_potential_keys <- prediction_key_colnames
76+
rhs_potential_keys <- colnames(select(lambdas, -starts_with("lambda_")))
77+
object$by <- intersect(lhs_potential_keys, rhs_potential_keys)
78+
suggested_min_keys <- setdiff(lhs_potential_keys, "time_value")
79+
if (!all(suggested_min_keys %in% object$by)) {
80+
cli_warn(
81+
c(
82+
"{setdiff(suggested_min_keys, object$by)} {?was an/were} epikey column{?s} in the predictions,
83+
but {?wasn't/weren't} found in the population `df`.",
84+
"i" = "Defaulting to join by {object$by}",
85+
">" = "Double-check whether column names on the population `df` match those expected in your predictions",
86+
">" = "Consider using population data with breakdowns by {suggested_min_keys}",
87+
">" = "Manually specify `by =` to silence"
88+
),
89+
class = "epipredict__layer_population_scaling__default_by_missing_suggested_keys"
90+
)
91+
}
92+
}
93+
94+
# Establish the join columns.
95+
object$by <- object$by %||%
96+
intersect(
97+
epipredict:::epi_keys_only(components$predictions),
98+
colnames(select(lambdas, -starts_with(".lambda_")))
99+
)
100+
joinby <- list(x = names(object$by) %||% object$by, y = object$by)
101+
hardhat::validate_column_names(components$predictions, joinby$x)
102+
hardhat::validate_column_names(lambdas, joinby$y)
103+
104+
# Join the lambdas.
105+
components$predictions <- inner_join(
106+
components$predictions,
107+
lambdas,
108+
by = object$by,
109+
relationship = "many-to-one",
110+
unmatched = c("error", "drop")
111+
)
112+
113+
exprs <- rlang::expr(c(!!!object$terms))
114+
pos <- tidyselect::eval_select(exprs, components$predictions)
115+
col_names <- names(pos)
116+
117+
# The `object$terms` is where the user specifies the columns they want to
118+
# untransform. We need to match the outcomes with their lambda columns in our
119+
# parameter table and then apply the inverse transformation.
120+
if (identical(col_names, ".pred")) {
121+
# In this case, we don't get a hint for the outcome column name, so we need
122+
# to infer it from the mold.
123+
if (length(components$mold$outcomes) > 1) {
124+
cli_abort("Only one outcome is allowed when specifying `.pred`.", call = rlang::caller_env())
125+
}
126+
# `outcomes` is a vector of objects like ahead_1_cases, ahead_7_cases, etc.
127+
# We want to extract the cases part.
128+
outcome_cols <- names(components$mold$outcomes) %>%
129+
stringr::str_match("ahead_\\d+_(.*)") %>%
130+
magrittr::extract(, 2)
131+
132+
components$predictions <- components$predictions %>%
133+
rowwise() %>%
134+
mutate(.pred := yj_inverse(.pred, !!sym(paste0(".lambda_", outcome_cols))))
135+
} else if (identical(col_names, character(0))) {
136+
# Wish I could suggest `all_outcomes()` here, but currently it's the same as
137+
# not specifying any terms. I don't want to spend time with dealing with
138+
# this case until someone asks for it.
139+
cli::cli_abort("Not specifying columns to layer Yeo-Johnson is not implemented.
140+
If you had a single outcome, you can use `.pred` as a column name.
141+
If you had multiple outcomes, you'll need to specify them like
142+
`.pred_ahead_1_<outcome_col>`, `.pred_ahead_7_<outcome_col>`, etc.
143+
", call = rlang::caller_env())
144+
} else {
145+
# In this case, we assume that the user has specified the columns they want
146+
# transformed here. We then need to determine the lambda columns for each of
147+
# these columns. That is, we need to convert a vector of column names like
148+
# c(".pred_ahead_1_case_rate", ".pred_ahead_7_case_rate") to
149+
# c("lambda_ahead_1_case_rate", "lambda_ahead_7_case_rate").
150+
original_outcome_cols <- str_match(col_names, ".pred_ahead_\\d+_(.*)")[, 2]
151+
outcomes_wout_ahead <- str_match(names(components$mold$outcomes), "ahead_\\d+_(.*)")[,2]
152+
if (any(original_outcome_cols %nin% outcomes_wout_ahead)) {
153+
cli_abort("All columns specified in `...` must be outcome columns.
154+
They must be of the form `.pred_ahead_1_<outcome_col>`, `.pred_ahead_7_<outcome_col>`, etc.
155+
", call = rlang::caller_env())
156+
}
157+
158+
for (i in seq_along(col_names)) {
159+
col <- col_names[i]
160+
lambda_col <- paste0(".lambda_", original_outcome_cols[i])
161+
components$predictions <- components$predictions %>%
162+
rowwise() %>%
163+
mutate(!!sym(col) := yj_inverse(!!sym(col), !!sym(lambda_col)))
164+
}
165+
}
166+
167+
# Remove the lambda columns.
168+
components$predictions <- components$predictions %>%
169+
select(-any_of(starts_with(".lambda_"))) %>%
170+
ungroup()
171+
components
172+
}
173+
174+
#' @export
175+
print.layer_epi_YeoJohnson <- function(x, width = max(20, options()$width - 30), ...) {
176+
title <- "Yeo-Johnson transformation (see `lambdas` object for values) on "
177+
epipredict:::print_layer(x$terms, title = title, width = width)
178+
}
179+
180+
# Inverse Yeo-Johnson transformation
181+
#
182+
# Inverse of `yj_transform` in step_yeo_johnson.R. Note that this function is
183+
# vectorized in x, but not in lambda.
184+
yj_inverse <- function(x, lambda, eps = 0.001) {
185+
if (is.na(lambda)) {
186+
return(x)
187+
}
188+
if (!inherits(x, "tbl_df") || is.data.frame(x)) {
189+
x <- unlist(x, use.names = FALSE)
190+
} else {
191+
if (!is.vector(x)) {
192+
x <- as.vector(x)
193+
}
194+
}
195+
196+
dat_neg <- x < 0
197+
ind_neg <- list(is = which(dat_neg), not = which(!dat_neg))
198+
not_neg <- ind_neg[["not"]]
199+
is_neg <- ind_neg[["is"]]
200+
201+
nn_inv_trans <- function(x, lambda) {
202+
if (abs(lambda) < eps) {
203+
# log(x + 1)
204+
exp(x) - 1
205+
} else {
206+
# ((x + 1)^lambda - 1) / lambda
207+
(lambda * x + 1)^(1 / lambda) - 1
208+
}
209+
}
210+
211+
ng_inv_trans <- function(x, lambda) {
212+
if (abs(lambda - 2) < eps) {
213+
# -log(-x + 1)
214+
-(exp(-x) - 1)
215+
} else {
216+
# -((-x + 1)^(2 - lambda) - 1) / (2 - lambda)
217+
-(((lambda - 2) * x + 1)^(1 / (2 - lambda)) - 1)
218+
}
219+
}
220+
221+
if (length(not_neg) > 0) {
222+
x[not_neg] <- nn_inv_trans(x[not_neg], lambda)
223+
}
224+
225+
if (length(is_neg) > 0) {
226+
x[is_neg] <- ng_inv_trans(x[is_neg], lambda)
227+
}
228+
x
229+
}
230+
231+
get_lambdas_in_layer <- function(workflow) {
232+
this_recipe <- hardhat::extract_recipe(workflow)
233+
if (!(this_recipe %>% recipes::detect_step("epi_YeoJohnson"))) {
234+
cli_abort("`layer_epi_YeoJohnson` requires `step_epi_YeoJohnson` in the recipe.", call = rlang::caller_env())
235+
}
236+
for (step in this_recipe$steps) {
237+
if (inherits(step, "step_epi_YeoJohnson")) {
238+
lambdas <- step$lambdas
239+
break
240+
}
241+
}
242+
lambdas
243+
}
244+
245+
get_transformed_cols_in_layer <- function(workflow) {
246+
this_recipe <- hardhat::extract_recipe(workflow)
247+
if (!(this_recipe %>% recipes::detect_step("epi_YeoJohnson"))) {
248+
cli_abort("`layer_epi_YeoJohnson` requires `step_epi_YeoJohnson` in the recipe.", call = rlang::caller_env())
249+
}
250+
for (step in this_recipe$steps) {
251+
if (inherits(step, "step_epi_YeoJohnson")) {
252+
lambdas <- step$lambdas
253+
break
254+
}
255+
}
256+
}

0 commit comments

Comments
 (0)