Show Sidebar Hide Sidebar

# Decision Tree Regression with AdaBoost in Scikit-learn

A decision tree is boosted using the AdaBoost.R2 [1] algorithm on a 1D sinusoidal dataset with a small amount of Gaussian noise. 299 boosts (300 decision trees) is compared with a single decision tree regressor. As the number of boosts is increased the regressor can fit more detail.

[1] H.Drucker, “Improving Regressors using Boosting Techniques”, 1997.

#### New to Plotly?¶

You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

### Version¶

In [1]:
import sklearn
sklearn.__version__

Out[1]:
'0.18.1'

### Imports¶

This tutorial imports DecisionTreeRegressor and AdaBoostRegressor.

In [2]:
print(__doc__)

import plotly.plotly as py
import plotly.graph_objs as go
from plotly import tools

import numpy as np
from sklearn.tree import DecisionTreeRegressor

Automatically created module for IPython interactive environment


### Calculations¶

In [3]:
rng = np.random.RandomState(1)
X = np.linspace(0, 6, 100)[:, np.newaxis]
y = np.sin(X).ravel() + np.sin(6 * X).ravel() + rng.normal(0, 0.1, X.shape[0])

# Fit regression model
regr_1 = DecisionTreeRegressor(max_depth=4)

n_estimators=300, random_state=rng)

regr_1.fit(X, y)
regr_2.fit(X, y)

# Predict
y_1 = regr_1.predict(X)
y_2 = regr_2.predict(X)


### Plot Results¶

In [4]:
def data_to_plotly(x):
plotly_data = []
for i in range(0, len(x)):
plotly_data.append(x[i][0])

return plotly_data

In [5]:
training_samples = go.Scatter(x=data_to_plotly(X),
y=y,
name="training samples",
mode='markers',
marker=dict(color='black', size=6)
)

n_estimator1 = go.Scatter(x=data_to_plotly(X),
y=y_1,
name="n_estimators=1",
mode='lines',
line=dict(color='green'),
)

n_estimator300 = go.Scatter(x=data_to_plotly(X),
y=y_2,
name="n_estimators=300",
mode='lines',
line=dict(color='red'),
)
data = [training_samples, n_estimator1, n_estimator300]

layout = go.Layout(title='Boosted Decision Tree Regression',
xaxis=dict(title='data'),
yaxis=dict(title='target')
)
fig = go.Figure(data=data, layout=layout)

In [6]:
py.iplot(fig)

Out[6]:

Author:

    Noel Dawe <noel.dawe@gmail.com>



    BSD 3 clause