Show Sidebar Hide Sidebar

# Iso-Probability Iines for Gaussian Processes Classification in Scikit-learn

A two-dimensional classification example showing iso-probability lines for the predicted probabilities.

#### New to Plotly?¶

You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

### Version¶

In [1]:
import sklearn
sklearn.__version__
Out[1]:
'0.18.1'

### Imports¶

This tutorial imports GaussianProcessClassifier and DotProduct.

In [2]:
import plotly.plotly as py
import plotly.graph_objs as go
import numpy as np

from matplotlib import pyplot as plt
from matplotlib import cm
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.gaussian_process.kernels import DotProduct, ConstantKernel as C

### Calculations¶

In [3]:
# A few constants
lim = 8

def g(x):
"""The function to predict (classification will then consist in predicting
whether g(x) <= 0 or not)"""
return 5. - x[:, 1] - .5 * x[:, 0] ** 2.

# Design of experiments
X = np.array([[-4.61611719, -6.00099547],
[4.10469096, 5.32782448],
[0.00000000, -0.50000000],
[-6.17289014, -4.6984743],
[1.3109306, -6.93271427],
[-5.03823144, 3.10584743],
[-2.87600388, 6.74310541],
[5.21301203, 4.26386883]])

# Observations
y = np.array(g(X) > 0, dtype=int)

# Instanciate and fit Gaussian Process Model
kernel = C(0.1, (1e-5, np.inf)) * DotProduct(sigma_0=0.1) ** 2
gp = GaussianProcessClassifier(kernel=kernel)
gp.fit(X, y)
print("Learned kernel: %s " % gp.kernel_)

# Evaluate real function and the predicted probability
res = 50
x_ = np.linspace(- lim, lim, res)
y_ =  np.linspace(- lim, lim, res)
x1, x2 = np.meshgrid(x_, y_)
xx = np.vstack([x1.reshape(x1.size), x2.reshape(x2.size)]).T

y_true = g(xx)
y_prob = gp.predict_proba(xx)[:, 1]
y_true = y_true.reshape((res, res))
y_prob = y_prob.reshape((res, res))
Learned kernel: 0.0256**2 * DotProduct(sigma_0=5.72) ** 2

### Plot the probabilistic classification iso-values¶

In [4]:
def matplotlib_to_plotly(cmap, pl_entries):
h = 1.0/(pl_entries-1)
pl_colorscale = []

for k in range(pl_entries):
C = map(np.uint8, np.array(cmap(k*h)[:3])*255)
pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))])

return pl_colorscale

cmap = matplotlib_to_plotly(cm.gray_r, 20)
In [5]:
cax = go.Heatmap(x=x_, y=y_,
z=y_prob,
colorscale=cmap,
)

trace1 = go.Scatter(x=X[y <= 0, 0], y=X[y <= 0, 1],
mode='markers',
marker=dict(color='red', size=10),
showlegend=False)

trace2 = go.Scatter(x=X[y > 0, 0], y=X[y > 0, 1],
mode='markers',
marker=dict(color='blue', size=10),
showlegend=False)

cs1 = go.Contour(x=x_,
y=y_[:: -1],
z=y_true,
ncontours=2,
contours=dict(coloring='lines',
),
line=dict(width=1, dash='dashdot'),
colorscale=[[0, 'black'], [1, 'white']],
showscale=False)

cs2 = go.Contour(x=x_, y=y_[:: -1], z=y_prob,
ncontours=2,
contours=dict(coloring='lines',
end=0.667,
start=0.666,
size=0.01),
line=dict(width=1),
colorscale=[[0, 'blue'], [1, 'white']],
showscale=False)

cs3 = go.Contour(x=x_, y=y_, z=y_prob[:: -1],
contours=dict(coloring='lines',
end=0.51,
start=0.5,
size=0.1),
line=dict(width=1, dash='dash'),
colorscale=[[0, 'black'], [1, 'white']],
showscale=False)

cs4 = go.Contour(x=x_, y=y_, z=y_prob[:: -1],
contours=dict(coloring='lines',
end=0.335,
start=0.334,
size=0.1),
line=dict(width=1),
colorscale=[[0, 'red'], [1, 'white']],
showscale=False)
In [6]:
layout = go.Layout(yaxis=dict(autorange='reversed', title='x<sub>2</sub>'),
xaxis=dict(title='x<sub>1</sub>'),
hovermode='closest',
annotations=[dict(
x=2, y=4.5,
xref='x', yref='y',
text='0.666',
showarrow=False,
font=dict(
family='Courier New, monospace',
size=12,
color='blue')),
dict(
x=2, y=0.9,
xref='x', yref='y',
text='0.5',
showarrow=False,
font=dict(
family='Courier New, monospace',
size=12,
color='black')),
dict(
x=2, y=-2.2,
xref='x', yref='y',
text='0.334',
showarrow=False,
font=dict(
family='Courier New, monospace',
size=12,
color='red')),
])

fig = go.Figure(data=[cax, cs1, cs2, cs3, cs4, trace1, trace2], layout=layout)
py.iplot(fig)
Out[6]:

Author:

Vincent Dubourg <vincent.dubourg@gmail.com>