Show Sidebar Hide Sidebar

Gaussian Processes Regression Basic Introductory Example in Scikit-learn

A simple one-dimensional regression example computed in two different ways:

  1. A noise-free case
  2. A noisy case with known noise-level per datapoint

In both cases, the kernel’s parameters are estimated using the maximum likelihood principle.

The figures illustrate the interpolating property of the Gaussian Process model as well as its probabilistic nature in the form of a pointwise 95% confidence interval.

Note that the parameter alpha is applied as a Tikhonov regularization of the assumed covariance between the training points.

New to Plotly?

Plotly's Python library is free and open source! Get started by downloading the client and reading the primer.
You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

Version

In [1]:
import sklearn
sklearn.__version__
Out[1]:
'0.18.1'

Imports

This tutorial imports GaussianProcessRegressor and RBF.

In [2]:
import plotly.plotly as py
import plotly.graph_objs as go

import numpy as np
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C

Calculations

In [3]:
np.random.seed(1)


def f(x):
    """The function to predict."""
    return x * np.sin(x)

def data_to_plotly(x):
    k = []
    
    for i in range(0, len(x)):
        k.append(x[i][0])
        
    return k

The Noiseless Case

In [4]:
X = np.atleast_2d([1., 3., 5., 6., 7., 8.]).T

# Observations
y = f(X).ravel()

# Mesh the input space for evaluations of the real function, the prediction and
# its MSE
x = np.atleast_2d(np.linspace(0, 10, 1000)).T

# Instanciate a Gaussian Process model
kernel = C(1.0, (1e-3, 1e3)) * RBF(10, (1e-2, 1e2))
gp = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=9)

# Fit to data using Maximum Likelihood Estimation of the parameters
gp.fit(X, y)

# Make the prediction on the meshed x-axis (ask for MSE as well)
y_pred, sigma = gp.predict(x, return_std=True)

Plot the function, the prediction and the 95% confidence interval based on the MSE

In [5]:
p1 = go.Scatter(x=data_to_plotly(x), y=data_to_plotly(f(x)), 
                mode='lines',
                line=dict(color='red', dash='dot'),
                name=u'<i>f(x) = xsin(x)</i>')

p2 = go.Scatter(x=data_to_plotly(X), y=y, 
               mode='markers',
               marker=dict(color='red'),
               name=u'Observations')

p3 = go.Scatter(x=data_to_plotly(x), y=y_pred, 
                mode='lines',
                line=dict(color='blue'),
                name=u'Prediction',
               )

p4 = go.Scatter(x=data_to_plotly(np.concatenate([x, x[::-1]])),
                y=np.concatenate([y_pred - 1.9600 * sigma,]),
                mode='lines',
                line=dict(color='blue'),
                fill='tonexty',
                name='95% confidence interval')


data = [p3, p4, p1, p2]
layout = go.Layout(xaxis=dict(title='<i>x</i>'),
                   yaxis=dict(title='<i>f(x)</i>'),
                  )
fig = go.Figure(data=data, layout=layout)
In [6]:
py.iplot(fig)
Out[6]:

The Noisy Case

In [7]:
X = np.linspace(0.1, 9.9, 20)
X = np.atleast_2d(X).T

# Observations and noise
y = f(X).ravel()
dy = 0.5 + 1.0 * np.random.random(y.shape)
noise = np.random.normal(0, dy)
y += noise

# Instanciate a Gaussian Process model
gp = GaussianProcessRegressor(kernel=kernel, alpha=(dy / y) ** 2,
                              n_restarts_optimizer=10)

# Fit to data using Maximum Likelihood Estimation of the parameters
gp.fit(X, y)

# Make the prediction on the meshed x-axis (ask for MSE as well)
y_pred, sigma = gp.predict(x, return_std=True)

Plot the function, the prediction and the 95% confidence interval based on the MSE

In [8]:
p1 = go.Scatter(x=data_to_plotly(x), y=data_to_plotly(f(x)), 
                mode='lines',
                line=dict(color='red', dash='dot'),
                name=u'<i>f(x) = xsin(x)</i>')

p2 = go.Scatter(x=X.ravel(), y=y, 
               mode='markers',
               marker=dict(color='red'),
               error_y=dict(visible=True, arrayminus=dy),
               name=u'Observations')

p3 = go.Scatter(x=data_to_plotly(x), y=y_pred, 
                mode='lines',
                line=dict(color='blue'),
                name=u'Prediction',
               )

p4 = go.Scatter(x=data_to_plotly(np.concatenate([x, x[::-1]])),
                y=np.concatenate([y_pred - 1.9600 * sigma,]),
                mode='lines',
                line=dict(color='blue'),
                fill='tonexty',
                name='95% confidence interval')


data = [p3, p4, p1, p2]
layout = go.Layout(xaxis=dict(title='<i>x</i>'),
                   yaxis=dict(title='<i>f(x)</i>'),
                  )
fig = go.Figure(data=data, layout=layout)
In [9]:
py.iplot(fig)
Out[9]:

License

Author:

    Vincent Dubourg <vincent.dubourg@gmail.com>

    Jake Vanderplas <vanderplas@astro.washington.edu>

    Jan Hendrik Metzen <jhm@informatik.uni-bremen.de>

License:

    BSD 3 clause
Still need help?
Contact Us

For guaranteed 24 hour response turnarounds, upgrade to a Developer Support Plan.