Skip to content

Commit a4cc05d

Browse files
committedJun 2, 2025·
Pushing the docs to dev/ for branch: main, commit 6343cd74c9ff90526212cfaf65ac58a6c59a82e3
1 parent 0501554 commit a4cc05d

File tree

1,516 files changed

+6011
-5978
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

1,516 files changed

+6011
-5978
lines changed
 

‎dev/_downloads/010337852815f8103ac6cca38a812b3c/plot_roc_crossval.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -62,46 +62,56 @@
6262
# Classification and ROC analysis
6363
# -------------------------------
6464
#
65-
# Here we run a :class:`~sklearn.svm.SVC` classifier with cross-validation and
66-
# plot the ROC curves fold-wise. Notice that the baseline to define the chance
65+
# Here we run :func:`~sklearn.model_selection.cross_validate` on a
66+
# :class:`~sklearn.svm.SVC` classifier, then use the computed cross-validation results
67+
# to plot the ROC curves fold-wise. Notice that the baseline to define the chance
6768
# level (dashed ROC curve) is a classifier that would always predict the most
6869
# frequent class.
6970

7071
import matplotlib.pyplot as plt
7172

7273
from sklearn import svm
7374
from sklearn.metrics import RocCurveDisplay, auc
74-
from sklearn.model_selection import StratifiedKFold
75+
from sklearn.model_selection import StratifiedKFold, cross_validate
7576

7677
n_splits = 6
7778
cv = StratifiedKFold(n_splits=n_splits)
7879
classifier = svm.SVC(kernel="linear", probability=True, random_state=random_state)
80+
cv_results = cross_validate(
81+
classifier, X, y, cv=cv, return_estimator=True, return_indices=True
82+
)
83+
84+
prop_cycle = plt.rcParams["axes.prop_cycle"]
85+
colors = prop_cycle.by_key()["color"]
86+
curve_kwargs_list = [
87+
dict(alpha=0.3, lw=1, color=colors[fold % len(colors)]) for fold in range(n_splits)
88+
]
89+
names = [f"ROC fold {idx}" for idx in range(n_splits)]
7990

80-
tprs = []
81-
aucs = []
8291
mean_fpr = np.linspace(0, 1, 100)
92+
interp_tprs = []
93+
94+
_, ax = plt.subplots(figsize=(6, 6))
95+
viz = RocCurveDisplay.from_cv_results(
96+
cv_results,
97+
X,
98+
y,
99+
ax=ax,
100+
name=names,
101+
curve_kwargs=curve_kwargs_list,
102+
plot_chance_level=True,
103+
)
83104

84-
fig, ax = plt.subplots(figsize=(6, 6))
85-
for fold, (train, test) in enumerate(cv.split(X, y)):
86-
classifier.fit(X[train], y[train])
87-
viz = RocCurveDisplay.from_estimator(
88-
classifier,
89-
X[test],
90-
y[test],
91-
name=f"ROC fold {fold}",
92-
curve_kwargs=dict(alpha=0.3, lw=1),
93-
ax=ax,
94-
plot_chance_level=(fold == n_splits - 1),
95-
)
96-
interp_tpr = np.interp(mean_fpr, viz.fpr, viz.tpr)
105+
for idx in range(n_splits):
106+
interp_tpr = np.interp(mean_fpr, viz.fpr[idx], viz.tpr[idx])
97107
interp_tpr[0] = 0.0
98-
tprs.append(interp_tpr)
99-
aucs.append(viz.roc_auc)
108+
interp_tprs.append(interp_tpr)
100109

101-
mean_tpr = np.mean(tprs, axis=0)
110+
mean_tpr = np.mean(interp_tprs, axis=0)
102111
mean_tpr[-1] = 1.0
103112
mean_auc = auc(mean_fpr, mean_tpr)
104-
std_auc = np.std(aucs)
113+
std_auc = np.std(viz.roc_auc)
114+
105115
ax.plot(
106116
mean_fpr,
107117
mean_tpr,
@@ -111,7 +121,7 @@
111121
alpha=0.8,
112122
)
113123

114-
std_tpr = np.std(tprs, axis=0)
124+
std_tpr = np.std(interp_tprs, axis=0)
115125
tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
116126
tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
117127
ax.fill_between(
Binary file not shown.

0 commit comments

Comments
 (0)
Please sign in to comment.