Show Sidebar Hide Sidebar

Sparse Coding with a Precomputed Dictionary in Scikit-learn

Transform a signal as a sparse combination of Ricker wavelets. This example visually compares different sparse coding methods using the sklearn.decomposition.SparseCoder estimator. The Ricker (also known as Mexican hat or the second derivative of a Gaussian) is not a particularly good kernel to represent piecewise constant signals like this one. It can therefore be seen how much adding different widths of atoms matters and it therefore motivates learning the dictionary to best fit your type of signals.

The richer dictionary on the right is not larger in size, heavier subsampling is performed in order to stay on the same order of magnitude.

New to Plotly?

Plotly's Python library is free and open source! Get started by downloading the client and reading the primer.
You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

Version

In [1]:
import sklearn
sklearn.__version__
Out[1]:
'0.18'

Imports

This tutorial imports SparseCoder

In [2]:
print(__doc__)

import plotly.plotly as py
import plotly.graph_objs as go
from plotly import tools

import numpy as np
from sklearn.decomposition import SparseCoder
Automatically created module for IPython interactive environment

Calculations

In [3]:
def ricker_function(resolution, center, width):
    """Discrete sub-sampled Ricker (Mexican hat) wavelet"""
    x = np.linspace(0, resolution - 1, resolution)
    x = ((2 / ((np.sqrt(3 * width) * np.pi ** 1 / 4)))
         * (1 - ((x - center) ** 2 / width ** 2))
         * np.exp((-(x - center) ** 2) / (2 * width ** 2)))
    return x


def ricker_matrix(width, resolution, n_components):
    """Dictionary of Ricker (Mexican hat) wavelets"""
    centers = np.linspace(0, resolution - 1, n_components)
    D = np.empty((n_components, resolution))
    for i, center in enumerate(centers):
        D[i] = ricker_function(resolution, center, width)
    D /= np.sqrt(np.sum(D ** 2, axis=1))[:, np.newaxis]
    return D


resolution = 1024
subsampling = 3  # subsampling factor
width = 100
n_components = resolution / subsampling

# Compute a wavelet dictionary
D_fixed = ricker_matrix(width=width, resolution=resolution,
                        n_components=n_components)
D_multi = np.r_[tuple(ricker_matrix(width=w, resolution=resolution,
                                    n_components=np.floor(n_components / 5))
                for w in (10, 50, 100, 500, 1000))]

# Generate a signal
y = np.linspace(0, resolution - 1, resolution)
first_quarter = y < resolution / 4
y[first_quarter] = 3.
y[np.logical_not(first_quarter)] = -1.

# List the different sparse coding methods in the following format:
# (title, transform_algorithm, transform_alpha, transform_n_nozero_coefs)
estimators = [('OMP', 'omp', None, 15, 'navy'),
              ('Lasso', 'lasso_cd', 2, None, 'turquoise'), ]

Plot Results

In [4]:
data = [[],[]]
for plot, (D, title) in enumerate(zip((D_fixed, D_multi),
                                         ('fixed width', 'multiple widths'))):
   
    original_signal = go.Scatter(x=y,
                                 mode='lines',
                                 name='Original signal',
                                 line=dict(dash='dash'))
    data[plot].append(original_signal)
    
    # Do a wavelet approximation
    for title, algo, alpha, n_nonzero, color in estimators:
        coder = SparseCoder(dictionary=D, transform_n_nonzero_coefs=n_nonzero,
                            transform_alpha=alpha, transform_algorithm=algo)
        x = coder.transform(y.reshape(1, -1))
        density = len(np.flatnonzero(x))
        x = np.ravel(np.dot(x, D))
        squared_error = np.sum((y - x) ** 2)
        non_zero_coeff = go.Scatter(x=x,
                                    mode='lines',
                                    line=dict(color=color),
                                    name='%s: %s nonzero coefs,<br>%.2f error'
                                         % (title, density, squared_error))
        data[plot].append(non_zero_coeff)
    
    # Soft thresholding debiasing
    coder = SparseCoder(dictionary=D, transform_algorithm='threshold',
                        transform_alpha=20)
    x = coder.transform(y.reshape(1, -1))
    _, idx = np.where(x != 0)
    x[0, idx], _, _, _ = np.linalg.lstsq(D[idx, :].T, y)
    x = np.ravel(np.dot(x, D))
    squared_error = np.sum((y - x) ** 2)
    threshold = go.Scatter(x=x, 
                           mode='lines',
                           line=dict(color='darkorange'),
                           name='Thresholding w/ debiasing:<br>%d nonzero coefs, %.2f error'
                           % (len(idx), squared_error))
    data[plot].append(threshold)
    

Sparse coding against fixed width dictionary

In [5]:
layout = go.Layout(title='Sparse coding against fixed width dictionary')
fig = go.Figure(data = data[0], layout=layout)

py.iplot(fig)
Out[5]:

Sparse coding against multiple widths dictionary

In [6]:
layout = go.Layout(title='Sparse coding against multiple widths dictionary')
fig = go.Figure(data = data[1], layout=layout)

py.iplot(fig)
Out[6]:
Still need help?
Contact Us

For guaranteed 24 hour response turnarounds, upgrade to a Developer Support Plan.