Show Sidebar Hide Sidebar

# Gaussian Processes Regression Basic Introductory Example in Scikit-learn

A simple one-dimensional regression example computed in two different ways:

1. A noise-free case
2. A noisy case with known noise-level per datapoint

In both cases, the kernelâ€™s parameters are estimated using the maximum likelihood principle.

The figures illustrate the interpolating property of the Gaussian Process model as well as its probabilistic nature in the form of a pointwise 95% confidence interval.

Note that the parameter alpha is applied as a Tikhonov regularization of the assumed covariance between the training points.

#### New to Plotly?¶

You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

### Version¶

In [1]:
import sklearn
sklearn.__version__

Out[1]:
'0.18.1'

### Imports¶

This tutorial imports GaussianProcessRegressor and RBF.

In [2]:
import plotly.plotly as py
import plotly.graph_objs as go

import numpy as np
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C


### Calculations¶

In [3]:
np.random.seed(1)

def f(x):
"""The function to predict."""
return x * np.sin(x)

def data_to_plotly(x):
k = []

for i in range(0, len(x)):
k.append(x[i][0])

return k


### The Noiseless Case¶

In [4]:
X = np.atleast_2d([1., 3., 5., 6., 7., 8.]).T

# Observations
y = f(X).ravel()

# Mesh the input space for evaluations of the real function, the prediction and
# its MSE
x = np.atleast_2d(np.linspace(0, 10, 1000)).T

# Instanciate a Gaussian Process model
kernel = C(1.0, (1e-3, 1e3)) * RBF(10, (1e-2, 1e2))
gp = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=9)

# Fit to data using Maximum Likelihood Estimation of the parameters
gp.fit(X, y)

# Make the prediction on the meshed x-axis (ask for MSE as well)
y_pred, sigma = gp.predict(x, return_std=True)


### Plot the function, the prediction and the 95% confidence interval based on the MSE¶

In [5]:
p1 = go.Scatter(x=data_to_plotly(x), y=data_to_plotly(f(x)),
mode='lines',
line=dict(color='red', dash='dot'),
name=u'<i>f(x) = xsin(x)</i>')

p2 = go.Scatter(x=data_to_plotly(X), y=y,
mode='markers',
marker=dict(color='red'),
name=u'Observations')

p3 = go.Scatter(x=data_to_plotly(x), y=y_pred,
mode='lines',
line=dict(color='blue'),
name=u'Prediction',
)

p4 = go.Scatter(x=data_to_plotly(np.concatenate([x, x[::-1]])),
y=np.concatenate([y_pred - 1.9600 * sigma,]),
mode='lines',
line=dict(color='blue'),
fill='tonexty',
name='95% confidence interval')

data = [p3, p4, p1, p2]
layout = go.Layout(xaxis=dict(title='<i>x</i>'),
yaxis=dict(title='<i>f(x)</i>'),
)
fig = go.Figure(data=data, layout=layout)

In [6]:
py.iplot(fig)

Out[6]:

### The Noisy Case¶

In [7]:
X = np.linspace(0.1, 9.9, 20)
X = np.atleast_2d(X).T

# Observations and noise
y = f(X).ravel()
dy = 0.5 + 1.0 * np.random.random(y.shape)
noise = np.random.normal(0, dy)
y += noise

# Instanciate a Gaussian Process model
gp = GaussianProcessRegressor(kernel=kernel, alpha=(dy / y) ** 2,
n_restarts_optimizer=10)

# Fit to data using Maximum Likelihood Estimation of the parameters
gp.fit(X, y)

# Make the prediction on the meshed x-axis (ask for MSE as well)
y_pred, sigma = gp.predict(x, return_std=True)


### Plot the function, the prediction and the 95% confidence interval based on the MSE¶

In [8]:
p1 = go.Scatter(x=data_to_plotly(x), y=data_to_plotly(f(x)),
mode='lines',
line=dict(color='red', dash='dot'),
name=u'<i>f(x) = xsin(x)</i>')

p2 = go.Scatter(x=X.ravel(), y=y,
mode='markers',
marker=dict(color='red'),
error_y=dict(visible=True, arrayminus=dy),
name=u'Observations')

p3 = go.Scatter(x=data_to_plotly(x), y=y_pred,
mode='lines',
line=dict(color='blue'),
name=u'Prediction',
)

p4 = go.Scatter(x=data_to_plotly(np.concatenate([x, x[::-1]])),
y=np.concatenate([y_pred - 1.9600 * sigma,]),
mode='lines',
line=dict(color='blue'),
fill='tonexty',
name='95% confidence interval')

data = [p3, p4, p1, p2]
layout = go.Layout(xaxis=dict(title='<i>x</i>'),
yaxis=dict(title='<i>f(x)</i>'),
)
fig = go.Figure(data=data, layout=layout)

In [9]:
py.iplot(fig)

Out[9]:

Author:

    Vincent Dubourg <vincent.dubourg@gmail.com>

Jake Vanderplas <vanderplas@astro.washington.edu>

Jan Hendrik Metzen <jhm@informatik.uni-bremen.de>



    BSD 3 clause