Show Sidebar Hide Sidebar

# Classification Probability in Scikit-learn

Plot the classification probability for different classifiers. We use a 3 class dataset, and we classify it with a Support Vector classifier, L1 and L2 penalized logistic regression with either a One-Vs-Rest or multinomial setting, and Gaussian process classification.

The logistic regression is not a multiclass classifier out of the box. As a result it can identify only the first class.

#### New to Plotly?¶

You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

### Version¶

In [1]:
import sklearn
sklearn.__version__

Out[1]:
'0.18'

### Imports¶

This tutorial imports LogisticRegression, SVC, GaussianProcessClassifier and RBF.

In [2]:
print(__doc__)

import plotly.plotly as py
import plotly.graph_objs as go
from plotly import tools

import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np

from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.gaussian_process.kernels import RBF
from sklearn import datasets

Automatically created module for IPython interactive environment


### Calculations¶

In [3]:
iris = datasets.load_iris()
X = iris.data[:, 0:2]  # we only take the first two features for visualization
y = iris.target

n_features = X.shape[1]

C = 1.0
kernel = 1.0 * RBF([1.0, 1.0])  # for GPC

# Create different classifiers. The logistic regression cannot do
# multiclass out of the box.
classifiers = {'L1 logistic': LogisticRegression(C=C, penalty='l1'),
'L2 logistic (OvR)': LogisticRegression(C=C, penalty='l2'),
'Linear SVC': SVC(kernel='linear', C=C, probability=True,
random_state=0),
'L2 logistic (Multinomial)': LogisticRegression(
C=C, solver='lbfgs', multi_class='multinomial'),
'GPC': GaussianProcessClassifier(kernel)
}

n_classifiers = len(classifiers)

xx = np.linspace(3, 9, 100)
yy = np.linspace(1, 5, 100).T
xx, yy = np.meshgrid(xx, yy)
Xfull = np.c_[xx.ravel(), yy.ravel()]


### Plots¶

In [4]:
def matplotlib_to_plotly(cmap, pl_entries):
h = 1.0/(pl_entries-1)
pl_colorscale = []

for k in range(pl_entries):
C = map(np.uint8, np.array(cmap(k*h)[:3])*255)
pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))])

return pl_colorscale

fig = tools.make_subplots(rows=5, cols=3, print_grid=False,
subplot_titles = ('Class 0', 'Class 1', 'Class 2',
'Class 0', 'Class 1', 'Class 2',
'Class 0', 'Class 1', 'Class 2',
'Class 0', 'Class 1', 'Class 2',
'Class 0', 'Class 1', 'Class 2'))
i=1

for index, (name, classifier) in enumerate(classifiers.items()):
classifier.fit(X, y)

y_pred = classifier.predict(X)
classif_rate = np.mean(y_pred.ravel() == y.ravel()) * 100
print("classif_rate for %s : %f " % (name, classif_rate))

# View probabilities=
probas = classifier.predict_proba(Xfull)
n_classes = np.unique(y_pred).size

for k in range(n_classes):
idx = (y_pred == k)
if idx.any():
trace = go.Scatter(x=X[idx, 0], y=X[idx, 1],mode='markers',
showlegend=False,
marker=dict(color='black', size=10))

x_max, x_min= max(X[idx, 0]), min(X[idx, 0])
y_max, y_min= max(X[idx, 1]), min(X[idx, 1])

imshow_handle = go.Heatmap(z=probas[:, k].reshape((100, 100)),
x=np.linspace(x_min, x_max, 100),
y=np.linspace(y_min, y_max, 100),
showscale=False,
colorscale=matplotlib_to_plotly(cm.jet, len(probas[:, k].reshape((100, 100)))))

fig.append_trace(imshow_handle, i, k+1)
fig.append_trace(trace, i, k+1)
i=i+1

for k in map(str,range(1,16)):
x = 'xaxis' + k
y = 'yaxis' + k
fig['layout'][y].update(showticklabels=False, ticks='')
fig['layout'][x].update(showticklabels=False, ticks='')

titles = ['GPC','L2 logistic (Multinomial)','Linear SVC','L2 logistic (OvR)','L1 logistic']
i=0
for l in map(str,range(1,16,3)):
y = 'yaxis' + l
fig['layout'][y].update(title = titles[i])
i=i+1

fig['layout'].update(height=1000)

classif_rate for GPC : 82.666667
classif_rate for L2 logistic (OvR) : 76.666667
classif_rate for L1 logistic : 79.333333
classif_rate for Linear SVC : 82.000000
classif_rate for L2 logistic (Multinomial) : 82.000000

In [5]:
py.iplot(fig)

Out[5]:

Author:

    Alexandre Gramfort <alexandre.gramfort@inria.fr>



    BSD 3 clause