Show Sidebar Hide Sidebar

Receiver Operating Characteristic (ROC) with Cross Validation in Scikit-learn

Example of Receiver Operating Characteristic (ROC) metric to evaluate classifier output quality using cross-validation.

ROC curves typically feature true positive rate on the Y axis, and false positive rate on the X axis. This means that the top left corner of the plot is the “ideal” point - a false positive rate of zero, and a true positive rate of one. This is not very realistic, but it does mean that a larger area under the curve (AUC) is usually better.

The “steepness” of ROC curves is also important, since it is ideal to maximize the true positive rate while minimizing the false positive rate.

This example shows the ROC response of different datasets, created from K-fold cross-validation. Taking all of these curves, it is possible to calculate the mean area under curve, and see the variance of the curve when the training set is split into different subsets. This roughly shows how the classifier output is affected by changes in the training data, and how different the splits generated by K-fold cross-validation are from one another.

New to Plotly?

Plotly's Python library is free and open source! Get started by downloading the client and reading the primer.
You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

Version

In [1]:
import sklearn
sklearn.__version__
Out[1]:
'0.18.1'

Imports

This tutorial imports roc_curve, auc and StratifiedKFold.

In [2]:
print(__doc__)

import plotly.plotly as py
import plotly.graph_objs as go

import numpy as np
from scipy import interp
from itertools import cycle

from sklearn import svm, datasets
from sklearn.metrics import roc_curve, auc
from sklearn.model_selection import StratifiedKFold
Automatically created module for IPython interactive environment

Calculations

Data IO and generation

In [3]:
# import some data to play with
iris = datasets.load_iris()
X = iris.data
y = iris.target
X, y = X[y != 2], y[y != 2]
n_samples, n_features = X.shape

# Add noisy features
random_state = np.random.RandomState(0)
X = np.c_[X, random_state.randn(n_samples, 200 * n_features)]

Classification and ROC analysis

In [4]:
# Run classifier with cross-validation  
cv = StratifiedKFold(n_splits=6)
classifier = svm.SVC(kernel='linear', probability=True,
                     random_state=random_state)

mean_tpr = 0.0
mean_fpr = np.linspace(0, 1, 100)

colors = cycle(['cyan', 'indigo', 'seagreen', 'yellow', 'blue', 'darkorange'])
lw = 2

i = 0

Plot ROC Curves

In [5]:
data = []
for (train, test), color in zip(cv.split(X, y), colors):
    probas_ = classifier.fit(X[train], y[train]).predict_proba(X[test])
    # Compute ROC curve and area the curve
    fpr, tpr, thresholds = roc_curve(y[test], probas_[:, 1])
    mean_tpr += interp(mean_fpr, fpr, tpr)
    mean_tpr[0] = 0.0
    roc_auc = auc(fpr, tpr)
    
    trace = go.Scatter(x=fpr, y=tpr, 
                       mode='lines', 
                       line=dict(width=lw, color=color),
                       name='ROC fold %d (area = %0.2f)' % (i, roc_auc))

    data.append(trace)
    i += 1
    
trace = go.Scatter(x=[0, 1], y=[0, 1], 
                   mode='lines', 
                   line=dict(width=lw, color='black', dash='dash'),
                   name='Luck')
data.append(trace)

mean_tpr /= cv.get_n_splits(X, y)
mean_tpr[-1] = 1.0
mean_auc = auc(mean_fpr, mean_tpr)

trace = go.Scatter(x=mean_fpr, y=mean_tpr, 
                   mode='lines', 
                   line=dict(width=lw, color='green', dash='dash'),
                   name='Mean ROC (area = %0.2f)' % mean_auc)
data.append(trace)

layout = go.Layout(title='Receiver operating characteristic example',
                   xaxis=dict(title='False Positive Rate', showgrid=False,
                              range=[-0.05, 1.05]),
                   yaxis=dict(title='True Positive Rate', showgrid=False,
                              range=[-0.05, 1.05]))
fig = go.Figure(data=data, layout=layout)
In [6]:
py.iplot(fig)
Out[6]:
Still need help?
Contact Us

For guaranteed 24 hour response turnarounds, upgrade to a Developer Support Plan.