Show Sidebar Hide Sidebar

Simple 1D Kernel Density Estimation in Scikit-learn

This example uses the sklearn.neighbors.KernelDensity class to demonstrate the principles of Kernel Density Estimation in one dimension.

The first plot shows one of the problems with using histograms to visualize the density of points in 1D. Intuitively, a histogram can be thought of as a scheme in which a unit “block” is stacked above each point on a regular grid. As the top two panels show, however, the choice of gridding for these blocks can lead to wildly divergent ideas about the underlying shape of the density distribution. If we instead center each block on the point it represents, we get the estimate shown in the bottom left panel. This is a kernel density estimation with a “top hat” kernel. This idea can be generalized to other kernel shapes: the bottom-right panel of the first figure shows a Gaussian kernel density estimate over the same distribution. Scikit-learn implements efficient kernel density estimation using either a Ball Tree or KD Tree structure, through the sklearn.neighbors.KernelDensity estimator. The available kernels are shown in the second figure of this example.

The third figure compares kernel density estimates for a distribution of 100 samples in 1 dimension. Though this example uses 1D distributions, kernel density estimation is easily and efficiently extensible to higher dimensions as well.

Version

In [1]:
import sklearn
sklearn.__version__
Out[1]:
'0.18.1'

Imports

This tutorial imports norm and KernelDensity.

In [2]:
import plotly.plotly as py
import plotly.graph_objs as go
from plotly import tools

import numpy as np
from scipy.stats import norm
from sklearn.neighbors import KernelDensity

Plot the Progression of Histograms to Kernels

In [3]:
# Plot the progression of histograms to kernels
np.random.seed(1)
N = 20
X = np.concatenate((np.random.normal(0, 1, 0.3 * N),
                    np.random.normal(5, 1, 0.7 * N)))[:, np.newaxis]
X_plot = np.linspace(-5, 10, 1000)[:, np.newaxis]
bins = 10

fig =tools.make_subplots(rows=2, cols=2,
                         subplot_titles=("Histogram", "Histogram, bins shifted",
                                          "Tophat Kernel Density", "Gaussian Kernel Density"))
# histogram 1
fig.append_trace(go.Histogram(x=X[:, 0], nbinsy=bins, 
                  marker=dict(color='#AAAAFF', 
                              line=dict(color='black', width=1))), 
                 1, 1)

# histogram 2
fig.append_trace(go.Histogram(x=X[:, 0], nbinsy=bins + 20, 
                              marker=dict(color='#AAAAFF',
                                          line=dict(color='black',
                                                    width=1))),
                              1, 2)

# tophat KDE
kde = KernelDensity(kernel='tophat', bandwidth=0.75).fit(X)
log_dens = kde.score_samples(X_plot)

fig.append_trace(go.Scatter(x=X_plot[:, 0], y=np.exp(log_dens),
                            mode='lines', fill='tozeroy',
                            line=dict(color='#AAAAFF', width=2)), 
                 2, 1)

# Gaussian KDE
kde = KernelDensity(kernel='gaussian', bandwidth=0.75).fit(X)
log_dens = kde.score_samples(X_plot)
fig.append_trace(go.Scatter(x=X_plot[:, 0], y=np.exp(log_dens), 
                            mode='lines', fill='tozeroy',
                            line=dict(color='#AAAAFF', width=2)),
                 2, 2)

for i in map(str, range(1, 5, 2)):
        y = 'yaxis' + i
        fig['layout'][y].update(title='Normalized Density', )
        
fig['layout'].update(hovermode='closest', height=600,
                     showlegend=False)
        
py.iplot(fig)
This is the format of your plot grid:
[ (1,1) x1,y1 ]  [ (1,2) x2,y2 ]
[ (2,1) x3,y3 ]  [ (2,2) x4,y4 ]

Out[3]:

Plot all available kernels

In [4]:
X_plot = np.linspace(-6, 6, 1000)[:, None]
X_src = np.zeros((1, 1))

fig = tools.make_subplots(rows=2, cols=3, print_grid=False,
                          subplot_titles=('gaussian', 'tophat', 'epanechnikov',
                                          'exponential', 'linear', 'cosine'))

def format_func(x, loc):
    if x == 0:
        return '0'
    elif x == 1:
        return 'h'
    elif x == -1:
        return '-h'
    else:
        return '%ih' % x



for i, kernel in enumerate(['gaussian', 'tophat', 'epanechnikov',
                            'exponential', 'linear', 'cosine']):
    log_dens = KernelDensity(kernel=kernel).fit(X_src).score_samples(X_plot)
    
    trace=go.Scatter(x=X_plot[:, 0], y=np.exp(log_dens),
                     mode='lines', fill='tozeroy',
                     line=dict(color='#AAAAFF', width=2)) 
    fig.append_trace(trace, i/3+1, i%3+1)               
    
fig['layout'].update(hovermode='closest',
                     showlegend=False, height=600, 
                     title='Available Kernels')
In [5]:
py.iplot(fig)
Out[5]:

Plot a 1D density example

In [6]:
N = 100
np.random.seed(1)
data = []
X = np.concatenate((np.random.normal(0, 1, 0.3 * N),
                    np.random.normal(5, 1, 0.7 * N)))[:, np.newaxis]

X_plot = np.linspace(-5, 10, 1000)[:, np.newaxis]

true_dens = (0.3 * norm(0, 1).pdf(X_plot[:, 0])
             + 0.7 * norm(5, 1).pdf(X_plot[:, 0]))


trace1 = go.Scatter(x=X_plot[:, 0], y=true_dens, 
                    mode='lines', fill='tozeroy',
                    line=dict(color='black', width=2), 
                    name='input distribution')
data.append(trace1)

for kernel in ['gaussian', 'tophat', 'epanechnikov']:
    kde = KernelDensity(kernel=kernel, bandwidth=0.5).fit(X)
    log_dens = kde.score_samples(X_plot)
    trace2 = go.Scatter(x=X_plot[:, 0], y=np.exp(log_dens),
                        mode='lines', 
                        line=dict(width=2, dash='dash'), 
                        name="kernel = '{0}'".format(kernel))
    data.append(trace2)
    
trace3 = go.Scatter(x=X[:, 0],
                    y=-0.005 - 0.01 * np.random.random(X.shape[0]), 
                    mode='markers', showlegend=False,
                    marker=dict(color='black'))  
data.append(trace3)

layout=go.Layout(annotations=[dict(x=6, y=0.38, showarrow=False,
                                   text="N={0} points".format(N)),
                                   ],
                 xaxis=dict(zeroline=False), hovermode='closest')
fig = go.Figure(data=data, layout=layout)
In [7]:
py.iplot(fig)
Out[7]:

License

Author:

    Jake Vanderplas <jakevdp@cs.washington.edu>
Still need help?
Contact Us

For guaranteed 24 hour response turnarounds, upgrade to a Developer Support Plan.