15
15
""" # Noqa:D205,D400
16
16
import logging
17
17
import os
18
+ from pathlib import Path
18
19
from typing import Tuple
19
20
20
21
import matplotlib .pyplot as plt
30
31
DwdObservationResolution ,
31
32
)
32
33
34
+ HERE = Path (__file__ ).parent
35
+
33
36
log = logging .getLogger ()
34
37
35
38
try :
@@ -63,7 +66,7 @@ class ModelYearlyGaussians:
63
66
64
67
"""
65
68
66
- def __init__ (self , station_data : StationsResult ):
69
+ def __init__ (self , station_data : StationsResult , plot_path : Path ):
67
70
self ._station_data = station_data
68
71
69
72
result_values = station_data .values .all ().df .drop_nulls ()
@@ -81,7 +84,7 @@ def __init__(self, station_data: StationsResult):
81
84
82
85
log .info (f"Fit Result message: { out .result .message } " )
83
86
84
- self .plot_data_and_model (valid_data , out , savefig_to_file = True )
87
+ self .plot_data_and_model (valid_data , out , savefig_to_file = True , plot_path = plot_path )
85
88
86
89
def get_valid_data (self , result_values : pl .DataFrame ) -> pl .DataFrame :
87
90
valid_data_lst = []
@@ -137,7 +140,7 @@ def model_pars_update(
137
140
138
141
return pars
139
142
140
- def plot_data_and_model (self , valid_data : pl .DataFrame , out : ModelResult , savefig_to_file = True ) -> None :
143
+ def plot_data_and_model (self , valid_data : pl .DataFrame , out : ModelResult , savefig_to_file , plot_path : Path ) -> None :
141
144
"""plots the data and the model"""
142
145
if savefig_to_file :
143
146
_ = plt .subplots (figsize = (12 , 12 ))
@@ -153,21 +156,21 @@ def plot_data_and_model(self, valid_data: pl.DataFrame, out: ModelResult, savefi
153
156
if savefig_to_file :
154
157
number_of_years = valid_data .get_column ("date" ).dt .year ().n_unique ()
155
158
filename = f"{ self .__class__ .__qualname__ } _wetter_model_{ number_of_years } "
156
- plt .savefig (filename , dpi = 300 , bbox_inches = "tight" )
159
+ plt .savefig (plot_path / filename , dpi = 300 , bbox_inches = "tight" )
157
160
log .info ("saved fig to file: " + filename )
158
161
if "PYTEST_CURRENT_TEST" not in os .environ :
159
162
plt .show ()
160
163
161
164
162
- def main ():
165
+ def main (plot_path = HERE ):
163
166
"""Run example."""
164
167
logging .basicConfig (level = logging .INFO )
165
168
166
169
station_data_one_year = station_example (start_date = "2020-12-25" , end_date = "2022-01-01" )
167
- _ = ModelYearlyGaussians (station_data_one_year )
170
+ _ = ModelYearlyGaussians (station_data_one_year , plot_path = plot_path )
168
171
169
172
station_data_many_years = station_example (start_date = "1995-12-25" , end_date = "2022-12-31" )
170
- _ = ModelYearlyGaussians (station_data_many_years )
173
+ _ = ModelYearlyGaussians (station_data_many_years , plot_path = plot_path )
171
174
172
175
173
176
if __name__ == "__main__" :
0 commit comments