Skip to content

Commit a6c39bf

Browse files
committed
Merge remote-tracking branch 'origin/main' into 37-add-get_webtextr
2 parents 599a693 + cfad627 commit a6c39bf

File tree

9 files changed

+270
-12
lines changed

9 files changed

+270
-12
lines changed

.github/workflows/jarl-check.yaml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
name: "Lint R Code With JARL"
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
pull_request:
7+
branches: [ main ]
8+
9+
jobs:
10+
lint:
11+
runs-on: ubuntu-latest
12+
steps:
13+
- uses: actions/checkout@v6
14+
- uses: etiennebacher/[email protected]
15+
with:
16+
args: check . --fix --output-format github

.lintr

Lines changed: 0 additions & 2 deletions
This file was deleted.

.pre-commit-config.yaml

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ repos:
1313
#####
1414
# Python
1515
- repo: https://github.com/astral-sh/ruff-pre-commit
16-
rev: v0.14.6
16+
rev: v0.14.8
1717
hooks:
1818
# Sort imports
1919
- id: ruff
@@ -25,12 +25,6 @@ repos:
2525
- id: ruff-format
2626
args: ['--line-length', '79']
2727
#####
28-
# R
29-
- repo: https://github.com/lorenzwalthert/precommit
30-
rev: v0.4.3.9017
31-
hooks:
32-
- id: lintr
33-
#####
3428
# Java
3529
- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks
3630
rev: v2.15.0

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ export(excluded_locations)
66
export(generate_hub_baseline)
77
export(generate_hub_ensemble)
88
export(generate_oracle_output)
9+
export(get_forecast_data)
910
export(get_hub_name)
1011
export(get_map_data)
1112
export(get_webtext)

R/generate_hub_baselines.R

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,16 @@ make_baseline_forecast <- function(
181181
#' @param base_hub_path Path to the base hub directory.
182182
#' @param reference_date Reference date (should be a Saturday).
183183
#' @param disease Disease name ("covid" or "rsv").
184+
#' @param as_of As of date to filter to, as an object
185+
#' coercible by as.Date(), or "latest" to filter to the
186+
#' most recent available vintage. Default "latest".
184187
#' @return NULL. Writes baseline forecast file to hub's model-output directory.
185188
#' @export
186189
generate_hub_baseline <- function(
187190
base_hub_path,
188191
reference_date,
189-
disease
192+
disease,
193+
as_of = "latest"
190194
) {
191195
checkmate::assert_scalar(disease)
192196
checkmate::assert_names(disease, subset.of = c("covid", "rsv"))
@@ -213,7 +217,7 @@ generate_hub_baseline <- function(
213217

214218
hub_target_data <- hubData::connect_target_timeseries(base_hub_path) |>
215219
dplyr::collect() |>
216-
forecasttools::hub_target_data_as_of()
220+
forecasttools::hub_target_data_as_of(as_of)
217221

218222
preds_hosp <- make_baseline_forecast(
219223
target_data = hub_target_data,

R/get_forecast_data.R

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
#' Generate forecast data file containing all forecast hub
2+
#' model submissions.
3+
#'
4+
#' This function fetches all forecast submissions from a
5+
#' forecast hub based on the reference date. The forecast
6+
#' data is then pivoted to create a wide format with
7+
#' quantile levels as columns.
8+
#'
9+
#' The resulting file contains the following columns:
10+
#' - `location_name`: full state name (including "US" for
11+
#' the US state)
12+
#' - `abbreviation`: state abbreviation
13+
#' - `horizon`: forecast horizon
14+
#' - `forecast_date`: date the forecast was generated
15+
#' - `target_end_date`: target date for the forecast
16+
#' - `model`: model name
17+
#' - `quantile_*`: forecast values for various quantiles
18+
#' (e.g., 0.025, 0.5, 0.975)
19+
#' - `forecast_teams`: name of the team that generated the
20+
#' model
21+
#' - `forecast_fullnames`: full model name
22+
#'
23+
#' @param reference_date character, the reference date for
24+
#' the forecast in YYYY-MM-DD format (ISO-8601).
25+
#' @param base_hub_path character, path to the forecast
26+
#' hub directory.
27+
#' @param hub_reports_path character, path to forecast hub
28+
#' reports directory.
29+
#' @param disease character, disease name ("covid" or
30+
#' "rsv"). Used to derive target name and file prefix.
31+
#' @param horizons_to_include integer vector, horizons to
32+
#' include in the output. Default: c(0, 1, 2).
33+
#' @param excluded_locations character vector of location
34+
#' codes to exclude from the output. Default: character(0).
35+
#' @param output_format character, output file format. One
36+
#' of "csv", "tsv", or "parquet". Default: "csv".
37+
#' @param targets character vector, target name(s) to filter
38+
#' forecasts. If NULL (default), does not filter by target.
39+
#' Can be a single target like "wk inc covid hosp" or
40+
#' multiple targets like c("wk inc covid hosp", "wk inc
41+
#' covid prop ed visits").
42+
#'
43+
#' @export
44+
get_forecast_data <- function(
45+
reference_date,
46+
base_hub_path,
47+
hub_reports_path,
48+
disease,
49+
horizons_to_include = c(0, 1, 2),
50+
excluded_locations = character(0),
51+
output_format = "csv",
52+
targets = NULL
53+
) {
54+
checkmate::assert_choice(disease, choices = c("covid", "rsv"))
55+
checkmate::assert_subset(horizons_to_include, choices = c(-1, 0, 1, 2, 3))
56+
checkmate::assert_character(excluded_locations)
57+
checkmate::assert_choice(output_format, choices = c("csv", "tsv", "parquet"))
58+
checkmate::assert_character(targets, null.ok = TRUE)
59+
60+
reference_date <- lubridate::as_date(reference_date)
61+
62+
model_metadata <- hubData::load_model_metadata(
63+
base_hub_path,
64+
model_ids = NULL
65+
)
66+
67+
hub_content <- hubData::connect_hub(base_hub_path)
68+
69+
current_forecasts <- hub_content |>
70+
dplyr::filter(
71+
.data$reference_date == !!reference_date,
72+
!(.data$location %in% !!excluded_locations),
73+
.data$horizon %in% !!horizons_to_include
74+
) |>
75+
hubData::collect_hub() |>
76+
dplyr::filter(forecasttools::nullable_comparison(
77+
.data$target,
78+
"%in%",
79+
!!targets
80+
))
81+
82+
all_forecasts_data <- forecasttools::pivot_hubverse_quantiles_wider(
83+
hubverse_table = current_forecasts,
84+
pivot_quantiles = c(
85+
"quantile_0.025" = 0.025,
86+
"quantile_0.25" = 0.25,
87+
"quantile_0.5" = 0.5,
88+
"quantile_0.75" = 0.75,
89+
"quantile_0.975" = 0.975
90+
)
91+
) |>
92+
dplyr::mutate(
93+
location_name = forecasttools::us_location_recode(
94+
.data$location,
95+
"hub",
96+
"name"
97+
),
98+
abbreviation = forecasttools::us_location_recode(
99+
.data$location,
100+
"hub",
101+
"abbr"
102+
),
103+
dplyr::across(
104+
tidyselect::starts_with("quantile_"),
105+
round,
106+
.names = "{.col}_rounded"
107+
),
108+
forecast_due_date = as.Date(!!reference_date) - 3,
109+
location_sort_order = ifelse(.data$location_name == "United States", 0, 1)
110+
) |>
111+
dplyr::mutate(
112+
location_name = dplyr::case_match(
113+
.data$location_name,
114+
"United States" ~ "US",
115+
.default = .data$location_name
116+
)
117+
) |>
118+
dplyr::arrange(.data$location_sort_order, .data$location_name) |>
119+
dplyr::left_join(
120+
dplyr::distinct(
121+
model_metadata,
122+
.data$model_id,
123+
.keep_all = TRUE
124+
),
125+
by = "model_id"
126+
) |>
127+
dplyr::select(
128+
"location_name",
129+
"abbreviation",
130+
"horizon",
131+
forecast_date = "reference_date",
132+
"target_end_date",
133+
model = "model_id",
134+
"quantile_0.025",
135+
"quantile_0.25",
136+
"quantile_0.5",
137+
"quantile_0.75",
138+
"quantile_0.975",
139+
"quantile_0.025_rounded",
140+
"quantile_0.25_rounded",
141+
"quantile_0.5_rounded",
142+
"quantile_0.75_rounded",
143+
"quantile_0.975_rounded",
144+
forecast_team = "team_name",
145+
"forecast_due_date",
146+
model_full_name = "model_name"
147+
)
148+
149+
output_folder_path <- fs::path(
150+
hub_reports_path,
151+
"weekly-summaries",
152+
reference_date
153+
)
154+
output_filename <- glue::glue("{reference_date}_{disease}_forecasts_data")
155+
output_filepath <- fs::path(
156+
output_folder_path,
157+
output_filename,
158+
ext = output_format
159+
)
160+
161+
fs::dir_create(output_folder_path)
162+
163+
if (!fs::file_exists(output_filepath)) {
164+
forecasttools::write_tabular(all_forecasts_data, output_filepath)
165+
cli::cli_inform("File saved as: {output_filepath}")
166+
} else {
167+
cli::cli_abort("File already exists: {output_filepath}")
168+
}
169+
}

jarl.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[lint]
2+
default-exclude = true
3+
assignment = "<-"

man/generate_hub_baseline.Rd

Lines changed: 5 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/get_forecast_data.Rd

Lines changed: 69 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)