|
| 1 | +#' Transform a modeling task represented as a nested list |
| 2 | +#' to a single data frame |
| 3 | +#' |
| 4 | +#' @param task Nested list representing a modeling task, |
| 5 | +#' as one entry of the output of [hubUtils::get_round_model_tasks()]. |
| 6 | +#' Must have a `target_end_date` specification. |
| 7 | +#' @return A [`tibble`][tibble::tibble()] of all potentially |
| 8 | +#' valid submittable outputs for the modeling task defined in `task`. |
| 9 | +#' Each row of the table represents a single valid forecastable quantity |
| 10 | +#' (e.g. "`target` X on `target_end_date` Y in `location` Z"), |
| 11 | +#' plus a valid submittable output_type for forecasting that quantity. |
| 12 | +#' If multiple `output_type`s are accepted for a given valid forecastable |
| 13 | +#' quantity, that quantity will be represented multiple times, with |
| 14 | +#' one row for each valid associated `output_type`. |
| 15 | +flatten_task <- function(task) { |
| 16 | + checkmate::assert_names( |
| 17 | + names(task), |
| 18 | + must.include = c("output_type", "task_ids") |
| 19 | + ) |
| 20 | + checkmate::assert_names( |
| 21 | + names(task$task_ids), |
| 22 | + must.include = "target_end_date" |
| 23 | + ) |
| 24 | + output_types <- names(task$output_type) |
| 25 | + |
| 26 | + task_params <- purrr::map(task$task_ids, \(x) c(x$required, x$optional)) |> |
| 27 | + purrr::discard_at(c("horizon", "reference_date")) |
| 28 | + ## discard columns that are redundant with `target_end_date` |
| 29 | + |
| 30 | + return(do.call( |
| 31 | + tidyr::crossing, |
| 32 | + c(task_params, list(output_type = output_types)) |
| 33 | + )) |
| 34 | +} |
| 35 | + |
| 36 | + |
| 37 | +#' Transform a group of modeling task represented as a list of |
| 38 | +#' nested lists into a single data frame. |
| 39 | +#' |
| 40 | +#' Calls [flatten_task()] on each entry of the task list. |
| 41 | +#' |
| 42 | +#' @param task_list List of tasks. Each entry should itself be |
| 43 | +#' be a nested list that can be passed to [flatten_task()]. |
| 44 | +#' @param .deduplicate deduplicate the output if the same flat |
| 45 | +#' configuration is found multiple times while flattening the task list? |
| 46 | +#' Default `TRUE`. |
| 47 | +#' |
| 48 | +#' @return A [`tibble`][tibble::tibble()] of all potentially |
| 49 | +#' valid submittable outputs for all the modeling tasks defined in `task_lists`. |
| 50 | +#' Each row of the table represents a single valid forecastable quantity |
| 51 | +#' (e.g. "`target` X on `target_end_date` Y in `location` Z"), |
| 52 | +#' plus a valid submittable output_type for forecasting that quantity. |
| 53 | +#' If multiple `output_type`s are accepted for a given valid forecastable |
| 54 | +#' quantity, that quantity will be represented multiple times, with |
| 55 | +#' one row for each valid associated `output_type`. |
| 56 | +#' |
| 57 | +flatten_task_list <- function(task_list, .deduplicate = TRUE) { |
| 58 | + flat_tasks <- purrr::map_df(task_list, flatten_task) |
| 59 | + |
| 60 | + if (.deduplicate) { |
| 61 | + flat_tasks <- dplyr::distinct(flat_tasks) |
| 62 | + } |
| 63 | + |
| 64 | + return(flat_tasks) |
| 65 | +} |
| 66 | + |
| 67 | +#' Generate and save oracle output for the Hub |
| 68 | +#' |
| 69 | +#' @param hub_path Path to the hub root. |
| 70 | +#' |
| 71 | +#' @return nothing, invisibly, on success. |
| 72 | +generate_oracle_output <- function(hub_path) { |
| 73 | + output_dirpath <- fs::path(hub_path, "target-data") |
| 74 | + fs::dir_create(output_dirpath) |
| 75 | + target_ts <- hubData::connect_target_timeseries(hub_path) |
| 76 | + |
| 77 | + config_tasks <- hubUtils::read_config(hub_path, "tasks") |
| 78 | + round_ids <- hubUtils::get_round_ids(config_tasks) |
| 79 | + |
| 80 | + ## this involves duplication given how hubUtils::get_round_model_tasks |
| 81 | + ## behaves by default with round ids created from reference dates, |
| 82 | + ## but we do this this way for completeness / generality |
| 83 | + list_of_task_lists <- purrr::map(round_ids, \(id) { |
| 84 | + hubUtils::get_round_model_tasks(config_tasks, id) |
| 85 | + }) |
| 86 | + |
| 87 | + unique_tasks <- purrr::map_df(list_of_task_lists, flatten_task_list) |> |
| 88 | + dplyr::distinct() |> |
| 89 | + dplyr::mutate(target_end_date = as.Date(.data$target_end_date)) |
| 90 | + |
| 91 | + target_data <- target_ts |> |
| 92 | + forecasttools::hub_target_data_as_of("latest", .drop = TRUE) |> |
| 93 | + dplyr::collect() |> |
| 94 | + dplyr::rename(target_end_date = "date") |
| 95 | + |
| 96 | + join_key <- intersect( |
| 97 | + colnames(unique_tasks), |
| 98 | + colnames(target_data) |
| 99 | + ) |
| 100 | + |
| 101 | + oracle_data <- dplyr::inner_join(unique_tasks, target_data, by = join_key) |> |
| 102 | + dplyr::mutate(output_type_id = NA) |> |
| 103 | + dplyr::rename( |
| 104 | + oracle_value = "observation" |
| 105 | + ) |
| 106 | + |
| 107 | + output_file <- fs::path(output_dirpath, "oracle-output", ext = "parquet") |
| 108 | + forecasttools::write_tabular_file(oracle_data, output_file) |
| 109 | + invisible() |
| 110 | +} |
0 commit comments