Show Sidebar Hide Sidebar

Image Denoising using Dictionary Learning in Scikit-learn

An example comparing the effect of reconstructing noisy fragments of a raccoon face image using firstly online Dictionary Learning and various transform methods.

The dictionary is fitted on the distorted left half of the image, and subsequently used to reconstruct the right half. Note that even better performance could be achieved by fitting to an undistorted (i.e. noiseless) image, but here we start from the assumption that it is not available.

A common practice for evaluating the results of image denoising is by looking at the difference between the reconstruction and the original image. If the reconstruction is perfect this will look like Gaussian noise. It can be seen from the plots that the results of Orthogonal Matching Pursuit (OMP) with two non-zero coefficients is a bit less biased than when keeping only one (the edges look less prominent). It is in addition closer from the ground truth in Frobenius norm.

The result of Least Angle Regression is much more strongly biased: the difference is reminiscent of the local intensity value of the original image.

Thresholding is clearly not useful for denoising, but it is here to show that it can produce a suggestive output with very high speed, and thus be useful for other tasks such as object classification, where performance is not necessarily related to visualisation.

New to Plotly?

Plotly's Python library is free and open source! Get started by downloading the client and reading the primer.
You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

Version

In [1]:
import sklearn
sklearn.__version__
Out[1]:
'0.18'

Imports

In [2]:
print(__doc__)

import plotly.plotly as py
import plotly.graph_objs as go
from plotly import tools

from time import time
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp

from sklearn.decomposition import MiniBatchDictionaryLearning
from sklearn.feature_extraction.image import extract_patches_2d
from sklearn.feature_extraction.image import reconstruct_from_patches_2d
from sklearn.utils.testing import SkipTest
from sklearn.utils.fixes import sp_version

if sp_version < (0, 12):
    raise SkipTest("Skipping because SciPy version earlier than 0.12.0 and "
                   "thus does not include the scipy.misc.face() image.")
Automatically created module for IPython interactive environment

Calculations

In [3]:
try:
    from scipy import misc
    face = misc.face(gray=True)
except AttributeError:
    # Old versions of scipy have face in the top level package
    face = sp.face(gray=True)

height, width = face.shape

# Distort the right half of the image
print('Distorting image...')
distorted = face.copy()
distorted[:, width // 2:] = distorted[:, width // 2:] + 0.075 * np.random.randn(height, width // 2)

# Extract all reference patches from the left half of the image
print('Extracting reference patches...')
t0 = time()
patch_size = (7, 7)
data = extract_patches_2d(distorted[:, :width // 2], patch_size)
data = data.reshape(data.shape[0], -1)
data = data - np.mean(data, axis=0)
data = data/np.std(data, axis=0)
print('done in %.2fs.' % (time() - t0))
Distorting image...
Extracting reference patches...
done in 0.20s.

Plot Results

Learn the dictionary from reference patches

In [4]:
print('Learning the dictionary...')
t0 = time()
dico = MiniBatchDictionaryLearning(n_components=100, alpha=1, n_iter=500)
V = dico.fit(data).components_
dt = time() - t0
print('done in %.2fs.' % dt)
Learning the dictionary...
done in 12.13s.
In [5]:
def matplotlib_to_plotly(cmap, pl_entries):
    h = 1.0/(pl_entries-1)
    pl_colorscale = []
    
    for k in range(pl_entries):
        C = map(np.uint8, np.array(cmap(k*h)[:3])*255)
        pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))])
        
    return pl_colorscale

Plot Dictionary Learned from Face Patches

In [6]:
fig = tools.make_subplots(rows=10, cols=10,
                         print_grid=False)
j = 1

for i, comp in enumerate(V[:100]):
    trace = go.Heatmap(z=comp.reshape(patch_size), 
                       colorscale=matplotlib_to_plotly(plt.cm.gray_r, len(comp.reshape(patch_size))),
                       showscale=False)
    
    k = i/10+1
    j = j%10
    if(j==0):
        j = 10
    fig.append_trace(trace, k, j)
    j=j+1
    
fig['layout'].update(title='Dictionary learned from face patches<br>' +
                            'Train time %.1fs on %d patches' % (dt, len(data)),
                     height=1000
                    )
for i in map(str,range(1,101)):
        y = 'yaxis' + i
        x = 'xaxis' + i
        fig['layout'][y].update(autorange='reversed',
                                showticklabels=False, ticks='')
        fig['layout'][x].update(showticklabels=False, ticks='')
        
In [7]:
py.iplot(fig)
Out[7]:

Display the distorted image

In [8]:
def show_with_diff(image, reference, title):
    """Helper function to display denoising"""
    
    trace1 = go.Heatmap(z=image, 
                        showscale=False,
                        colorscale=matplotlib_to_plotly(plt.cm.gray, 20),
                       )
    
    difference = image - reference
    fig = tools.make_subplots(rows=1, cols=2,
                              print_grid=False,
                              subplot_titles=('Image',
                                              'Difference (norm: %.2f)' % np.sqrt(np.sum(difference ** 2)))
                             )
    
    trace2 = go.Heatmap(z=difference, 
                        showscale=False,
                        colorscale=matplotlib_to_plotly(plt.cm.Blues, 10),
                       )
    fig.append_trace(trace1, 1, 1)
    fig.append_trace(trace2, 1, 2)
    
    for i in map(str, range(1,3)):
        y = 'yaxis' + i
        x = 'xaxis' + i
        fig['layout'][y].update(autorange='reversed',
                                showticklabels=False, ticks='')
        fig['layout'][x].update(showticklabels=False, ticks='')
    
    fig['layout'].update(title=title)
    return fig
In [9]:
py.iplot(show_with_diff(distorted, face, 'Distorted image'))
Out[9]:

Extract noisy patches and reconstruct them using the dictionary

In [10]:
print('Extracting noisy patches... ')
t0 = time()
data = extract_patches_2d(distorted[:, width // 2:], patch_size)
data = data.reshape(data.shape[0], -1)
intercept = np.mean(data, axis=0)
data = data - intercept
print('done in %.2fs.' % (time() - t0))
Extracting noisy patches... 
done in 0.10s.
In [11]:
transform_algorithms = [
            ('Orthogonal Matching Pursuit<br>1 atom', 'omp',
             {'transform_n_nonzero_coefs': 1}),
            ('Orthogonal Matching Pursuit<br>2 atoms', 'omp',
             {'transform_n_nonzero_coefs': 2}),
            ('Least-angle regression<br>5 atoms', 'lars',
             {'transform_n_nonzero_coefs': 5}),
            ('Thresholding<br> alpha=0.1', 'threshold', {'transform_alpha': .1})
            ]

reconstructions = {}
plot = []

for title, transform_algorithm, kwargs in transform_algorithms:
    print(title + '...')
    reconstructions[title] = face.copy()
    t0 = time()
    dico.set_params(transform_algorithm=transform_algorithm, **kwargs)
    code = dico.transform(data)
    patches = np.dot(code, V)

    patches += intercept
    patches = patches.reshape(len(data), *patch_size)
    if transform_algorithm == 'threshold':
        patches -= patches.min()
        patches /= patches.max()
    reconstructions[title][:, width // 2:] = reconstruct_from_patches_2d(
        patches, (height, width // 2))
    dt = time() - t0
    print('done in %.2fs.' % dt)
    
    plot.append(show_with_diff(reconstructions[title], face,
                               title + ' (time: %.1fs)' % dt))
Orthogonal Matching Pursuit<br>1 atom...
done in 20.44s.
Orthogonal Matching Pursuit<br>2 atoms...
done in 40.17s.
Least-angle regression<br>5 atoms...
done in 220.06s.
Thresholding<br> alpha=0.1...
done in 2.02s.

Orthogonal Matching Pursuit 1 atom

In [12]:
py.iplot(plot[0])
Out[12]:

Orthogonal Matching Pursuit 2 atoms

In [13]:
py.iplot(plot[1])
Out[13]:

Least-angle regression 5 atoms

In [14]:
py.iplot(plot[2])
Out[14]:

Thresholding alpha=0.1

In [15]:
py.iplot(plot[3])
Out[15]:
Still need help?
Contact Us

For guaranteed 24 hour response turnarounds, upgrade to a Developer Support Plan.