|
9 | 9 | from sklearn.utils.validation import check_X_y
|
10 | 10 | from tqdm.auto import tqdm
|
11 | 11 |
|
12 |
| -from autoemulate.cross_validate import _run_cv, _sum_cv, _sum_cvs |
| 12 | +from autoemulate.cross_validate import _run_cv |
| 13 | +from autoemulate.cross_validate import _sum_cv |
| 14 | +from autoemulate.cross_validate import _sum_cvs |
13 | 15 | from autoemulate.data_splitting import _split_data
|
14 | 16 | from autoemulate.emulators import model_registry
|
15 | 17 | from autoemulate.hyperparam_searching import _optimize_params
|
16 | 18 | from autoemulate.logging_config import _configure_logging
|
17 | 19 | from autoemulate.metrics import METRIC_REGISTRY
|
18 | 20 | from autoemulate.model_processing import AutoEmulatePipeline
|
19 |
| -from autoemulate.plotting import _plot_cv, _plot_model |
20 |
| -from autoemulate.preprocess_target import NonTrainableTransformer, get_dim_reducer |
| 21 | +from autoemulate.plotting import _plot_cv |
| 22 | +from autoemulate.plotting import _plot_model |
| 23 | +from autoemulate.preprocess_target import get_dim_reducer |
| 24 | +from autoemulate.preprocess_target import NonTrainableTransformer |
21 | 25 | from autoemulate.printing import _print_setup
|
22 | 26 | from autoemulate.save import ModelSerialiser
|
23 |
| -from autoemulate.sensitivity_analysis import ( |
24 |
| - _plot_morris_analysis, |
25 |
| - _plot_sobol_analysis, |
26 |
| - _sensitivity_analysis, |
27 |
| -) |
28 |
| -from autoemulate.utils import ( |
29 |
| - _check_cv, |
30 |
| - _ensure_2d, |
31 |
| - _get_full_model_name, |
32 |
| - _redirect_warnings, |
33 |
| - get_model_name, |
34 |
| - get_short_model_name, |
35 |
| -) |
| 27 | +from autoemulate.sensitivity_analysis import _plot_morris_analysis |
| 28 | +from autoemulate.sensitivity_analysis import _plot_sobol_analysis |
| 29 | +from autoemulate.sensitivity_analysis import _sensitivity_analysis |
| 30 | +from autoemulate.utils import _check_cv |
| 31 | +from autoemulate.utils import _ensure_2d |
| 32 | +from autoemulate.utils import _get_full_model_name |
| 33 | +from autoemulate.utils import _redirect_warnings |
| 34 | +from autoemulate.utils import get_model_name |
| 35 | +from autoemulate.utils import get_short_model_name |
36 | 36 |
|
37 | 37 |
|
38 | 38 | class AutoEmulate:
|
@@ -370,10 +370,10 @@ def compare(self):
|
370 | 370 | pbar.update(1)
|
371 | 371 |
|
372 | 372 | # Get best model for this preprocessing method
|
373 |
| - self.preprocessing_results[prep_name]["best_model"] = ( |
374 |
| - self.get_best_model_for_prep( |
375 |
| - prep_results=self.preprocessing_results[prep_name], metric="r2" |
376 |
| - ) |
| 373 | + self.preprocessing_results[prep_name][ |
| 374 | + "best_model" |
| 375 | + ] = self.get_best_model_for_prep( |
| 376 | + prep_results=self.preprocessing_results[prep_name], metric="r2" |
377 | 377 | )
|
378 | 378 |
|
379 | 379 | # Find the overall best model and preprocessing method
|
|
0 commit comments