Show Sidebar Hide Sidebar

Spectral Clustering for Image Segmentation in Scikit-learn

In this example, an image with connected circles is generated and spectral clustering is used to separate the circles.

In these settings, the Spectral clustering approach solves the problem know as ‘normalized graph cuts’: the image is seen as a graph of connected voxels, and the spectral clustering algorithm amounts to choosing graph cuts defining regions while minimizing the ratio of the gradient along the cut, and the volume of the region. As the algorithm tries to balance the volume (ie balance the region sizes), if we take circles with different sizes, the segmentation fails.

In addition, as there is no useful information in the intensity of the image, or its gradient, we choose to perform the spectral clustering on a graph that is only weakly informed by the gradient. This is close to performing a Voronoi partition of the graph.

In addition, we use the mask of the objects to restrict the graph to the outline of the objects. In this example, we are interested in separating the objects one from the other, and not from the background.

New to Plotly?

Plotly's Python library is free and open source! Get started by downloading the client and reading the primer.
You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

Version

In [1]:
import sklearn
sklearn.__version__
Out[1]:
'0.18'

Imports

This tutorial imports spectral_clustering.

In [2]:
print(__doc__)

import plotly.plotly as py
import plotly.graph_objs as go
from plotly import tools

import numpy as np
from sklearn.feature_extraction import image
from sklearn.cluster import spectral_clustering
Automatically created module for IPython interactive environment

Calculations

In [3]:
l = 100
x, y = np.indices((l, l))

center1 = (28, 24)
center2 = (40, 50)
center3 = (67, 58)
center4 = (24, 70)

radius1, radius2, radius3, radius4 = 16, 14, 15, 14

circle1 = (x - center1[0]) ** 2 + (y - center1[1]) ** 2 < radius1 ** 2
circle2 = (x - center2[0]) ** 2 + (y - center2[1]) ** 2 < radius2 ** 2
circle3 = (x - center3[0]) ** 2 + (y - center3[1]) ** 2 < radius3 ** 2
circle4 = (x - center4[0]) ** 2 + (y - center4[1]) ** 2 < radius4 ** 2

Plot Results

In [4]:
fig = tools.make_subplots(rows=1, cols=2)
This is the format of your plot grid:
[ (1,1) x1,y1 ]  [ (1,2) x2,y2 ]

4 circles

In [5]:
img = circle1 + circle2 + circle3 + circle4

# We use a mask that limits to the foreground: the problem that we are
# interested in here is not separating the objects from the background,
# but separating them one from the other.
mask = img.astype(bool)

img = img.astype(float)
img += 1 + 0.2 * np.random.randn(*img.shape)

# Convert the image into a graph with the value of the gradient on the
# edges.
graph = image.img_to_graph(img, mask=mask)

# Take a decreasing function of the gradient: we take it weakly
# dependent from the gradient the segmentation is close to a voronoi
graph.data = np.exp(-graph.data / graph.data.std())

# Force the solver to be arpack, since amg is numerically
# unstable on this example
labels = spectral_clustering(graph, n_clusters=4, eigen_solver='arpack')
label_im = -np.ones(mask.shape)
label_im[mask] = labels
In [6]:
normal = go.Heatmap(z=img, showscale=False, 
                   colorscale='YIGnBu')

spectral_clustering_ = go.Heatmap(z=label_im, showscale=False,
                                 colorscale='YIGnBu')

fig.append_trace(normal, 1, 1)
fig.append_trace(spectral_clustering_, 1, 2)

fig['layout']['yaxis1'].update(autorange='reversed')
fig['layout']['yaxis2'].update(autorange='reversed')

py.iplot(fig)
Out[6]:

2 circles

In [7]:
img = circle1 + circle2
mask = img.astype(bool)
img = img.astype(float)

img += 1 + 0.2 * np.random.randn(*img.shape)

graph = image.img_to_graph(img, mask=mask)
graph.data = np.exp(-graph.data / graph.data.std())

labels = spectral_clustering(graph, n_clusters=2, eigen_solver='arpack')
label_im = -np.ones(mask.shape)
label_im[mask] = labels
In [8]:
normal1 = go.Heatmap(z=img, showscale=False,
                   colorscale='YIGnBu')

spectral_clustering1 = go.Heatmap(z=label_im, showscale=False,
                                 colorscale='YIGnBu')


fig.append_trace(normal1, 1, 1)
fig.append_trace(spectral_clustering1, 1, 2)

fig['layout']['yaxis1'].update(autorange='reversed')
fig['layout']['yaxis2'].update(autorange='reversed')

py.iplot(fig)
Out[8]:

License

Authors:

      Emmanuelle Gouillart <emmanuelle.gouillart@normalesup.org>

      Gael Varoquaux <gael.varoquaux@normalesup.org>

License:

      BSD 3 clause
Still need help?
Contact Us

For guaranteed 24 hour response turnarounds, upgrade to a Developer Support Plan.