Show Sidebar Hide Sidebar

Multi-Output Decision Tree Regression in Scikit-learn

An example to illustrate multi-output regression with decision tree.

The decision trees is used to predict simultaneously the noisy x and y observations of a circle given a single underlying feature. As a result, it learns local linear regressions approximating the circle.

We can see that if the maximum depth of the tree (controlled by the max_depth parameter) is set too high, the decision trees learn too fine details of the training data and learn from the noise, i.e. they overfit.

Version

In [1]:
import sklearn
sklearn.__version__
Out[1]:
'0.18.1'

Imports

In [2]:
print(__doc__)

import plotly.plotly as py
import plotly.graph_objs as go

import numpy as np
from sklearn.tree import DecisionTreeRegressor
Automatically created module for IPython interactive environment

Calculations

In [3]:
# Create a random dataset
rng = np.random.RandomState(1)
X = np.sort(200 * rng.rand(100, 1) - 100, axis=0)
y = np.array([np.pi * np.sin(X).ravel(), np.pi * np.cos(X).ravel()]).T
y[::5, :] += (0.5 - rng.rand(20, 2))

# Fit regression model
regr_1 = DecisionTreeRegressor(max_depth=2)
regr_2 = DecisionTreeRegressor(max_depth=5)
regr_3 = DecisionTreeRegressor(max_depth=8)
regr_1.fit(X, y)
regr_2.fit(X, y)
regr_3.fit(X, y)

# Predict
X_test = np.arange(-100.0, 100.0, 0.01)[:, np.newaxis]
y_1 = regr_1.predict(X_test)
y_2 = regr_2.predict(X_test)
y_3 = regr_3.predict(X_test)

Plot Results

In [4]:
p1 = go.Scatter(x=y[:, 0], y=y[:, 1], 
                mode='markers',
                marker=dict(color="navy"),
                name="data")

p2 = go.Scatter(x=y_1[:, 0], y=y_1[:, 1],
                mode='markers',
                marker=dict(color="cornflowerblue"),
                name="max_depth=2")

p3 = go.Scatter(x=y_2[:, 0], y=y_2[:, 1],
                mode='markers',
                marker=dict(color="cyan"),
                name="max_depth=5")

p4 = go.Scatter(x=y_3[:, 0], y=y_3[:, 1], 
                mode='markers',
                marker=dict(color="orange"),
                name="max_depth=8")


layout = go.Layout(xaxis=dict(title="target 1", zeroline=False),
                   yaxis=dict(title="target 2", zeroline=False),
                   title="Multi-output Decision Tree Regression"
                  )
fig = go.Figure(data=[p1, p2, p3, p4], layout=layout)
In [5]:
py.iplot(fig)
Out[5]:
Still need help?
Contact Us

For guaranteed 24 hour response turnarounds, upgrade to a Developer Support Plan.