Skip to content

Simplify workflow #399

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 17 commits into
base: v0.2.0
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -8,6 +8,8 @@ S3method(Remove_model,epi_workflow)
S3method(Remove_model,workflow)
S3method(Update_model,epi_workflow)
S3method(Update_model,workflow)
S3method(add_frosting,default)
S3method(add_frosting,epi_workflow)
S3method(adjust_epi_recipe,epi_recipe)
S3method(adjust_epi_recipe,epi_workflow)
S3method(adjust_frosting,epi_workflow)
@@ -97,6 +99,8 @@ S3method(quantile,dist_quantiles)
S3method(recipe,epi_df)
S3method(recipes::recipe,formula)
S3method(refresh_blueprint,default_epi_recipe_blueprint)
S3method(remove_frosting,default)
S3method(remove_frosting,epi_workflow)
S3method(residuals,flatline)
S3method(run_mold,default_epi_recipe_blueprint)
S3method(slather,layer_add_forecast_date)
@@ -119,6 +123,8 @@ S3method(tidy,check_enough_train_data)
S3method(tidy,frosting)
S3method(tidy,layer)
S3method(update,layer)
S3method(update_frosting,default)
S3method(update_frosting,epi_workflow)
S3method(vec_ptype_abbr,dist_quantiles)
S3method(vec_ptype_full,dist_quantiles)
S3method(weighted_interval_score,default)
@@ -271,6 +277,7 @@ importFrom(rlang,":=")
importFrom(rlang,abort)
importFrom(rlang,arg_match)
importFrom(rlang,as_function)
importFrom(rlang,caller_arg)
importFrom(rlang,caller_env)
importFrom(rlang,enquo)
importFrom(rlang,enquos)
45 changes: 0 additions & 45 deletions R/create-layer.R

This file was deleted.

151 changes: 29 additions & 122 deletions R/epi_recipe.R
Original file line number Diff line number Diff line change
@@ -95,23 +95,12 @@ add_epi_recipe <- function(
#' @rdname add_epi_recipe
#' @export
remove_epi_recipe <- function(x) {
workflows:::validate_is_workflow(x)

if (!workflows:::has_preprocessor_recipe(x)) {
rlang::warn("The workflow has no recipe preprocessor to remove.")
}

actions <- x$pre$actions
actions[["recipe"]] <- NULL

new_epi_workflow(
pre = workflows:::new_stage_pre(actions = actions),
fit = x$fit,
post = x$post,
trained = FALSE
)
x <- workflows::remove_recipe(x)
class(x) <- c("epi_workflow", class(x))
x
}


#' @rdname add_epi_recipe
#' @export
update_epi_recipe <- function(x, recipe, ..., blueprint = default_epi_recipe_blueprint()) {
@@ -180,15 +169,21 @@ adjust_epi_recipe <- function(x, which_step, ..., blueprint = default_epi_recipe

#' @rdname adjust_epi_recipe
#' @export
adjust_epi_recipe.epi_workflow <- function(x, which_step, ..., blueprint = default_epi_recipe_blueprint()) {
recipe <- adjust_epi_recipe(workflows::extract_preprocessor(x), which_step, ...)
adjust_epi_recipe.epi_workflow <- function(
x, which_step, ..., blueprint = default_epi_recipe_blueprint()
) {

update_epi_recipe(x, recipe, blueprint = blueprint)
rec <- adjust_epi_recipe(
workflows::extract_preprocessor(x), which_step, ...
)
update_epi_recipe(x, rec, blueprint = blueprint)
}

#' @rdname adjust_epi_recipe
#' @export
adjust_epi_recipe.epi_recipe <- function(x, which_step, ..., blueprint = default_epi_recipe_blueprint()) {
adjust_epi_recipe.epi_recipe <- function(
x, which_step, ..., blueprint = default_epi_recipe_blueprint()
) {
if (!(is.numeric(which_step) || is.character(which_step))) {
cli::cli_abort(
c("`which_step` must be a number or a character.",
@@ -294,109 +289,21 @@ kill_levels <- function(x, keys) {

#' @export
print.epi_recipe <- function(x, form_width = 30, ...) {
cli::cli_div(theme = list(.pkg = list("vec-trunc" = Inf, "vec-last" = ", ")))

cli::cli_h1("Epi Recipe")
cli::cli_h3("Inputs")

tab <- table(x$var_info$role, useNA = "ifany")
tab <- stats::setNames(tab, names(tab))
names(tab)[is.na(names(tab))] <- "undeclared role"

roles <- c("outcome", "predictor", "case_weights", "undeclared role")

tab <- c(
tab[names(tab) == roles[1]],
tab[names(tab) == roles[2]],
tab[names(tab) == roles[3]],
sort(tab[!names(tab) %in% roles], TRUE),
tab[names(tab) == roles[4]]
)

cli::cli_text("Number of variables by role")

spaces_needed <- max(nchar(names(tab))) - nchar(names(tab)) +
max(nchar(tab)) - nchar(tab)

cli::cli_verbatim(
glue::glue("{names(tab)}: {strrep('\ua0', spaces_needed)}{tab}")
)

if ("tr_info" %in% names(x)) {
cli::cli_h3("Training information")
nmiss <- x$tr_info$nrows - x$tr_info$ncomplete
nrows <- x$tr_info$nrows

cli::cli_text(
"Training data contained {nrows} data points and {cli::no(nmiss)} \\
incomplete row{?s}."
)
}

if (!is.null(x$steps)) {
cli::cli_h3("Operations")
}

fmt <- cli::cli_fmt({
for (step in x$steps) {
print(step, form_width = form_width)
}
})
cli::cli_ol(fmt)
cli::cli_end()

invisible(x)
}

# Currently only used in the workflow printing
print_preprocessor_recipe <- function(x, ...) {
recipe <- workflows::extract_preprocessor(x)
steps <- recipe$steps
n_steps <- length(steps)
cli::cli_text("{n_steps} Recipe step{?s}.")

if (n_steps == 0L) {
return(invisible(x))
}

step_names <- map_chr(steps, workflows:::pull_step_name)

if (n_steps <= 10L) {
cli::cli_ol(step_names)
return(invisible(x))
}

extra_steps <- n_steps - 10L
step_names <- step_names[1:10]

cli::cli_ol(step_names)
cli::cli_bullets("... and {extra_steps} more step{?s}.")
invisible(x)
}

print_preprocessor <- function(x) {
has_preprocessor_formula <- workflows:::has_preprocessor_formula(x)
has_preprocessor_recipe <- workflows:::has_preprocessor_recipe(x)
has_preprocessor_variables <- workflows:::has_preprocessor_variables(x)

no_preprocessor <- !has_preprocessor_formula && !has_preprocessor_recipe &&
!has_preprocessor_variables

if (no_preprocessor) {
return(invisible(x))
}

cli::cli_rule("Preprocessor")
cli::cli_text("")

if (has_preprocessor_formula) {
workflows:::print_preprocessor_formula(x)
}
if (has_preprocessor_recipe) {
print_preprocessor_recipe(x)
}
if (has_preprocessor_variables) {
workflows:::print_preprocessor_variables(x)
o <- cli::cli_fmt(NextMethod())
# Fix up the recipe name
rr <- unlist(strsplit(o[2], "Recipe"))
len <- nchar(rr[2])
h1_tail <- paste0(substr(rr[2], 1, len / 2 - 10), substr(rr[2], len / 2, len))
o[2] <- paste0(rr[1], "Epi Recipe", h1_tail)

# Number the operations
has_operations <- any(grepl(" Operations ", o, fixed = TRUE))
if (has_operations) {
ops <- seq(grep(" Operations ", o, fixed = TRUE) + 1, length(o))
# kills the \bullet
rep_ops <- sub("^\\033\\[36m.\\033\\[39m ", "", o[ops], perl = TRUE)
o[ops] <- paste0(ops - ops[1] + 1, ". ", rep_ops)
}
cli::cli_bullets(o)
invisible(x)
}
38 changes: 13 additions & 25 deletions R/epi_workflow.R
Original file line number Diff line number Diff line change
@@ -32,18 +32,15 @@
#'
#' wf
epi_workflow <- function(preprocessor = NULL, spec = NULL, postprocessor = NULL) {
out <- workflows::workflow(spec = spec)
class(out) <- c("epi_workflow", class(out))

out <- workflows::workflow(preprocessor, spec = spec)
if (is_epi_recipe(preprocessor)) {
out <- workflows::remove_recipe(out)
out <- add_epi_recipe(out, preprocessor)
} else if (!is_null(preprocessor)) {
out <- workflows:::add_preprocessor(out, preprocessor)
}
class(out) <- c("epi_workflow", class(out))
if (!is_null(postprocessor)) {
out <- add_postprocessor(out, postprocessor)
}

out
}

@@ -101,7 +98,6 @@ fit.epi_workflow <- function(object, data, ..., control = workflows::control_wor
as_of = attributes(data)$metadata$as_of
)
object$original_data <- data

NextMethod()
}

@@ -162,11 +158,14 @@ predict.epi_workflow <- function(object, new_data, type = NULL, opts = list(), .
}
components <- list()
components$mold <- workflows::extract_mold(object)
components$forged <- hardhat::forge(new_data,
components$forged <- hardhat::forge(
new_data,
blueprint = components$mold$blueprint
)
components$keys <- grab_forged_keys(components$forged, object, new_data)
components <- apply_frosting(object, components, new_data, type = type, opts = opts, ...)
components <- apply_frosting(
object, components, new_data, type = type, opts = opts, ...
)
components$predictions
}

@@ -201,25 +200,14 @@ augment.epi_workflow <- function(x, new_data, ...) {
full_join(predictions, new_data, by = join_by)
}

new_epi_workflow <- function(
pre = workflows:::new_stage_pre(),
fit = workflows:::new_stage_fit(),
post = workflows:::new_stage_post(),
trained = FALSE) {
out <- workflows:::new_workflow(
pre = pre, fit = fit, post = post, trained = trained
)
class(out) <- c("epi_workflow", class(out))
out
}


#' @export
print.epi_workflow <- function(x, ...) {
print_header(x)
print_preprocessor(x)
# workflows:::print_case_weights(x)
print_model(x)
trained <- ifelse(workflows::is_trained_workflow(x), " [trained]", "")
header <- glue::glue("Epi Workflow{trained}")
txt <- utils::capture.output(NextMethod())
txt[1] <- cli::rule(header, line = 2)
cli::cat_line(txt)
print_postprocessor(x)
invisible(x)
}
2 changes: 1 addition & 1 deletion R/extract.R
Original file line number Diff line number Diff line change
@@ -82,7 +82,7 @@ extract_argument.epi_workflow <- function(x, name, arg, ...) {
rlang::check_dots_empty()
type <- sub("_.*", "", name)
if (type %in% c("check", "step")) {
if (!workflows:::has_preprocessor_recipe(x)) {
if ("recipe" %nin% names(x$pre$actions)) {
cli_abort("The workflow must have a recipe preprocessor.")
}
out <- extract_argument(x$pre$actions$recipe$recipe, name, arg)
90 changes: 53 additions & 37 deletions R/frosting.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#' Add frosting to a workflow
#' Add frosting to an epi_workflow
#'
#' @param x A workflow
#' @param x An epi_workflow
#' @param frosting A frosting object created using `frosting()`.
#' @param ... Not used.
#'
@@ -38,37 +38,54 @@
#' p3 <- predict(wf3, latest)
#' p3
#'
#'
add_frosting <- function(x, frosting, ...) {
rlang::check_dots_empty()
action <- workflows:::new_action_post(frosting = frosting)
epi_add_action(x, action, "frosting", ...)
UseMethod("add_frosting")
}


# Hacks around workflows `order_stage_post <- charcter(0)` ----------------
epi_add_action <- function(x, action, name, ..., call = caller_env()) {
workflows:::validate_is_workflow(x, call = call)
add_action_frosting(x, action, name, ..., call = call)
#' @export
add_frosting.default <- function(x, frosting, ..., arg = caller_arg(x)) {
cli_abort("{x} must be a {.cls workflow}, not a {.cls {class(x)[1]}}.")
}
add_action_frosting <- function(x, action, name, ..., call = caller_env()) {
workflows:::check_singleton(x$post$actions, name, call = call)
x$post <- workflows:::add_action_to_stage(x$post, action, name, order_stage_frosting())
x

#' @export
add_frosting.epi_workflow <- function(x, frosting, ...) {
rlang::check_dots_empty()
action <- structure(
list(frosting = frosting),
class = c("action_post", "action")
)
if ("frosting" %in% names(x$post$actions)) {
cli_abort("A `frosting` action has already been added to this workflow.")
}
add_frosting_postprocessor(x, action)
}
order_stage_frosting <- function() "frosting"
# End hacks. See cmu-delphi/epipredict#75

add_frosting_postprocessor <- function(wf, action) {
actions <- c(wf$post$actions, list(frosting = action))
order <- intersect("frosting", names(actions))
actions <- actions[order]
wf$post$actions <- actions
wf
}

#' @rdname add_frosting
#' @export
remove_frosting <- function(x) {
workflows:::validate_is_workflow(x)
remove_frosting <- function(x, ...) {
UseMethod("remove_frosting")
}

#' @export
remove_frosting.default <- function(x, ..., arg = caller_arg(x)) {
cli_abort("{arg} must be an {.cls epi_workflow}, not a {.cls {class(x)[1]}}.")
}

#' @export
remove_frosting.epi_workflow <- function(x, ..., arg = caller_arg(x)) {
if (!has_postprocessor_frosting(x)) {
rlang::warn("The workflow has no frosting postprocessor to remove.")
cli_warn("The epi_workflow {arg} has no frosting postprocessor to remove.")
return(x)
}

x$post$actions[["frosting"]] <- NULL
x
}
@@ -85,18 +102,27 @@ validate_has_postprocessor <- function(x, ..., call = caller_env()) {
rlang::check_dots_empty()
has_postprocessor <- has_postprocessor_frosting(x)
if (!has_postprocessor) {
message <- c(
cli_abort(c(
"The workflow must have a frosting postprocessor.",
i = "Provide one with `add_frosting()`."
)
rlang::abort(message, call = call)
), call = call)
}
invisible(x)
}

#' @rdname add_frosting
#' @export
update_frosting <- function(x, frosting, ...) {
UseMethod("update_frosting")
}

#' @export
update_frosting.default <- function(x, frosting, ..., arg = caller_arg(x)) {
cli_abort("{arg} must be an {.cls epi_workflow}, not a {.cls {class(x)[1]}}.")
}

#' @export
update_frosting.epi_workflow <- function(x, frosting, ...) {
rlang::check_dots_empty()
x <- remove_frosting(x)
add_frosting(x, frosting)
@@ -225,8 +251,8 @@ is_frosting <- function(x) {
inherits(x, "frosting")
}

#' @importFrom rlang caller_env
validate_frosting <- function(x, ..., arg = "`x`", call = caller_env()) {
#' @importFrom rlang caller_env caller_arg
validate_frosting <- function(x, ..., arg = caller_arg(x), call = caller_env()) {
rlang::check_dots_empty()
if (!is_frosting(x)) {
cli_abort(
@@ -237,16 +263,6 @@ validate_frosting <- function(x, ..., arg = "`x`", call = caller_env()) {
invisible(x)
}

new_frosting <- function() {
structure(
list(
layers = NULL,
requirements = NULL
),
class = "frosting"
)
}


#' Create frosting for postprocessing predictions
#'
@@ -289,11 +305,11 @@ new_frosting <- function() {
#' p
frosting <- function(layers = NULL, requirements = NULL) {
if (!is_null(layers) || !is_null(requirements)) {
cli::cli_abort(
cli_abort(
"Currently, no arguments to `frosting()` are allowed to be non-null."
)
}
out <- new_frosting()
structure(list(layers = NULL, requirements = NULL), class = "frosting")
}


17 changes: 15 additions & 2 deletions R/layers.R
Original file line number Diff line number Diff line change
@@ -74,15 +74,28 @@ layer <- function(subclass, ..., .prefix = "layer_") {
#' p1
#' @export
update.layer <- function(object, ...) {
changes <- list(...)
changes <- enlist(...)

# Replace the appropriate values in object with the changes
object <- recipes:::update_fields(object, changes)
object <- update_layers(object, changes)

# Call layer() to construct a new layer to ensure all new changes are validated
reconstruct_layer(object)
}

update_layers <- function(object, changes) {
new_nms <- names(changes)
old_nms <- names(object)
layer_type <- class(object)[1]
for (nm in new_nms) {
if (!(nm %in% old_nms)) {
cli::cli_abort("The layer you are trying to update, {.fn {layer_type}}, \\\n does not have the {.field {nm}} field.")
}
object[[nm]] <- changes[[nm]]
}
object
}

reconstruct_layer <- function(x) {
# Collect the subclass of the layer to use
# when recreating it
15 changes: 3 additions & 12 deletions R/model-methods.R
Original file line number Diff line number Diff line change
@@ -80,18 +80,9 @@ Add_model.epi_workflow <- function(x, spec, ..., formula = NULL) {
#' @rdname Add_model
#' @export
Remove_model.epi_workflow <- function(x) {
workflows:::validate_is_workflow(x)

if (!workflows:::has_spec(x)) {
rlang::warn("The workflow has no model to remove.")
}

new_epi_workflow(
pre = x$pre,
fit = workflows:::new_stage_fit(),
post = x$post,
trained = FALSE
)
x <- workflows::remove_model(x)
class(x) <- c("epi_workflow", class(x))
x
}

#' @rdname Add_model
2 changes: 2 additions & 0 deletions R/utils-misc.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
`%nin%` <- function(x, table) match(x, table, nomatch = 0) == 0

#' Check that newly created variable names don't overlap
#'
#' `check_pname` is to be used in a slather method to ensure that
153 changes: 0 additions & 153 deletions R/workflow-printing.R
Original file line number Diff line number Diff line change
@@ -1,84 +1,3 @@
print_header <- function(x) {
cli::cli_text("")
trained <- ifelse(workflows::is_trained_workflow(x), " [trained]", "")
d <- cli::cli_div(theme = list(rule = list("line-type" = "double")))
cli::cli_rule("Epi Workflow{trained}")
cli::cli_end(d)

preprocessor_msg <- cli::style_italic("Preprocessor:")
preprocessor <- dplyr::case_when(
workflows:::has_preprocessor_formula(x) ~ "Formula",
workflows:::has_preprocessor_recipe(x) ~ "Recipe",
workflows:::has_preprocessor_variables(x) ~ "Variables",
TRUE ~ "None"
)
cli::cli_text("{.emph Preprocessor:} {preprocessor}")


if (workflows:::has_spec(x)) {
spec <- class(workflows::extract_spec_parsnip(x))[[1]]
spec <- glue::glue("{spec}()")
} else {
spec <- "None"
}
cli::cli_text("{.emph Model:} {spec}")

postprocessor <- ifelse(has_postprocessor_frosting(x), "Frosting", "None")
cli::cli_text("{.emph Postprocessor:} {postprocessor}")
cli::cli_text("")
invisible(x)
}


print_preprocessor <- function(x) {
has_preprocessor_formula <- workflows:::has_preprocessor_formula(x)
has_preprocessor_recipe <- workflows:::has_preprocessor_recipe(x)
has_preprocessor_variables <- workflows:::has_preprocessor_variables(x)

no_preprocessor <- !has_preprocessor_formula && !has_preprocessor_recipe &&
!has_preprocessor_variables

if (no_preprocessor) {
return(invisible(x))
}

cli::cli_rule("Preprocessor")
cli::cli_text("")

if (has_preprocessor_formula) {
print_preprocessor_formula(x)
}
if (has_preprocessor_recipe) {
print_preprocessor_recipe(x)
}
if (has_preprocessor_variables) {
print_preprocessor_variables(x)
}
cli::cli_text("")
invisible(x)
}

# revision of workflows:::print_model()
print_model <- function(x) {
has_spec <- workflows:::has_spec(x)
if (!has_spec) {
cli::cli_text("")
return(invisible(x))
}
has_fit <- workflows:::has_fit(x)
cli::cli_rule("Model")

if (has_fit) {
print_fit(x)
cli::cli_text("")
return(invisible(x))
}
workflows:::print_spec(x)
cli::cli_text("")
invisible(x)
}


print_postprocessor <- function(x) {
if (!has_postprocessor_frosting(x)) {
return(invisible(x))
@@ -94,78 +13,6 @@ print_postprocessor <- function(x) {
}


# subfunctions for printing -----------------------------------------------



print_preprocessor_formula <- function(x) {
formula <- workflows::extract_preprocessor(x)
formula <- rlang::expr_text(formula)
cli::cli_text(formula)
invisible(x)
}

print_preprocessor_variables <- function(x) {
variables <- workflows::extract_preprocessor(x)
outcomes <- rlang::quo_get_expr(variables$outcomes)
predictors <- rlang::quo_get_expr(variables$predictors)
outcomes <- rlang::expr_text(outcomes)
predictors <- rlang::expr_text(predictors)
cli::cli_text("Outcomes: ", outcomes)
cli::cli_text("")
cli::cli_text("Predictors: ", predictors)
invisible(x)
}

# Currently only used in the workflow printing
print_preprocessor_recipe <- function(x, ...) {
recipe <- workflows::extract_preprocessor(x)
steps <- recipe$steps
n_steps <- length(steps)
cli::cli_text("{n_steps} Recipe step{?s}.")

if (n_steps == 0L) {
return(invisible(x))
}

step_names <- map_chr(steps, workflows:::pull_step_name)

if (n_steps <= 10L) {
cli::cli_ol(step_names)
return(invisible(x))
}

extra_steps <- n_steps - 10L
step_names <- step_names[1:10]

cli::cli_ol(step_names)
cli::cli_bullets("... and {extra_steps} more step{?s}.")
invisible(x)
}




print_fit <- function(x) {
parsnip_fit <- workflows::extract_fit_parsnip(x)
fit <- parsnip_fit$fit
output <- utils::capture.output(fit)
n_output <- length(output)
if (n_output < 50L) {
print(fit)
return(invisible(x))
}
n_extra_output <- n_output - 50L
output <- output[1:50]
empty_string <- output == ""
output[empty_string] <- " "

cli::cli_verbatim(output)
cli::cli_text("")
cli::cli_text("... and {n_extra_output} more line{?s}.")
invisible(x)
}

# Currently only used in the workflow printing
print_frosting <- function(x, ...) {
layers <- x$layers
44 changes: 0 additions & 44 deletions inst/templates/layer.R

This file was deleted.

9 changes: 5 additions & 4 deletions man/add_frosting.Rd
1 change: 1 addition & 0 deletions tests/testthat/test-epi_workflow.R
Original file line number Diff line number Diff line change
@@ -59,6 +59,7 @@ test_that("model can be added/updated/removed from epi_workflow", {
expect_equal(class(model_spec2), c("linear_reg", "model_spec"))

wf <- remove_model(wf)
expect_equal(class(wf), c("epi_workflow", "workflow"))
expect_error(extract_spec_parsnip(wf))
expect_equal(wf$fit$actions$model$spec, NULL)
})
12 changes: 6 additions & 6 deletions tests/testthat/test-frosting.R
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
test_that("frosting validators / constructors work", {
wf <- epi_workflow()
expect_s3_class(new_frosting(), "frosting")
expect_true(is_frosting(new_frosting()))
expect_silent(epi_workflow(postprocessor = new_frosting()))
expect_s3_class(frosting(), "frosting")
expect_true(is_frosting(frosting()))
expect_silent(epi_workflow(postprocessor = frosting()))
expect_false(has_postprocessor(wf))
expect_false(has_postprocessor_frosting(wf))
expect_silent(wf %>% add_frosting(new_frosting()))
expect_silent(wf %>% add_postprocessor(new_frosting()))
expect_silent(wf %>% add_frosting(frosting()))
expect_silent(wf %>% add_postprocessor(frosting()))
expect_error(wf %>% add_postprocessor(list()))

wf <- wf %>% add_frosting(new_frosting())
wf <- wf %>% add_frosting(frosting())
expect_true(has_postprocessor(wf))
expect_true(has_postprocessor_frosting(wf))
})