Skip to content

Commit 84f35e6

Browse files
committed
attempt to dryify code
1 parent f78d90e commit 84f35e6

12 files changed

+711
-369
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Description: This package provides functions to help with the maintenance of CDC
99
License: Apache License (>= 2)
1010
Encoding: UTF-8
1111
Roxygen: list(markdown = TRUE)
12-
RoxygenNote: 7.3.2
12+
RoxygenNote: 7.3.3
1313
Imports:
1414
checkmate,
1515
cli,

NAMESPACE

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@ export(get_forecast_data)
1010
export(get_hub_name)
1111
export(get_map_data)
1212
export(included_locations)
13+
export(summarize_ref_date_forecasts)
1314
export(update_authorized_users)
1415
export(update_hub_target_data)
16+
export(write_ref_date_summary)
17+
export(write_ref_date_summary_all)
18+
export(write_ref_date_summary_ensemble)
1519
importFrom(rlang,":=")
1620
importFrom(rlang,.data)

R/get_forecast_data.R

Lines changed: 10 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,10 @@
1-
#' Generate forecast data file containing all forecast hub
2-
#' model submissions.
1+
#' Generate forecast data file containing all forecast hub model submissions
32
#'
43
#' This function fetches all forecast submissions from a
54
#' forecast hub based on the reference date. The forecast
65
#' data is then pivoted to create a wide format with
76
#' quantile levels as columns.
87
#'
9-
#' The resulting file contains the following columns:
10-
#' - `location_name`: full state name (including "US" for
11-
#' the US state)
12-
#' - `abbreviation`: state abbreviation
13-
#' - `horizon`: forecast horizon
14-
#' - `forecast_date`: date the forecast was generated
15-
#' - `target_end_date`: target date for the forecast
16-
#' - `model`: model name
17-
#' - `quantile_*`: forecast values for various quantiles
18-
#' (e.g., 0.025, 0.5, 0.975)
19-
#' - `forecast_teams`: name of the team that generated the
20-
#' model
21-
#' - `forecast_fullnames`: full model name
22-
#'
238
#' @param reference_date character, the reference date for
249
#' the forecast in YYYY-MM-DD format (ISO-8601).
2510
#' @param base_hub_path character, path to the forecast
@@ -51,119 +36,14 @@ get_forecast_data <- function(
5136
output_format = "csv",
5237
targets = NULL
5338
) {
54-
checkmate::assert_choice(disease, choices = c("covid", "rsv"))
55-
checkmate::assert_subset(horizons_to_include, choices = c(-1, 0, 1, 2, 3))
56-
checkmate::assert_character(excluded_locations)
57-
checkmate::assert_choice(output_format, choices = c("csv", "tsv", "parquet"))
58-
checkmate::assert_character(targets, null.ok = TRUE)
59-
60-
reference_date <- lubridate::as_date(reference_date)
61-
62-
model_metadata <- hubData::load_model_metadata(
63-
base_hub_path,
64-
model_ids = NULL
65-
)
66-
67-
hub_content <- hubData::connect_hub(base_hub_path)
68-
69-
current_forecasts <- hub_content |>
70-
dplyr::filter(
71-
.data$reference_date == !!reference_date,
72-
!(.data$location %in% !!excluded_locations),
73-
.data$horizon %in% !!horizons_to_include
74-
) |>
75-
hubData::collect_hub() |>
76-
dplyr::filter(forecasttools::nullable_comparison(
77-
.data$target,
78-
"%in%",
79-
!!targets
80-
))
81-
82-
all_forecasts_data <- forecasttools::pivot_hubverse_quantiles_wider(
83-
hubverse_table = current_forecasts,
84-
pivot_quantiles = c(
85-
"quantile_0.025" = 0.025,
86-
"quantile_0.25" = 0.25,
87-
"quantile_0.5" = 0.5,
88-
"quantile_0.75" = 0.75,
89-
"quantile_0.975" = 0.975
90-
)
91-
) |>
92-
dplyr::mutate(
93-
location_name = forecasttools::us_location_recode(
94-
.data$location,
95-
"hub",
96-
"name"
97-
),
98-
abbreviation = forecasttools::us_location_recode(
99-
.data$location,
100-
"hub",
101-
"abbr"
102-
),
103-
dplyr::across(
104-
tidyselect::starts_with("quantile_"),
105-
round,
106-
.names = "{.col}_rounded"
107-
),
108-
forecast_due_date = as.Date(!!reference_date) - 3,
109-
location_sort_order = ifelse(.data$location_name == "United States", 0, 1)
110-
) |>
111-
dplyr::mutate(
112-
location_name = dplyr::case_match(
113-
.data$location_name,
114-
"United States" ~ "US",
115-
.default = .data$location_name
116-
)
117-
) |>
118-
dplyr::arrange(.data$location_sort_order, .data$location_name) |>
119-
dplyr::left_join(
120-
dplyr::distinct(
121-
model_metadata,
122-
.data$model_id,
123-
.keep_all = TRUE
124-
),
125-
by = "model_id"
126-
) |>
127-
dplyr::select(
128-
"location_name",
129-
"abbreviation",
130-
"horizon",
131-
forecast_date = "reference_date",
132-
"target_end_date",
133-
model = "model_id",
134-
"quantile_0.025",
135-
"quantile_0.25",
136-
"quantile_0.5",
137-
"quantile_0.75",
138-
"quantile_0.975",
139-
"quantile_0.025_rounded",
140-
"quantile_0.25_rounded",
141-
"quantile_0.5_rounded",
142-
"quantile_0.75_rounded",
143-
"quantile_0.975_rounded",
144-
forecast_team = "team_name",
145-
"forecast_due_date",
146-
model_full_name = "model_name"
147-
)
148-
149-
output_folder_path <- fs::path(
150-
hub_reports_path,
151-
"weekly-summaries",
152-
reference_date
153-
)
154-
output_filename <- glue::glue("{reference_date}_{disease}_forecasts_data")
155-
output_filepath <- fs::path(
156-
output_folder_path,
157-
output_filename,
158-
ext = output_format
39+
write_ref_date_summary_all(
40+
reference_date = reference_date,
41+
base_hub_path = base_hub_path,
42+
hub_reports_path = hub_reports_path,
43+
disease = disease,
44+
horizons_to_include = horizons_to_include,
45+
excluded_locations = excluded_locations,
46+
output_format = output_format,
47+
targets = targets
15948
)
160-
161-
fs::dir_create(output_folder_path)
162-
163-
if (!fs::file_exists(output_filepath)) {
164-
forecasttools::write_tabular(all_forecasts_data, output_filepath)
165-
cli::cli_inform("File saved as: {output_filepath}")
166-
} else {
167-
cli::cli_abort("File already exists: {output_filepath}")
168-
}
16949
}

R/get_map_data.R

Lines changed: 10 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
#' Generate map data file containing ensemble forecast
2-
#' data.
1+
#' Generate map data file containing ensemble forecast data
32
#'
43
#' This function loads the latest ensemble forecast data
54
#' from the forecast hub and processes it into the required
@@ -8,37 +7,6 @@
87
#' various forecast horizons, and quantiles (0.025, 0.5,
98
#' and 0.975).
109
#'
11-
#' The ensemble data is expected to contain the following
12-
#' columns:
13-
#' - `reference_date`: the date of the forecast
14-
#' - `location`: state abbreviation
15-
#' - `horizon`: forecast horizon
16-
#' - `target`: forecast target (e.g., "wk inc covid hosp")
17-
#' - `target_end_date`: the forecast target date
18-
#' - `output_type`: type of output (e.g., "quantile")
19-
#' - `output_type_id`: quantile value (e.g., 0.025, 0.5,
20-
#' 0.975)
21-
#' - `value`: forecast value
22-
#'
23-
#' The resulting map file will have the following columns:
24-
#' - `location_name`: full state name (including "US" for
25-
#' the US state)
26-
#' - `quantile_*`: the quantile forecast values (rounded
27-
#' to two decimal places)
28-
#' - `horizon`: forecast horizon
29-
#' - `target`: forecast target (e.g., "7 day ahead inc
30-
#' hosp")
31-
#' - `target_end_date`: target date for the forecast (Ex:
32-
#' 2024-11-30)
33-
#' - `reference_date`: date that the forecast was generated
34-
#' (Ex: 2024-11-23)
35-
#' - `target_end_date_formatted`: target date for the
36-
#' forecast, prettily re-formatted as a string (Ex:
37-
#' "November 30, 2024")
38-
#' - `reference_date_formatted`: date that the forecast
39-
#' was generated, prettily re-formatted as a string
40-
#' (Ex: "November 23, 2024")
41-
#'
4210
#' @param reference_date character, the reference date for
4311
#' the forecast in YYYY-MM-DD format (ISO-8601).
4412
#' @param base_hub_path character, path to the forecast
@@ -67,143 +35,14 @@ get_map_data <- function(
6735
excluded_locations = character(0),
6836
output_format = "csv"
6937
) {
70-
checkmate::assert_choice(disease, choices = c("covid", "rsv"))
71-
checkmate::assert_subset(horizons_to_include, choices = c(-1, 0, 1, 2, 3))
72-
checkmate::assert_data_frame(population_data)
73-
checkmate::assert_names(
74-
colnames(population_data),
75-
must.include = c("location", "population")
76-
)
77-
checkmate::assert_character(excluded_locations)
78-
checkmate::assert_choice(output_format, choices = c("csv", "tsv", "parquet"))
79-
80-
reference_date <- lubridate::as_date(reference_date)
81-
82-
hub_name <- get_hub_name(disease)
83-
ensemble_model_name <- glue::glue("{hub_name}-ensemble")
84-
85-
ensemble_data <- hubData::connect_hub(base_hub_path) |>
86-
dplyr::filter(
87-
.data$reference_date == !!reference_date,
88-
.data$model_id == !!ensemble_model_name
89-
) |>
90-
hubData::collect_hub()
91-
92-
if (nrow(ensemble_data) == 0) {
93-
cli::cli_abort(
94-
glue::glue(
95-
"No ensemble data found for reference date {reference_date} ",
96-
"and model {ensemble_model_name}"
97-
)
98-
)
99-
}
100-
101-
# process ensemble data into the required format for Map file
102-
map_data <- forecasttools::pivot_hubverse_quantiles_wider(
103-
hubverse_table = ensemble_data,
104-
pivot_quantiles = c(
105-
"quantile_0.025" = 0.025,
106-
"quantile_0.25" = 0.25,
107-
"quantile_0.5" = 0.5,
108-
"quantile_0.75" = 0.75,
109-
"quantile_0.975" = 0.975
110-
)
111-
) |>
112-
dplyr::filter(.data$horizon %in% !!horizons_to_include) |>
113-
dplyr::filter(!(.data$location %in% !!excluded_locations)) |>
114-
dplyr::mutate(
115-
reference_date = as.Date(.data$reference_date),
116-
target_end_date = as.Date(.data$target_end_date),
117-
model = !!ensemble_model_name
118-
) |>
119-
# convert location column codes to full location names
120-
dplyr::mutate(
121-
location = forecasttools::us_location_recode(
122-
.data$location,
123-
"hub",
124-
"name"
125-
)
126-
) |>
127-
# long name "United States" to "US"
128-
dplyr::mutate(
129-
location = dplyr::case_match(
130-
.data$location,
131-
"United States" ~ "US",
132-
.default = .data$location
133-
),
134-
# sort locations alphabetically, except for US
135-
location_sort_order = ifelse(.data$location == "US", 0, 1)
136-
) |>
137-
dplyr::arrange(.data$location_sort_order, .data$location) |>
138-
dplyr::left_join(
139-
population_data,
140-
by = "location"
141-
) |>
142-
dplyr::mutate(
143-
population = as.numeric(.data$population),
144-
quantile_0.025_per100k = .data$quantile_0.025 / .data$population * 100000,
145-
quantile_0.5_per100k = .data$quantile_0.5 / .data$population * 100000,
146-
quantile_0.975_per100k = .data$quantile_0.975 / .data$population * 100000,
147-
quantile_0.025_count = .data$quantile_0.025,
148-
quantile_0.5_count = .data$quantile_0.5,
149-
quantile_0.975_count = .data$quantile_0.975,
150-
quantile_0.025_per100k_rounded = round(.data$quantile_0.025_per100k, 2),
151-
quantile_0.5_per100k_rounded = round(.data$quantile_0.5_per100k, 2),
152-
quantile_0.975_per100k_rounded = round(.data$quantile_0.975_per100k, 2),
153-
quantile_0.025_count_rounded = round(.data$quantile_0.025_count),
154-
quantile_0.5_count_rounded = round(.data$quantile_0.5_count),
155-
quantile_0.975_count_rounded = round(.data$quantile_0.975_count),
156-
target_end_date_formatted = format(.data$target_end_date, "%B %d, %Y"),
157-
reference_date_formatted = format(.data$reference_date, "%B %d, %Y"),
158-
forecast_due_date = as.Date(!!reference_date) - 3,
159-
forecast_due_date_formatted = format(
160-
.data$forecast_due_date,
161-
"%B %d, %Y"
162-
),
163-
) |>
164-
dplyr::select(
165-
location_name = "location",
166-
"horizon",
167-
"quantile_0.025_per100k",
168-
"quantile_0.5_per100k",
169-
"quantile_0.975_per100k",
170-
"quantile_0.025_count",
171-
"quantile_0.5_count",
172-
"quantile_0.975_count",
173-
"quantile_0.025_per100k_rounded",
174-
"quantile_0.5_per100k_rounded",
175-
"quantile_0.975_per100k_rounded",
176-
"quantile_0.025_count_rounded",
177-
"quantile_0.5_count_rounded",
178-
"quantile_0.975_count_rounded",
179-
"target",
180-
"target_end_date",
181-
"reference_date",
182-
"forecast_due_date",
183-
"target_end_date_formatted",
184-
"forecast_due_date_formatted",
185-
"reference_date_formatted",
186-
"model",
187-
)
188-
189-
output_folder_path <- fs::path(
190-
hub_reports_path,
191-
"weekly-summaries",
192-
reference_date
193-
)
194-
output_filename <- glue::glue("{reference_date}_{disease}_map_data")
195-
output_filepath <- fs::path(
196-
output_folder_path,
197-
output_filename,
198-
ext = output_format
38+
write_ref_date_summary_ensemble(
39+
reference_date = reference_date,
40+
base_hub_path = base_hub_path,
41+
hub_reports_path = hub_reports_path,
42+
disease = disease,
43+
horizons_to_include = horizons_to_include,
44+
population_data = population_data,
45+
excluded_locations = excluded_locations,
46+
output_format = output_format
19947
)
200-
201-
fs::dir_create(output_folder_path)
202-
203-
if (!fs::file_exists(output_filepath)) {
204-
forecasttools::write_tabular(map_data, output_filepath)
205-
cli::cli_inform("File saved as: {output_filepath}")
206-
} else {
207-
cli::cli_abort("File already exists: {output_filepath}")
208-
}
20948
}

0 commit comments

Comments
 (0)