Show Sidebar Hide Sidebar

# Iso-Probability Iines for Gaussian Processes Classification in Scikit-learn

A two-dimensional classification example showing iso-probability lines for the predicted probabilities.

#### New to Plotly?¶

Plotly's Python library is free and open source! Get started by downloading the client and reading the primer.
You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

### Version¶

In [1]:
import sklearn
sklearn.__version__

Out[1]:
'0.18.1'

### Imports¶

This tutorial imports GaussianProcessClassifier and DotProduct.

In [2]:
import plotly.plotly as py
import plotly.graph_objs as go
import numpy as np

from matplotlib import pyplot as plt
from matplotlib import cm
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.gaussian_process.kernels import DotProduct, ConstantKernel as C


### Calculations¶

In [3]:
# A few constants
lim = 8

def g(x):
"""The function to predict (classification will then consist in predicting
whether g(x) <= 0 or not)"""
return 5. - x[:, 1] - .5 * x[:, 0] ** 2.

# Design of experiments
X = np.array([[-4.61611719, -6.00099547],
[4.10469096, 5.32782448],
[0.00000000, -0.50000000],
[-6.17289014, -4.6984743],
[1.3109306, -6.93271427],
[-5.03823144, 3.10584743],
[-2.87600388, 6.74310541],
[5.21301203, 4.26386883]])

# Observations
y = np.array(g(X) > 0, dtype=int)

# Instanciate and fit Gaussian Process Model
kernel = C(0.1, (1e-5, np.inf)) * DotProduct(sigma_0=0.1) ** 2
gp = GaussianProcessClassifier(kernel=kernel)
gp.fit(X, y)
print("Learned kernel: %s " % gp.kernel_)

# Evaluate real function and the predicted probability
res = 50
x_ = np.linspace(- lim, lim, res)
y_ =  np.linspace(- lim, lim, res)
x1, x2 = np.meshgrid(x_, y_)
xx = np.vstack([x1.reshape(x1.size), x2.reshape(x2.size)]).T

y_true = g(xx)
y_prob = gp.predict_proba(xx)[:, 1]
y_true = y_true.reshape((res, res))
y_prob = y_prob.reshape((res, res))

Learned kernel: 0.0256**2 * DotProduct(sigma_0=5.72) ** 2


### Plot the probabilistic classification iso-values¶

In [4]:
def matplotlib_to_plotly(cmap, pl_entries):
h = 1.0/(pl_entries-1)
pl_colorscale = []

for k in range(pl_entries):
C = map(np.uint8, np.array(cmap(k*h)[:3])*255)
pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))])

return pl_colorscale

cmap = matplotlib_to_plotly(cm.gray_r, 20)

In [5]:
cax = go.Heatmap(x=x_, y=y_,
z=y_prob,
colorscale=cmap,
)

trace1 = go.Scatter(x=X[y <= 0, 0], y=X[y <= 0, 1],
mode='markers',
marker=dict(color='red', size=10),
showlegend=False)

trace2 = go.Scatter(x=X[y > 0, 0], y=X[y > 0, 1],
mode='markers',
marker=dict(color='blue', size=10),
showlegend=False)

cs1 = go.Contour(x=x_,
y=y_[:: -1],
z=y_true,
ncontours=2,
contours=dict(coloring='lines',
),
line=dict(width=1, dash='dashdot'),
colorscale=[[0, 'black'], [1, 'white']],
showscale=False)

cs2 = go.Contour(x=x_, y=y_[:: -1], z=y_prob,
ncontours=2,
contours=dict(coloring='lines',
end=0.667,
start=0.666,
size=0.01),
line=dict(width=1),
colorscale=[[0, 'blue'], [1, 'white']],
showscale=False)

cs3 = go.Contour(x=x_, y=y_, z=y_prob[:: -1],
contours=dict(coloring='lines',
end=0.51,
start=0.5,
size=0.1),
line=dict(width=1, dash='dash'),
colorscale=[[0, 'black'], [1, 'white']],
showscale=False)

cs4 = go.Contour(x=x_, y=y_, z=y_prob[:: -1],
contours=dict(coloring='lines',
end=0.335,
start=0.334,
size=0.1),
line=dict(width=1),
colorscale=[[0, 'red'], [1, 'white']],
showscale=False)

In [6]:
layout = go.Layout(yaxis=dict(autorange='reversed', title='x<sub>2</sub>'),
xaxis=dict(title='x<sub>1</sub>'),
hovermode='closest',
annotations=[dict(
x=2, y=4.5,
xref='x', yref='y',
text='0.666',
showarrow=False,
font=dict(
family='Courier New, monospace',
size=12,
color='blue')),
dict(
x=2, y=0.9,
xref='x', yref='y',
text='0.5',
showarrow=False,
font=dict(
family='Courier New, monospace',
size=12,
color='black')),
dict(
x=2, y=-2.2,
xref='x', yref='y',
text='0.334',
showarrow=False,
font=dict(
family='Courier New, monospace',
size=12,
color='red')),
])

fig = go.Figure(data=[cax, cs1, cs2, cs3, cs4, trace1, trace2], layout=layout)
py.iplot(fig)

Out[6]:

Author:

    Vincent Dubourg <vincent.dubourg@gmail.com>



    Jan Hendrik Metzen <jhm@informatik.uni-bremen.de>

    BSD 3 clause