Gaussian Process Classification (GPC) on the XOR Dataset in Scikit-learn

This example illustrates GPC on XOR data. Compared are a stationary, isotropic kernel (RBF) and a non-stationary kernel (DotProduct). On this particular dataset, the DotProduct kernel obtains considerably better results because the class-boundaries are linear and coincide with the coordinate axes. In general, stationary kernels often obtain better results.

In [1]:
import sklearn
sklearn.__version__

Out[1]:
'0.18.1'

Imports¶

This tutorial imports GaussianProcessClassifier, RBF and DotProduct.

In [2]:
import plotly.plotly as py
import plotly.graph_objs as go
from plotly import tools

import numpy as np
import matplotlib.pyplot as plt

from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.gaussian_process.kernels import RBF, DotProduct


Calculations¶

In [3]:
x_ = np.linspace(-3, 3, 50)
y_ = np.linspace(-3, 3, 50)

xx, yy = np.meshgrid(x_, y_)
rng = np.random.RandomState(0)
X = rng.randn(200, 2)
Y = np.logical_xor(X[:, 0] > 0, X[:, 1] > 0)

# fit the model
plt.figure(figsize=(10, 5))
kernels = [1.0 * RBF(length_scale=1.0), 1.0 * DotProduct(sigma_0=1.0)**2]


Plot Results¶

Define colormaps for plots

In [4]:
def matplotlib_to_plotly(cmap, pl_entries):
h = 1.0/(pl_entries-1)
pl_colorscale = []

for k in range(pl_entries):
C = map(np.uint8, np.array(cmap(k*h)[:3])*255)
pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))])

return pl_colorscale

cmap1 = matplotlib_to_plotly(plt.cm.PuOr_r, 4)
cmap2 = matplotlib_to_plotly(plt.cm.Paired, 4)

In [5]:
titles = [ ]
plots = [[], []]
for i, kernel in enumerate(kernels):
clf = GaussianProcessClassifier(kernel=kernel, warm_start=True).fit(X, Y)

# plot the decision function for each datapoint on the grid
Z = clf.predict_proba(np.vstack((xx.ravel(), yy.ravel())).T)[:, 1]
Z = Z.reshape(xx.shape)

contours = go.Heatmap(x=x_,
y=y_,
z=Z,
colorscale=cmap1
)
plots[i].append(contours)

scatter = go.Scatter(x=X[:, 0], y=X[:, 1],
showlegend=False,
mode='markers',
marker=dict(color=X[:, 0],
colorscale=cmap2,
showscale=False,
line=dict(color='black', width=1))
)
plots[i].append(scatter)

titles.append("%s<br> Log-Marginal-Likelihood:%.3f"
% (clf.kernel_, clf.log_marginal_likelihood(clf.kernel_.theta)))



In [6]:
fig = tools.make_subplots(rows=1, cols=2,
subplot_titles=tuple(titles),
print_grid=False
)
for i in range(0, len(plots[0])):
fig.append_trace(plots[0][i], 1, 1)
fig.append_trace(plots[1][i], 1, 2)

for i in map(str, range(1,3)):
y = 'yaxis'+i
x = 'xaxis'+i
fig['layout'][y].update(showticklabels=False, ticks='')
fig['layout'][x].update(showticklabels=False, ticks='')

In [7]:
py.iplot(fig)

Out[7]:

Authors:

    Jan Hendrik Metzen <jhm@informatik.uni-bremen.de>



    BSD 3 clause