Skip to content

Commit 9b398f0

Browse files
committed
Pushing the docs to dev/ for branch: main, commit 325930e0151b67f6ddcd155d47a7b2a320f30ea7
1 parent 6474327 commit 9b398f0

File tree

1,595 files changed

+14225
-13660
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

1,595 files changed

+14225
-13660
lines changed
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Lines changed: 109 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,62 @@
11
"""
2-
===================================================================
2+
========================
33
Decision Tree Regression
4-
===================================================================
5-
6-
A 1D regression with decision tree.
7-
8-
The :ref:`decision trees <tree>` is
9-
used to fit a sine curve with addition noisy observation. As a result, it
10-
learns local linear regressions approximating the sine curve.
11-
12-
We can see that if the maximum depth of the tree (controlled by the
13-
`max_depth` parameter) is set too high, the decision trees learn too fine
14-
details of the training data and learn from the noise, i.e. they overfit.
4+
========================
5+
In this example, we demonstrate the effect of changing the maximum depth of a
6+
decision tree on how it fits to the data. We perform this once on a 1D regression
7+
task and once on a multi-output regression task.
158
"""
169

1710
# Authors: The scikit-learn developers
1811
# SPDX-License-Identifier: BSD-3-Clause
1912

20-
# Import the necessary modules and libraries
21-
import matplotlib.pyplot as plt
13+
# %%
14+
# Decision Tree on a 1D Regression Task
15+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
16+
#
17+
# Here we fit a tree on a 1D regression task.
18+
#
19+
# The :ref:`decision trees <tree>` is
20+
# used to fit a sine curve with addition noisy observation. As a result, it
21+
# learns local linear regressions approximating the sine curve.
22+
#
23+
# We can see that if the maximum depth of the tree (controlled by the
24+
# `max_depth` parameter) is set too high, the decision trees learn too fine
25+
# details of the training data and learn from the noise, i.e. they overfit.
26+
#
27+
# Create a random 1D dataset
28+
# --------------------------
2229
import numpy as np
2330

24-
from sklearn.tree import DecisionTreeRegressor
25-
26-
# Create a random dataset
2731
rng = np.random.RandomState(1)
2832
X = np.sort(5 * rng.rand(80, 1), axis=0)
2933
y = np.sin(X).ravel()
3034
y[::5] += 3 * (0.5 - rng.rand(16))
3135

36+
# %%
3237
# Fit regression model
38+
# --------------------
39+
# Here we fit two models with different maximum depths
40+
from sklearn.tree import DecisionTreeRegressor
41+
3342
regr_1 = DecisionTreeRegressor(max_depth=2)
3443
regr_2 = DecisionTreeRegressor(max_depth=5)
3544
regr_1.fit(X, y)
3645
regr_2.fit(X, y)
3746

47+
# %%
3848
# Predict
49+
# -------
50+
# Get predictions on the test set
3951
X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
4052
y_1 = regr_1.predict(X_test)
4153
y_2 = regr_2.predict(X_test)
4254

55+
# %%
4356
# Plot the results
57+
# ----------------
58+
import matplotlib.pyplot as plt
59+
4460
plt.figure()
4561
plt.scatter(X, y, s=20, edgecolor="black", c="darkorange", label="data")
4662
plt.plot(X_test, y_1, color="cornflowerblue", label="max_depth=2", linewidth=2)
@@ -50,3 +66,79 @@
5066
plt.title("Decision Tree Regression")
5167
plt.legend()
5268
plt.show()
69+
70+
# %%
71+
# As you can see, the model with a depth of 5 (yellow) learns the details of the
72+
# training data to the point that it overfits to the noise. On the other hand,
73+
# the model with a depth of 2 (blue) learns the major tendencies in the data well
74+
# and does not overfit. In real use cases, you need to make sure that the tree
75+
# is not overfitting the training data, which can be done using cross-validation.
76+
77+
# %%
78+
# Decision Tree Regression with Multi-Output Targets
79+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
80+
#
81+
# Here the :ref:`decision trees <tree>`
82+
# is used to predict simultaneously the noisy `x` and `y` observations of a circle
83+
# given a single underlying feature. As a result, it learns local linear
84+
# regressions approximating the circle.
85+
#
86+
# We can see that if the maximum depth of the tree (controlled by the
87+
# `max_depth` parameter) is set too high, the decision trees learn too fine
88+
# details of the training data and learn from the noise, i.e. they overfit.
89+
90+
# %%
91+
# Create a random dataset
92+
# -----------------------
93+
rng = np.random.RandomState(1)
94+
X = np.sort(200 * rng.rand(100, 1) - 100, axis=0)
95+
y = np.array([np.pi * np.sin(X).ravel(), np.pi * np.cos(X).ravel()]).T
96+
y[::5, :] += 0.5 - rng.rand(20, 2)
97+
98+
# %%
99+
# Fit regression model
100+
# --------------------
101+
regr_1 = DecisionTreeRegressor(max_depth=2)
102+
regr_2 = DecisionTreeRegressor(max_depth=5)
103+
regr_3 = DecisionTreeRegressor(max_depth=8)
104+
regr_1.fit(X, y)
105+
regr_2.fit(X, y)
106+
regr_3.fit(X, y)
107+
108+
# %%
109+
# Predict
110+
# -------
111+
# Get predictions on the test set
112+
X_test = np.arange(-100.0, 100.0, 0.01)[:, np.newaxis]
113+
y_1 = regr_1.predict(X_test)
114+
y_2 = regr_2.predict(X_test)
115+
y_3 = regr_3.predict(X_test)
116+
117+
# %%
118+
# Plot the results
119+
# ----------------
120+
plt.figure()
121+
s = 25
122+
plt.scatter(y[:, 0], y[:, 1], c="yellow", s=s, edgecolor="black", label="data")
123+
plt.scatter(
124+
y_1[:, 0],
125+
y_1[:, 1],
126+
c="cornflowerblue",
127+
s=s,
128+
edgecolor="black",
129+
label="max_depth=2",
130+
)
131+
plt.scatter(y_2[:, 0], y_2[:, 1], c="red", s=s, edgecolor="black", label="max_depth=5")
132+
plt.scatter(y_3[:, 0], y_3[:, 1], c="blue", s=s, edgecolor="black", label="max_depth=8")
133+
plt.xlim([-6, 6])
134+
plt.ylim([-6, 6])
135+
plt.xlabel("target 1")
136+
plt.ylabel("target 2")
137+
plt.title("Multi-output Decision Tree Regression")
138+
plt.legend(loc="best")
139+
plt.show()
140+
141+
# %%
142+
# As you can see, the higher the value of `max_depth`, the more details of the data
143+
# are caught by the model. However, the model also overfits to the data and is
144+
# influenced by the noise.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

dev/_downloads/8a87659782dab72f4bb6ef792517234c/plot_tree_regression_multioutput.py

Lines changed: 0 additions & 68 deletions
This file was deleted.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)