Show Sidebar Hide Sidebar

# OOB Errors for Random Forests in Scikit-learn

The RandomForestClassifier is trained using bootstrap aggregation, where each new tree is fit from a bootstrap sample of the training observations z_i = (x_i, y_i). The out-of-bag (OOB) error is the average error for each z_i calculated using predictions from the trees that do not contain z_i in their respective bootstrap sample. This allows the RandomForestClassifier to be fit and validated whilst being trained [1].

The example below demonstrates how the OOB error can be measured at the addition of each new tree during training.

The resulting plot allows a practitioner to approximate a suitable value of n_estimators at which the error stabilizes.

[1] T. Hastie, R. Tibshirani and J. Friedman, “Elements of Statistical Learning Ed. 2”, p592-593, Springer, 2009.

#### New to Plotly?¶

You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

### Version¶

In [1]:
import sklearn
sklearn.__version__

Out[1]:
'0.18.1'

### Imports¶

This tutorial imports make_classification and RandomForestClassifier.

In [2]:
import plotly.plotly as py
import plotly.graph_objs as go

from collections import OrderedDict
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier


### Calculations¶

In [3]:
RANDOM_STATE = 123

# Generate a binary classification dataset.
X, y = make_classification(n_samples=500, n_features=25,
n_clusters_per_class=1, n_informative=15,
random_state=RANDOM_STATE)

# NOTE: Setting the warm_start construction parameter to True disables
# support for parallelized ensembles but is necessary for tracking the OOB
# error trajectory during training.
ensemble_clfs = [
("RandomForestClassifier,<br>max_features='sqrt'",
RandomForestClassifier(warm_start=True, oob_score=True,
max_features="sqrt",
random_state=RANDOM_STATE)),
("RandomForestClassifier,<br>max_features='log2'",
RandomForestClassifier(warm_start=True, max_features='log2',
oob_score=True,
random_state=RANDOM_STATE)),
("RandomForestClassifier,<br>max_features=None",
RandomForestClassifier(warm_start=True, max_features=None,
oob_score=True,
random_state=RANDOM_STATE))
]

# Map a classifier name to a list of (<n_estimators>, <error rate>) pairs.
error_rate = OrderedDict((label, []) for label, _ in ensemble_clfs)

# Range of n_estimators values to explore.
min_estimators = 15
max_estimators = 175

for label, clf in ensemble_clfs:
for i in range(min_estimators, max_estimators + 1):
clf.set_params(n_estimators=i)
clf.fit(X, y)

# Record the OOB error for each n_estimators=i setting.
oob_error = 1 - clf.oob_score_
error_rate[label].append((i, oob_error))


### Plot Results¶

In [4]:
data = []

for label, clf_err in error_rate.items():
xs, ys = zip(*clf_err)
trace = go.Scatter(x=xs, y=ys,
name=label,
mode='lines'
)
data.append(trace)

layout = go.Layout(xaxis=dict(title='n_estimators'),
yaxis=dict(title='OOB error rate'),
hovermode='closest'
)

fig = go.Figure(data=data, layout=layout)

In [5]:
py.iplot(fig)

Out[5]:

Author:

    Kian Ho <hui.kian.ho@gmail.com>

Gilles Louppe <g.louppe@gmail.com>

Andreas Mueller <amueller@ais.uni-bonn.de>



    BSD 3 Clause