From db49aa7b2d590372435e20b22025a9af6037cd24 Mon Sep 17 00:00:00 2001 From: butterman0 Date: Wed, 9 Apr 2025 16:56:41 +0200 Subject: [PATCH 1/3] set is_fitted = True after .fit call --- pymc_extras/model_builder.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pymc_extras/model_builder.py b/pymc_extras/model_builder.py index 6e712e5d7..4fbf98f62 100644 --- a/pymc_extras/model_builder.py +++ b/pymc_extras/model_builder.py @@ -524,6 +524,8 @@ def fit( ) self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore + self.is_fitted_ = True + return self.idata # type: ignore def predict( From edec5395359e4e38f7a2820c3478e620a358c31a Mon Sep 17 00:00:00 2001 From: butterman0 Date: Wed, 9 Apr 2025 16:57:52 +0200 Subject: [PATCH 2/3] is_fitted = True when trained model is loaded --- pymc_extras/model_builder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymc_extras/model_builder.py b/pymc_extras/model_builder.py index 4fbf98f62..5f22013d5 100644 --- a/pymc_extras/model_builder.py +++ b/pymc_extras/model_builder.py @@ -444,6 +444,7 @@ def load(cls, fname: str): sampler_config=json.loads(idata.attrs["sampler_config"]), ) model.idata = idata + model.is_fitted_ = True dataset = idata.fit_data.to_dataframe() X = dataset.drop(columns=[model.output_var]) y = dataset[model.output_var] From 93c5ed61b1253186d867fc0963513f56f792bcc3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 25 Apr 2025 07:38:36 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pymc_extras/model_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_extras/model_builder.py b/pymc_extras/model_builder.py index 5f22013d5..be672db7b 100644 --- a/pymc_extras/model_builder.py +++ b/pymc_extras/model_builder.py @@ -526,7 +526,7 @@ def fit( self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore self.is_fitted_ = True - + return self.idata # type: ignore def predict(