Skip to content

add classification threshold selection with GHOST #183

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 196 additions & 0 deletions molpipeline/estimators/ghost_classification_threshold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
"""Classification threshold adjustment with GHOST."""

from typing import Any, Self, Literal, override

import ghostml
from sklearn.utils._random import check_random_state
from sklearn.base import BaseEstimator, TransformerMixin

from molpipeline import FilterReinserter
from molpipeline.post_prediction import PostPredictionWrapper
import numpy as np
import numpy.typing as npt

from molpipeline.utils.molpipeline_types import AnyPredictor


class Ghost(BaseEstimator, TransformerMixin):
"""GHOST estimator wrapper for classification threshold adjustment.

Applies the GHOST (Generalized tHreshOld ShifTing) algorithm to adjust the
classification threshold. See the paper and repository for more details:

Esposito, C., Landrum, G.A., Schneider, N., Stiefl, N. and Riniker, S., 2021.
"GHOST: adjusting the decision threshold to handle imbalanced data in machine
learning."
Journal of Chemical Information and Modeling, 61(6), pp.2623-2640.
https://doi.org/10.1021/acs.jcim.1c00160

https://github.com/rinikerlab/GHOST

"""

# TODO muss ich die Params fuer GHOST auf MolPipeline Seite eigentlich validieren? mach das vllt ghost besser?
# - Wenn ichs mache passiert es at construction time.
# - Dafuer muss ich die Params und validation in sync mit ghostml halten.
def __init__(
self,
thresholds: list[float] | None = None,
optimization_metric: Literal["Kappa", "ROC"] = "Kappa",
random_state: int | None = None,
):
"""
Initialize the GHOST post-prediction wrapper.

Parameters
----------
threshold : float, optional
Classification threshold to apply, by default 0.5.

"""
if thresholds is None:
# use default bins from GHOST paper
thresholds = list(np.round(np.arange(0.05, 0.55, 0.05), 2))
self._check_thresholds(thresholds)
self.thresholds = thresholds
self._check_optimization_metric(optimization_metric)
self.optimization_metric = optimization_metric
self.random_seed = self._get_random_seed_from_input(random_state)
self.decision_threshold: float | None = None

@staticmethod
def _check_optimization_metric(
optimization_metric: Literal["Kappa", "ROC"],
) -> None:
"""Check if the optimization metric is valid."""
if optimization_metric not in {"Kappa", "ROC"}:
raise ValueError(
"optimization_metric must be either 'Kappa' or 'ROC'",
)

@staticmethod
def _get_random_seed_from_input(random_state: int | None) -> int:
"""Get a random seed from the input data.

GHOST expects an integer random seed, so we generate one if not provided.
"""
rng = check_random_state(random_state)
return rng.randint(0, np.iinfo(np.int32).max)

@staticmethod
def _check_thresholds(thresholds: list[float]) -> None:
"""Check if the thresholds are valid."""
if len(thresholds) == 0:
raise ValueError("Thresholds must not be empty.")
if not all(0 <= t <= 1 for t in thresholds):
raise ValueError("All thresholds must be between 0 and 1.")
if len(set(thresholds)) != len(thresholds):
raise ValueError("Thresholds must be unique.")

def _check_and_process_X(self, X: npt.NDArray[Any]) -> npt.NDArray[Any]:
"""Check and process the input predictions."""
y_pred = X
if y_pred.ndim == 2:
# assume binary classification output when it's a 2D array
# take class probabilities for class 1
y_pred = y_pred[:, 1]
if y_pred.ndim != 1:
raise ValueError("X must be a 1D or 2D array.")
if not np.all((y_pred >= 0) & (y_pred <= 1)):
raise ValueError("All values in X must be between 0 and 1.")
return y_pred

def fit(
self,
X: npt.NDArray[Any], # pylint: disable=invalid-name
y: npt.NDArray[Any] | None = None,
) -> Self:
"""Fit the GHOST post-prediction wrapper.

Prepares the decision threshold based on the predictions.

Parameters
----------
X : npt.NDArray[Any]
Input data. The predictions.
y : npt.NDArray[Any] | None, optional
Target data. The true labels.

"""
# TODO vielleicht noch nan etc auf y_pred_proba rauswerfen?

y_pred = X
y_true = y

y_pred = self._check_and_process_X(y_pred)
if y_true is None:
raise ValueError("y must be provided for fitting the GHOST wrapper.")
if not np.all(np.isin(y_true, [0, 1])):
raise ValueError("y must be binary (0 or 1).")

self.decision_threshold = ghostml.optimize_threshold_from_predictions(
y_true,
y_pred,
thresholds=self.thresholds,
ThOpt_metrics=self.optimization_metric,
random_seed=self.random_seed,
)

return self

def transform(
self,
X: npt.NDArray[Any], # pylint: disable=invalid-name
) -> npt.NDArray[np.int64]:
if self.decision_threshold is None:
raise ValueError("Call fit first before calling transform.")

y_pred = X
y_pred = self._check_and_process_X(y_pred)

# TODO vielleicht noch nan etc auf y_pred_proba rauswerfen?
return (y_pred > self.decision_threshold).astype(np.int64)


class GhostPostPredictionWrapper(PostPredictionWrapper):
"""Post-prediction wrapper for GHOST classification threshold adjustment."""

def __init__(
self,
thresholds: list[float] | None = None,
optimization_metric: Literal["Kappa", "ROC"] = "Kappa",
random_state: int | None = None,
):
super().__init__(
wrapped_estimator=Ghost(
thresholds=thresholds,
optimization_metric=optimization_metric,
random_state=random_state,
)
)

@staticmethod
def _check_estimator(estimator: AnyPredictor) -> None:
"""Check if the wrapped estimator has a predict_proba method."""
if not hasattr(estimator, "predict_proba"):
raise ValueError(
f"GHOST requires an estimator with a predict_proba method. Got: {estimator}"
)

@override
def prepare_input(
self,
X: npt.NDArray[np.float64],
y: npt.NDArray[np.float64],
final_estimator: AnyPredictor,
) -> tuple[npt.NDArray[Any], npt.NDArray[Any]]:
"""Prepare input data for fitting."""
self._check_estimator(final_estimator)
y_pred_proba = final_estimator.predict_proba(X)
if y_pred_proba.ndim == 2 and y_pred_proba.shape[1] == 2:
# binary classification, take probabilities for class 1
y_pred_proba = y_pred_proba[:, 1]
elif y_pred_proba.ndim != 1:
raise ValueError("predict_proba must return a 1D or 2D array.")

return y_pred_proba, y
10 changes: 10 additions & 0 deletions molpipeline/pipeline/_skl_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,13 @@ def fit(self, X: Any, y: Any = None, **fit_params: Any) -> Self:
]
self._final_estimator.fit(Xt, yt, **fit_params_last_step["fit"])

# fit post-processing steps
for _, post_element in self._post_processing_steps():
X_post_pred, y_post_pred = post_element.prepare_input(
Xt, yt, self._final_estimator
)
post_element.fit(X_post_pred, y_post_pred)

return self

def _can_fit_transform(self) -> bool:
Expand Down Expand Up @@ -624,6 +631,9 @@ def fit_transform(self, X: Any, y: Any = None, **params: Any) -> Any:
f"match fit_transform of Pipeline {self.__class__.__name__}"
)
for _, post_element in self._post_processing_steps():
iter_input, iter_label = post_element.prepare_input(
iter_input, iter_label, self._final_estimator
)
iter_input = post_element.fit_transform(iter_input, iter_label)
return iter_input

Expand Down
11 changes: 11 additions & 0 deletions molpipeline/post_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ def transform(self, X: Any, **params: Any) -> Any: # pylint: disable=invalid-na
Transformed data.
"""

def prepare_input(
self, X, y, final_estimator: AnyPredictor
) -> tuple[npt.NDArray[Any], npt.NDArray[Any]]:
"""Placeholder method for additional functionality.

Use this function to prepare the input data for fitting.

This method can be overridden in subclasses to implement specific behavior.
"""
return X, y


class PostPredictionWrapper(PostPredictionTransformation):
"""Wrapper for post prediction transformations.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ authors = [
description = "Integration of rdkit functionality into sklearn pipelines."
readme = "README.md"
dependencies = [
"ghostml>=0.3.0",
"joblib>=1.3.0",
"loguru>=0.7.3",
"matplotlib>=3.10.1",
Expand Down
19 changes: 19 additions & 0 deletions tests/test_elements/test_post_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier

from molpipeline import Pipeline
from molpipeline.post_prediction import PostPredictionWrapper


Expand Down Expand Up @@ -88,3 +89,21 @@ def test_inverse_transform(self) -> None:
self.assertEqual(pca_inverse.shape, ppw_inverse.shape)

self.assertTrue(np.allclose(pca_inverse, ppw_inverse))

# def test_pipeline_integration(self) -> None:
# """Test integration with a pipeline."""
#
# rng = np.random.default_rng(20240918)
# features = rng.random((10, 5))
# y = rng.integers(0, 2, size=(10,))
#
# pipeline = Pipeline(
# [
# ("rf", RandomForestClassifier(n_estimators=10)),
# ("pca", PostPredictionWrapper(PCA(n_components=1))),
# ]
# )
#
# pipeline.fit(features, y)
# ppw_transformed = pipeline.predict(features)
# self.assertEqual(ppw_transformed.shape, (10, 3))
Loading
Loading