Skip to content

Commit 07934b8

Browse files
committed
Use tmp_path for storing plots when testing gaussian model example
1 parent 5ca8463 commit 07934b8

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

example/observations_station_gaussian_model.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
""" # Noqa:D205,D400
1616
import logging
1717
import os
18+
from pathlib import Path
1819
from typing import Tuple
1920

2021
import matplotlib.pyplot as plt
@@ -30,6 +31,8 @@
3031
DwdObservationResolution,
3132
)
3233

34+
HERE = Path(__file__).parent
35+
3336
log = logging.getLogger()
3437

3538
try:
@@ -63,7 +66,7 @@ class ModelYearlyGaussians:
6366
6467
"""
6568

66-
def __init__(self, station_data: StationsResult):
69+
def __init__(self, station_data: StationsResult, plot_path: Path):
6770
self._station_data = station_data
6871

6972
result_values = station_data.values.all().df.drop_nulls()
@@ -81,7 +84,7 @@ def __init__(self, station_data: StationsResult):
8184

8285
log.info(f"Fit Result message: {out.result.message}")
8386

84-
self.plot_data_and_model(valid_data, out, savefig_to_file=True)
87+
self.plot_data_and_model(valid_data, out, savefig_to_file=True, plot_path=plot_path)
8588

8689
def get_valid_data(self, result_values: pl.DataFrame) -> pl.DataFrame:
8790
valid_data_lst = []
@@ -137,7 +140,7 @@ def model_pars_update(
137140

138141
return pars
139142

140-
def plot_data_and_model(self, valid_data: pl.DataFrame, out: ModelResult, savefig_to_file=True) -> None:
143+
def plot_data_and_model(self, valid_data: pl.DataFrame, out: ModelResult, savefig_to_file, plot_path: Path) -> None:
141144
"""plots the data and the model"""
142145
if savefig_to_file:
143146
_ = plt.subplots(figsize=(12, 12))
@@ -153,21 +156,21 @@ def plot_data_and_model(self, valid_data: pl.DataFrame, out: ModelResult, savefi
153156
if savefig_to_file:
154157
number_of_years = valid_data.get_column("date").dt.year().n_unique()
155158
filename = f"{self.__class__.__qualname__}_wetter_model_{number_of_years}"
156-
plt.savefig(filename, dpi=300, bbox_inches="tight")
159+
plt.savefig(plot_path / filename, dpi=300, bbox_inches="tight")
157160
log.info("saved fig to file: " + filename)
158161
if "PYTEST_CURRENT_TEST" not in os.environ:
159162
plt.show()
160163

161164

162-
def main():
165+
def main(plot_path=HERE):
163166
"""Run example."""
164167
logging.basicConfig(level=logging.INFO)
165168

166169
station_data_one_year = station_example(start_date="2020-12-25", end_date="2022-01-01")
167-
_ = ModelYearlyGaussians(station_data_one_year)
170+
_ = ModelYearlyGaussians(station_data_one_year, plot_path=plot_path)
168171

169172
station_data_many_years = station_example(start_date="1995-12-25", end_date="2022-12-31")
170-
_ = ModelYearlyGaussians(station_data_many_years)
173+
_ = ModelYearlyGaussians(station_data_many_years, plot_path=plot_path)
171174

172175

173176
if __name__ == "__main__":

tests/example/test_regular_examples.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ def test_pdbufr_examples():
3535

3636
@pytest.mark.skipif(IS_CI and IS_LINUX, reason="stalls on Mac/Windows in CI")
3737
@pytest.mark.cflake
38-
def test_gaussian_example():
38+
def test_gaussian_example(tmp_path):
3939
from example import observations_station_gaussian_model
4040

41-
assert observations_station_gaussian_model.main() is None
41+
assert observations_station_gaussian_model.main(tmp_path) is None
4242

4343

4444
# @pytest.mark.skipif(IS_CI, reason="radar examples not working in CI")

0 commit comments

Comments
 (0)