From 98755e5e0f5d10f972549130c0d163594f5083fa Mon Sep 17 00:00:00 2001 From: Jochen Sieg Date: Tue, 17 Jun 2025 15:55:57 +0200 Subject: [PATCH] wip: add ghost and adapt post-processing --- .../ghost_classification_threshold.py | 196 ++++++++++++++++++ molpipeline/pipeline/_skl_pipeline.py | 10 + molpipeline/post_prediction.py | 11 + pyproject.toml | 1 + tests/test_elements/test_post_prediction.py | 19 ++ .../test_ghost_classification_threshold.py | 123 +++++++++++ 6 files changed, 360 insertions(+) create mode 100644 molpipeline/estimators/ghost_classification_threshold.py create mode 100644 tests/test_estimators/test_ghost_classification_threshold.py diff --git a/molpipeline/estimators/ghost_classification_threshold.py b/molpipeline/estimators/ghost_classification_threshold.py new file mode 100644 index 00000000..37ab75f2 --- /dev/null +++ b/molpipeline/estimators/ghost_classification_threshold.py @@ -0,0 +1,196 @@ +"""Classification threshold adjustment with GHOST.""" + +from typing import Any, Self, Literal, override + +import ghostml +from sklearn.utils._random import check_random_state +from sklearn.base import BaseEstimator, TransformerMixin + +from molpipeline import FilterReinserter +from molpipeline.post_prediction import PostPredictionWrapper +import numpy as np +import numpy.typing as npt + +from molpipeline.utils.molpipeline_types import AnyPredictor + + +class Ghost(BaseEstimator, TransformerMixin): + """GHOST estimator wrapper for classification threshold adjustment. + + Applies the GHOST (Generalized tHreshOld ShifTing) algorithm to adjust the + classification threshold. See the paper and repository for more details: + + Esposito, C., Landrum, G.A., Schneider, N., Stiefl, N. and Riniker, S., 2021. + "GHOST: adjusting the decision threshold to handle imbalanced data in machine + learning." + Journal of Chemical Information and Modeling, 61(6), pp.2623-2640. + https://doi.org/10.1021/acs.jcim.1c00160 + + https://github.com/rinikerlab/GHOST + + """ + + # TODO muss ich die Params fuer GHOST auf MolPipeline Seite eigentlich validieren? mach das vllt ghost besser? + # - Wenn ichs mache passiert es at construction time. + # - Dafuer muss ich die Params und validation in sync mit ghostml halten. + def __init__( + self, + thresholds: list[float] | None = None, + optimization_metric: Literal["Kappa", "ROC"] = "Kappa", + random_state: int | None = None, + ): + """ + Initialize the GHOST post-prediction wrapper. + + Parameters + ---------- + threshold : float, optional + Classification threshold to apply, by default 0.5. + + """ + if thresholds is None: + # use default bins from GHOST paper + thresholds = list(np.round(np.arange(0.05, 0.55, 0.05), 2)) + self._check_thresholds(thresholds) + self.thresholds = thresholds + self._check_optimization_metric(optimization_metric) + self.optimization_metric = optimization_metric + self.random_seed = self._get_random_seed_from_input(random_state) + self.decision_threshold: float | None = None + + @staticmethod + def _check_optimization_metric( + optimization_metric: Literal["Kappa", "ROC"], + ) -> None: + """Check if the optimization metric is valid.""" + if optimization_metric not in {"Kappa", "ROC"}: + raise ValueError( + "optimization_metric must be either 'Kappa' or 'ROC'", + ) + + @staticmethod + def _get_random_seed_from_input(random_state: int | None) -> int: + """Get a random seed from the input data. + + GHOST expects an integer random seed, so we generate one if not provided. + """ + rng = check_random_state(random_state) + return rng.randint(0, np.iinfo(np.int32).max) + + @staticmethod + def _check_thresholds(thresholds: list[float]) -> None: + """Check if the thresholds are valid.""" + if len(thresholds) == 0: + raise ValueError("Thresholds must not be empty.") + if not all(0 <= t <= 1 for t in thresholds): + raise ValueError("All thresholds must be between 0 and 1.") + if len(set(thresholds)) != len(thresholds): + raise ValueError("Thresholds must be unique.") + + def _check_and_process_X(self, X: npt.NDArray[Any]) -> npt.NDArray[Any]: + """Check and process the input predictions.""" + y_pred = X + if y_pred.ndim == 2: + # assume binary classification output when it's a 2D array + # take class probabilities for class 1 + y_pred = y_pred[:, 1] + if y_pred.ndim != 1: + raise ValueError("X must be a 1D or 2D array.") + if not np.all((y_pred >= 0) & (y_pred <= 1)): + raise ValueError("All values in X must be between 0 and 1.") + return y_pred + + def fit( + self, + X: npt.NDArray[Any], # pylint: disable=invalid-name + y: npt.NDArray[Any] | None = None, + ) -> Self: + """Fit the GHOST post-prediction wrapper. + + Prepares the decision threshold based on the predictions. + + Parameters + ---------- + X : npt.NDArray[Any] + Input data. The predictions. + y : npt.NDArray[Any] | None, optional + Target data. The true labels. + + """ + # TODO vielleicht noch nan etc auf y_pred_proba rauswerfen? + + y_pred = X + y_true = y + + y_pred = self._check_and_process_X(y_pred) + if y_true is None: + raise ValueError("y must be provided for fitting the GHOST wrapper.") + if not np.all(np.isin(y_true, [0, 1])): + raise ValueError("y must be binary (0 or 1).") + + self.decision_threshold = ghostml.optimize_threshold_from_predictions( + y_true, + y_pred, + thresholds=self.thresholds, + ThOpt_metrics=self.optimization_metric, + random_seed=self.random_seed, + ) + + return self + + def transform( + self, + X: npt.NDArray[Any], # pylint: disable=invalid-name + ) -> npt.NDArray[np.int64]: + if self.decision_threshold is None: + raise ValueError("Call fit first before calling transform.") + + y_pred = X + y_pred = self._check_and_process_X(y_pred) + + # TODO vielleicht noch nan etc auf y_pred_proba rauswerfen? + return (y_pred > self.decision_threshold).astype(np.int64) + + +class GhostPostPredictionWrapper(PostPredictionWrapper): + """Post-prediction wrapper for GHOST classification threshold adjustment.""" + + def __init__( + self, + thresholds: list[float] | None = None, + optimization_metric: Literal["Kappa", "ROC"] = "Kappa", + random_state: int | None = None, + ): + super().__init__( + wrapped_estimator=Ghost( + thresholds=thresholds, + optimization_metric=optimization_metric, + random_state=random_state, + ) + ) + + @staticmethod + def _check_estimator(estimator: AnyPredictor) -> None: + """Check if the wrapped estimator has a predict_proba method.""" + if not hasattr(estimator, "predict_proba"): + raise ValueError( + f"GHOST requires an estimator with a predict_proba method. Got: {estimator}" + ) + + @override + def prepare_input( + self, + X: npt.NDArray[np.float64], + y: npt.NDArray[np.float64], + final_estimator: AnyPredictor, + ) -> tuple[npt.NDArray[Any], npt.NDArray[Any]]: + """Prepare input data for fitting.""" + self._check_estimator(final_estimator) + y_pred_proba = final_estimator.predict_proba(X) + if y_pred_proba.ndim == 2 and y_pred_proba.shape[1] == 2: + # binary classification, take probabilities for class 1 + y_pred_proba = y_pred_proba[:, 1] + elif y_pred_proba.ndim != 1: + raise ValueError("predict_proba must return a 1D or 2D array.") + + return y_pred_proba, y diff --git a/molpipeline/pipeline/_skl_pipeline.py b/molpipeline/pipeline/_skl_pipeline.py index a4029fb8..571940b5 100644 --- a/molpipeline/pipeline/_skl_pipeline.py +++ b/molpipeline/pipeline/_skl_pipeline.py @@ -531,6 +531,13 @@ def fit(self, X: Any, y: Any = None, **fit_params: Any) -> Self: ] self._final_estimator.fit(Xt, yt, **fit_params_last_step["fit"]) + # fit post-processing steps + for _, post_element in self._post_processing_steps(): + X_post_pred, y_post_pred = post_element.prepare_input( + Xt, yt, self._final_estimator + ) + post_element.fit(X_post_pred, y_post_pred) + return self def _can_fit_transform(self) -> bool: @@ -624,6 +631,9 @@ def fit_transform(self, X: Any, y: Any = None, **params: Any) -> Any: f"match fit_transform of Pipeline {self.__class__.__name__}" ) for _, post_element in self._post_processing_steps(): + iter_input, iter_label = post_element.prepare_input( + iter_input, iter_label, self._final_estimator + ) iter_input = post_element.fit_transform(iter_input, iter_label) return iter_input diff --git a/molpipeline/post_prediction.py b/molpipeline/post_prediction.py index 38a7b65a..4abb7a2d 100644 --- a/molpipeline/post_prediction.py +++ b/molpipeline/post_prediction.py @@ -42,6 +42,17 @@ def transform(self, X: Any, **params: Any) -> Any: # pylint: disable=invalid-na Transformed data. """ + def prepare_input( + self, X, y, final_estimator: AnyPredictor + ) -> tuple[npt.NDArray[Any], npt.NDArray[Any]]: + """Placeholder method for additional functionality. + + Use this function to prepare the input data for fitting. + + This method can be overridden in subclasses to implement specific behavior. + """ + return X, y + class PostPredictionWrapper(PostPredictionTransformation): """Wrapper for post prediction transformations. diff --git a/pyproject.toml b/pyproject.toml index 2eb94220..c9c080d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ authors = [ description = "Integration of rdkit functionality into sklearn pipelines." readme = "README.md" dependencies = [ + "ghostml>=0.3.0", "joblib>=1.3.0", "loguru>=0.7.3", "matplotlib>=3.10.1", diff --git a/tests/test_elements/test_post_prediction.py b/tests/test_elements/test_post_prediction.py index 74e483b8..66e70f42 100644 --- a/tests/test_elements/test_post_prediction.py +++ b/tests/test_elements/test_post_prediction.py @@ -7,6 +7,7 @@ from sklearn.decomposition import PCA from sklearn.ensemble import RandomForestClassifier +from molpipeline import Pipeline from molpipeline.post_prediction import PostPredictionWrapper @@ -88,3 +89,21 @@ def test_inverse_transform(self) -> None: self.assertEqual(pca_inverse.shape, ppw_inverse.shape) self.assertTrue(np.allclose(pca_inverse, ppw_inverse)) + + # def test_pipeline_integration(self) -> None: + # """Test integration with a pipeline.""" + # + # rng = np.random.default_rng(20240918) + # features = rng.random((10, 5)) + # y = rng.integers(0, 2, size=(10,)) + # + # pipeline = Pipeline( + # [ + # ("rf", RandomForestClassifier(n_estimators=10)), + # ("pca", PostPredictionWrapper(PCA(n_components=1))), + # ] + # ) + # + # pipeline.fit(features, y) + # ppw_transformed = pipeline.predict(features) + # self.assertEqual(ppw_transformed.shape, (10, 3)) diff --git a/tests/test_estimators/test_ghost_classification_threshold.py b/tests/test_estimators/test_ghost_classification_threshold.py new file mode 100644 index 00000000..b2cee5c5 --- /dev/null +++ b/tests/test_estimators/test_ghost_classification_threshold.py @@ -0,0 +1,123 @@ +"""Tests for the GhostClassificationThreshold class.""" + +import unittest + +import numpy as np +from sklearn.datasets import make_classification +from sklearn.model_selection import train_test_split +from sklearn.linear_model import LogisticRegression + +from molpipeline import Pipeline, PostPredictionWrapper +from molpipeline.estimators.ghost_classification_threshold import ( + Ghost, + GhostPostPredictionWrapper, +) + + +class TestGhostClassificationThreshold(unittest.TestCase): + """Tests for the GhostClassificationThreshold class.""" + + def setUp(self): + """Set up test data and objects.""" + # Create a simple binary classification dataset + self.X, self.y = make_classification( + n_samples=100, n_features=5, n_classes=2, random_state=42 + ) + self.X_train, self.X_test, self.y_train, self.y_test = train_test_split( + self.X, self.y, test_size=0.2, random_state=42 + ) + # Create a simple classifier + self.clf = LogisticRegression(random_state=42) + self.clf.fit(self.X_train, self.y_train) + self.y_pred_train = self.clf.predict_proba(self.X_train) + + def test_init_default_params(self): + """Test initialization with default parameters.""" + ghost_clf = Ghost() + self.assertEqual(ghost_clf.optimization_metric, "Kappa") + self.assertEqual( + ghost_clf.thresholds, list(np.round(np.arange(0.05, 0.55, 0.05), 2)) + ) + self.assertIsNone(ghost_clf.decision_threshold) + + def test_init_custom_params(self): + """Test initialization with custom parameters.""" + thresholds = [0.1, 0.3, 0.5, 0.7, 0.9] + ghost_clf = Ghost( + thresholds=thresholds, optimization_metric="ROC", random_state=42 + ) + self.assertEqual(ghost_clf.optimization_metric, "ROC") + self.assertEqual(ghost_clf.thresholds, thresholds) + self.assertIsNotNone(ghost_clf.random_seed) + + def test_invalid_thresholds(self): + """Test initialization with invalid thresholds.""" + # Test with invalid thresholds (outside [0,1]) + with self.assertRaises(ValueError): + Ghost(thresholds=[-0.1, 0.5, 1.2]) + + # Test with duplicated thresholds + with self.assertRaises(ValueError): + Ghost(thresholds=[0.1, 0.5, 0.1]) + + def test_invalid_optimization_metric(self): + """Test initialization with invalid optimization metric.""" + with self.assertRaises(ValueError): + Ghost(optimization_metric="Invalid") + + def test_fit_and_transform_interface(self) -> None: + """Test basic usage of Ghost.""" + estimator = Ghost() + + # Call fit and check decision threshold is set + self.assertIsNone(estimator.decision_threshold) + estimator.fit(self.y_pred_train, self.y_train) + self.assertIsInstance(estimator.decision_threshold, float) + self.assertTrue(0 <= estimator.decision_threshold <= 1) + + # Call transform and check output + y_pred_transformed = estimator.transform(self.y_pred_train) + self.assertIsInstance(y_pred_transformed, np.ndarray) + self.assertEqual(y_pred_transformed.shape, (len(self.y_train),)) + self.assertTrue(np.all(np.isin(y_pred_transformed, [0, 1]))) + + def test_fit_transform_interface(self) -> None: + """Test basic usage of Ghost.""" + estimator = Ghost() + + # Call fit_transform and check results and decision threshold + self.assertIsNone(estimator.decision_threshold) + y_pred_transformed = estimator.fit_transform(self.y_pred_train, self.y_train) + self.assertIsInstance(estimator.decision_threshold, float) + self.assertTrue(0 <= estimator.decision_threshold <= 1) + self.assertIsInstance(y_pred_transformed, np.ndarray) + self.assertEqual(y_pred_transformed.shape, (len(self.y_train),)) + self.assertTrue(np.all(np.isin(y_pred_transformed, [0, 1]))) + + def test_transform_without_fit(self): + """Test error when predicting without fitting first.""" + ghost_clf = Ghost() + with self.assertRaises(ValueError): + ghost_clf.transform(self.X_test) + + def test_pipeline_using_prediction_wrapper(self): + """Test integration with Pipeline.""" + pipeline = Pipeline( + [ + ("clf", LogisticRegression(random_state=42)), + ( + "clf_threshold", + PostPredictionWrapper( + Ghost( + random_state=42, + ) + ), + ), + ] + ) + + # Fit and predict + pipeline.fit(self.X_train, self.y_train) + y_pred = pipeline.predict(self.X_test) + + print()