Show Sidebar Hide Sidebar

Iso-Probability Iines for Gaussian Processes Classification in Scikit-learn

A two-dimensional classification example showing iso-probability lines for the predicted probabilities.

New to Plotly?

Plotly's Python library is free and open source! Get started by downloading the client and reading the primer.
You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

Version

In [1]:
import sklearn
sklearn.__version__
Out[1]:
'0.18.1'

Imports

This tutorial imports GaussianProcessClassifier and DotProduct.

In [2]:
import plotly.plotly as py
import plotly.graph_objs as go
import numpy as np

from matplotlib import pyplot as plt
from matplotlib import cm
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.gaussian_process.kernels import DotProduct, ConstantKernel as C

Calculations

In [3]:
# A few constants
lim = 8


def g(x):
    """The function to predict (classification will then consist in predicting
    whether g(x) <= 0 or not)"""
    return 5. - x[:, 1] - .5 * x[:, 0] ** 2.

# Design of experiments
X = np.array([[-4.61611719, -6.00099547],
              [4.10469096, 5.32782448],
              [0.00000000, -0.50000000],
              [-6.17289014, -4.6984743],
              [1.3109306, -6.93271427],
              [-5.03823144, 3.10584743],
              [-2.87600388, 6.74310541],
              [5.21301203, 4.26386883]])

# Observations
y = np.array(g(X) > 0, dtype=int)

# Instanciate and fit Gaussian Process Model
kernel = C(0.1, (1e-5, np.inf)) * DotProduct(sigma_0=0.1) ** 2
gp = GaussianProcessClassifier(kernel=kernel)
gp.fit(X, y)
print("Learned kernel: %s " % gp.kernel_)

# Evaluate real function and the predicted probability
res = 50
x_ = np.linspace(- lim, lim, res)
y_ =  np.linspace(- lim, lim, res)
x1, x2 = np.meshgrid(x_, y_)
xx = np.vstack([x1.reshape(x1.size), x2.reshape(x2.size)]).T

y_true = g(xx)
y_prob = gp.predict_proba(xx)[:, 1]
y_true = y_true.reshape((res, res))
y_prob = y_prob.reshape((res, res))
Learned kernel: 0.0256**2 * DotProduct(sigma_0=5.72) ** 2 

Plot the probabilistic classification iso-values

In [4]:
def matplotlib_to_plotly(cmap, pl_entries):
    h = 1.0/(pl_entries-1)
    pl_colorscale = []
    
    for k in range(pl_entries):
        C = map(np.uint8, np.array(cmap(k*h)[:3])*255)
        pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))])
        
    return pl_colorscale

cmap = matplotlib_to_plotly(cm.gray_r, 20)
In [5]:
cax = go.Heatmap(x=x_, y=y_,
                 z=y_prob, 
                 colorscale=cmap, 
                 )

trace1 = go.Scatter(x=X[y <= 0, 0], y=X[y <= 0, 1],
                    mode='markers', 
                    marker=dict(color='red', size=10),
                    showlegend=False)

trace2 = go.Scatter(x=X[y > 0, 0], y=X[y > 0, 1],
                    mode='markers', 
                    marker=dict(color='blue', size=10),
                    showlegend=False)

cs1 = go.Contour(x=x_,
                 y=y_[:: -1],
                 z=y_true, 
                 ncontours=2,
                 contours=dict(coloring='lines', 
                              ),
                 line=dict(width=1, dash='dashdot'),
                 colorscale=[[0, 'black'], [1, 'white']],
                 showscale=False)

cs2 = go.Contour(x=x_, y=y_[:: -1], z=y_prob, 
                 ncontours=2,
                 contours=dict(coloring='lines', 
                               end=0.667, 
                               start=0.666, 
                               size=0.01),
                 line=dict(width=1),
                 colorscale=[[0, 'blue'], [1, 'white']],
                 showscale=False)

cs3 = go.Contour(x=x_, y=y_, z=y_prob[:: -1],  
                 contours=dict(coloring='lines', 
                               end=0.51, 
                               start=0.5, 
                               size=0.1),
                 line=dict(width=1, dash='dash'),
                 colorscale=[[0, 'black'], [1, 'white']],
                 showscale=False)


cs4 = go.Contour(x=x_, y=y_, z=y_prob[:: -1],  
                 contours=dict(coloring='lines', 
                               end=0.335, 
                               start=0.334, 
                               size=0.1),
                 line=dict(width=1),
                 colorscale=[[0, 'red'], [1, 'white']],
                 showscale=False)
In [6]:
layout = go.Layout(yaxis=dict(autorange='reversed', title='x<sub>2</sub>'),
                   xaxis=dict(title='x<sub>1</sub>'),
                   hovermode='closest',
                   annotations=[dict(
                                    x=2, y=4.5,
                                    xref='x', yref='y',
                                    text='0.666',
                                    showarrow=False,
                                    font=dict(
                                        family='Courier New, monospace',
                                        size=12,
                                        color='blue')),
                                dict(
                                    x=2, y=0.9,
                                    xref='x', yref='y',
                                    text='0.5',
                                    showarrow=False,
                                    font=dict(
                                        family='Courier New, monospace',
                                        size=12,
                                        color='black')),
                                dict(
                                    x=2, y=-2.2,
                                    xref='x', yref='y',
                                    text='0.334',
                                    showarrow=False,
                                    font=dict(
                                        family='Courier New, monospace',
                                        size=12,
                                        color='red')),
                               ])

fig = go.Figure(data=[cax, cs1, cs2, cs3, cs4, trace1, trace2], layout=layout)
py.iplot(fig)
Out[6]:

License

Author:

    Vincent Dubourg <vincent.dubourg@gmail.com>

Adapted to GaussianProcessClassifier:

    Jan Hendrik Metzen <jhm@informatik.uni-bremen.de>

License:

    BSD 3 clause
Still need help?
Contact Us

For guaranteed 24 hour response turnarounds, upgrade to a Developer Support Plan.