Show Sidebar Hide Sidebar

# Decision Tree Structure in Scikit-learn

The decision tree structure can be analysed to gain further insight on the relation between the features and the target to predict. In this example, we show how to retrieve:

• the binary tree structure;

• the depth of each node and whether or not it’s a leaf;

• the nodes that were reached by a sample using the decision_path method;

• the leaf that was reached by a sample using the apply method;

• the rules that were used to predict a sample;

• the decision path shared by a group of samples.

#### New to Plotly?¶

You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

### Version¶

In [1]:
import sklearn
sklearn.__version__

Out[1]:
'0.18.1'

### Imports¶

This tutorial imports train_test_split, load_iris and DecisionTreeClassifier.

In [2]:
import numpy as np

from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier


### Calculations¶

In [3]:
iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

estimator = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)
estimator.fit(X_train, y_train)

Out[3]:
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
max_features=None, max_leaf_nodes=3, min_impurity_split=1e-07,
min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, presort=False, random_state=0,
splitter='best')

The decision estimator has an attribute called tree which stores the entire tree structure and allows access to low level attributes. The binary tree tree is represented as a number of parallel arrays. The i-th element of each array holds information about the node i. Node 0 is the tree's root. NOTE: Some of the arrays only apply to either leaves or split nodes, resp. In this case the values of nodes of the other type are arbitrary!

Among those arrays, we have:

- left_child, id of the left child of the node
- right_child, id of the right child of the node
- feature, feature used for splitting the node
- threshold, threshold value at the node



Using those arrays, we can parse the tree structure:

In [4]:
n_nodes = estimator.tree_.node_count
children_left = estimator.tree_.children_left
children_right = estimator.tree_.children_right
feature = estimator.tree_.feature
threshold = estimator.tree_.threshold


The tree structure can be traversed to compute various properties such as the depth of each node and whether or not it is a leaf.

In [5]:
node_depth = np.zeros(shape=n_nodes)
is_leaves = np.zeros(shape=n_nodes, dtype=bool)
stack = [(0, -1)]  # seed is the root node id and its parent depth
while len(stack) > 0:
node_id, parent_depth = stack.pop()
node_depth[node_id] = parent_depth + 1

# If we have a test node
if (children_left[node_id] != children_right[node_id]):
stack.append((children_left[node_id], parent_depth + 1))
stack.append((children_right[node_id], parent_depth + 1))
else:
is_leaves[node_id] = True

print("The binary tree structure has %s nodes and has "
"the following tree structure:"
% n_nodes)
for i in range(n_nodes):
if is_leaves[i]:
print("%snode=%s leaf node." % (node_depth[i] * "\t", i))
else:
print("%snode=%s test node: go to node %s if X[:, %s] <= %ss else to "
"node %s."
% (node_depth[i] * "\t",
i,
children_left[i],
feature[i],
threshold[i],
children_right[i],
))
print()

The binary tree structure has 5 nodes and has the following tree structure:
node=0 test node: go to node 1 if X[:, 3] <= 0.800000011921s else to node 2.
node=1 leaf node.
node=2 test node: go to node 3 if X[:, 2] <= 4.94999980927s else to node 4.
node=3 leaf node.
node=4 leaf node.
()


First let's retrieve the decision path of each sample. The decision_path method allows to retrieve the node indicator functions. A non zero element of indicator matrix at the position (i, j) indicates that the sample i goes through the node j.

In [6]:
node_indicator = estimator.decision_path(X_test)

# Similarly, we can also have the leaves ids reached by each sample.

leave_id = estimator.apply(X_test)


Now, it's possible to get the tests that were used to predict a sample or a group of samples. First, let's make it for the sample.

In [7]:
sample_id = 0
node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
node_indicator.indptr[sample_id + 1]]

print('Rules used to predict sample %s: ' % sample_id)
for node_id in node_index:
if leave_id[sample_id] != node_id:
continue

if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):
threshold_sign = "<="
else:
threshold_sign = ">"

print("decision id node %s : (X[%s, %s] (= %s) %s %s)"
% (node_id,
sample_id,
feature[node_id],
X_test[i, feature[node_id]],
threshold_sign,
threshold[node_id]))

Rules used to predict sample 0:
decision id node 4 : (X[0, -2] (= 1.5) > -2.0)


For a group of samples, we have the following common node.

In [8]:
sample_ids = [0, 1]
common_nodes = (node_indicator.toarray()[sample_ids].sum(axis=0) ==
len(sample_ids))

common_node_id = np.arange(n_nodes)[common_nodes]

print("\nThe following samples %s share the node %s in the tree"
% (sample_ids, common_node_id))
print("It is %s %% of all nodes." % (100 * len(common_node_id) / n_nodes,))

The following samples [0, 1] share the node [0 2] in the tree
It is 40 % of all nodes.

Still need help?