Skip to content

Commit 166f2be

Browse files
committed
add function to generate oracle data
1 parent a51473d commit 166f2be

File tree

5 files changed

+187
-1
lines changed

5 files changed

+187
-1
lines changed

DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ Imports:
2727
purrr,
2828
stringr,
2929
tidyr,
30-
tidyselect
30+
tidyselect,
31+
hubUtils
3132
Remotes:
3233
forecasttools=github::cdcgov/forecasttools,
3334
hubUtils=github::hubverse-org/hubUtils,

R/generate_oracle_output.R

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
#' Transform a modeling task represented as a nested list
2+
#' to a single data frame
3+
#'
4+
#' @param task Nested list representing a modeling task,
5+
#' as one entry of the output of [hubUtils::get_round_model_tasks()].
6+
#' Must have a `target_end_date` specification.
7+
#' @return A [`tibble`][tibble::tibble()] of all potentially
8+
#' valid submittable outputs for the modeling task defined in `task`.
9+
#' Each row of the table represents a single valid forecastable quantity
10+
#' (e.g. "`target` X on `target_end_date` Y in `location` Z"),
11+
#' plus a valid submittable output_type for forecasting that quantity.
12+
#' If multiple `output_type`s are accepted for a given valid forecastable
13+
#' quantity, that quantity will be represented multiple times, with
14+
#' one row for each valid associated `output_type`.
15+
flatten_task <- function(task) {
16+
checkmate::assert_names(
17+
names(task),
18+
must.include = c("output_type", "task_ids")
19+
)
20+
checkmate::assert_names(
21+
names(task$task_ids),
22+
must.include = "target_end_date"
23+
)
24+
output_types <- names(task$output_type)
25+
26+
task_params <- purrr::map(task$task_ids, \(x) c(x$required, x$optional)) |>
27+
purrr::discard_at(c("horizon", "reference_date"))
28+
## discard columns that are redundant with `target_end_date`
29+
30+
return(do.call(
31+
tidyr::crossing,
32+
c(task_params, list(output_type = output_types))
33+
))
34+
}
35+
36+
37+
#' Transform a group of modeling task represented as a list of
38+
#' nested lists into a single data frame.
39+
#'
40+
#' Calls [flatten_task()] on each entry of the task list.
41+
#'
42+
#' @param task_list List of tasks. Each entry should itself be
43+
#' be a nested list that can be passed to [flatten_task()].
44+
#' @param .deduplicate deduplicate the output if the same flat
45+
#' configuration is found multiple times while flattening the task list?
46+
#' Default `TRUE`.
47+
#'
48+
#' @return A [`tibble`][tibble::tibble()] of all potentially
49+
#' valid submittable outputs for all the modeling tasks defined in `task_lists`.
50+
#' Each row of the table represents a single valid forecastable quantity
51+
#' (e.g. "`target` X on `target_end_date` Y in `location` Z"),
52+
#' plus a valid submittable output_type for forecasting that quantity.
53+
#' If multiple `output_type`s are accepted for a given valid forecastable
54+
#' quantity, that quantity will be represented multiple times, with
55+
#' one row for each valid associated `output_type`.
56+
#'
57+
flatten_task_list <- function(task_list, .deduplicate = TRUE) {
58+
flat_tasks <- purrr::map_df(task_list, flatten_task)
59+
60+
if (.deduplicate) {
61+
flat_tasks <- dplyr::distinct(flat_tasks)
62+
}
63+
64+
return(flat_tasks)
65+
}
66+
67+
#' Generate and save oracle output for the Hub
68+
#'
69+
#' @param hub_path Path to the hub root.
70+
#'
71+
#' @return nothing, invisibly, on success.
72+
generate_oracle_output <- function(hub_path) {
73+
output_dirpath <- fs::path(hub_path, "target-data")
74+
fs::dir_create(output_dirpath)
75+
target_ts <- hubData::connect_target_timeseries(hub_path)
76+
77+
config_tasks <- hubUtils::read_config(hub_path, "tasks")
78+
round_ids <- hubUtils::get_round_ids(config_tasks)
79+
80+
## this involves duplication given how hubUtils::get_round_model_tasks
81+
## behaves by default with round ids created from reference dates,
82+
## but we do this this way for completeness / generality
83+
list_of_task_lists <- purrr::map(round_ids, \(id) {
84+
hubUtils::get_round_model_tasks(config_tasks, id)
85+
})
86+
87+
unique_tasks <- purrr::map_df(list_of_task_lists, flatten_task_list) |>
88+
dplyr::distinct() |>
89+
dplyr::mutate(target_end_date = as.Date(.data$target_end_date))
90+
91+
target_data <- target_ts |>
92+
forecasttools::hub_target_data_as_of("latest", .drop = TRUE) |>
93+
dplyr::collect() |>
94+
dplyr::rename(target_end_date = "date")
95+
96+
join_key <- intersect(
97+
colnames(unique_tasks),
98+
colnames(target_data)
99+
)
100+
101+
oracle_data <- dplyr::inner_join(unique_tasks, target_data, by = join_key) |>
102+
dplyr::mutate(output_type_id = NA) |>
103+
dplyr::rename(
104+
oracle_value = "observation"
105+
)
106+
107+
output_file <- fs::path(output_dirpath, "oracle-output", ext = "parquet")
108+
forecasttools::write_tabular_file(oracle_data, output_file)
109+
invisible()
110+
}

man/flatten_task.Rd

Lines changed: 28 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/flatten_task_list.Rd

Lines changed: 30 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/generate_oracle_output.Rd

Lines changed: 17 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)