Show Sidebar Hide Sidebar

Vector Quantization Example in Scikit-learn

Face, a 1024 x 768 size image of a raccoon face, is used here to illustrate how k-means is used for vector quantization.

New to Plotly?

Plotly's Python library is free and open source! Get started by downloading the client and reading the primer.
You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

Version

In [1]:
import sklearn
sklearn.__version__
Out[1]:
'0.18'

Imports

In [2]:
print(__doc__)

import plotly.plotly as py
import plotly.graph_objs as go
from plotly import tools

import numpy as np
import scipy as sp
import matplotlib.pyplot as plt

from sklearn import cluster
from sklearn.utils.testing import SkipTest
from sklearn.utils.fixes import sp_version
Automatically created module for IPython interactive environment
In [3]:
if sp_version < (0, 12):
    raise SkipTest("Skipping because SciPy version earlier than 0.12.0 and "
                   "thus does not include the scipy.misc.face() image.")

try:
    face = sp.face(gray=True)
except AttributeError:
    # Newer versions of scipy have face in misc
    from scipy import misc
    face = misc.face(gray=True)

Calculations

In [4]:
n_clusters = 5
np.random.seed(0)

X = face.reshape((-1, 1))  # We need an (n_sample, n_feature) array
k_means = cluster.KMeans(n_clusters=n_clusters, n_init=4)
k_means.fit(X)
values = k_means.cluster_centers_.squeeze()
labels = k_means.labels_

# create an array from labels and values
face_compressed = np.choose(labels, values)
face_compressed.shape = face.shape

vmin = face.min()
vmax = face.max()

Plot Results

In [5]:
fig = tools.make_subplots(rows=1, cols=4)

def matplotlib_to_plotly(cmap, pl_entries):
    h = 1.0/(pl_entries-1)
    pl_colorscale = []
    
    for k in range(pl_entries):
        C = map(np.uint8, np.array(cmap(k*h)[:3])*255)
        pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))])
        
    return pl_colorscale

fig['layout'].update(width=900)
This is the format of your plot grid:
[ (1,1) x1,y1 ]  [ (1,2) x2,y2 ]  [ (1,3) x3,y3 ]  [ (1,4) x4,y4 ]

Original Face

In [6]:
orignal_face = go.Heatmap(z=face, showscale=False,
                          colorscale= matplotlib_to_plotly(plt.cm.gray, len(face)))

fig.append_trace(orignal_face, 1, 1)
fig['layout']['yaxis1'].update(autorange='reversed',
                               showticklabels=False, ticks='')
fig['layout']['xaxis1'].update(showticklabels=False, ticks='')

Compressed Face

In [7]:
compressed_face = go.Heatmap(z=face_compressed, showscale=False,
                          colorscale= matplotlib_to_plotly(plt.cm.gray, len(face)))

fig.append_trace(compressed_face, 1, 2)
fig['layout']['yaxis2'].update(autorange='reversed',
                               showticklabels=False, ticks='')
fig['layout']['xaxis2'].update(showticklabels=False, ticks='')

Equal Bins Face

In [8]:
regular_values = np.linspace(0, 256, n_clusters + 1)
regular_labels = np.searchsorted(regular_values, face) - 1
regular_values = .5 * (regular_values[1:] + regular_values[:-1])  # mean
regular_face = np.choose(regular_labels.ravel(), regular_values, mode="clip")
regular_face.shape = face.shape

regular_face = go.Heatmap(z=regular_face, showscale=False,
                          colorscale= matplotlib_to_plotly(plt.cm.gray, len(face)))

fig.append_trace(regular_face, 1, 3)
fig['layout']['yaxis3'].update(autorange='reversed',
                               showticklabels=False, ticks='')
fig['layout']['xaxis3'].update(showticklabels=False, ticks='')

Histogram

In [9]:
k = []
for i in range(0, len(X)):
    k.append(X[i][0])
    
hist = go.Histogram(x=k, showlegend=False,
                    marker=dict(color='rgb(211,211,211)'))

fig.append_trace(hist, 1, 4)

for center_1, center_2 in zip(values[:-1], values[1:]):
    axisline1 = go.Scatter(x=[.5 * (center_1 + center_2), .5 * (center_1 + center_2)], 
                           y=[0, 5000],
                           showlegend=False,
                           mode='lines',
                           line=dict(color='blue',
                                     width=2))
    
    fig.append_trace(axisline1, 1, 4)

for center_1, center_2 in zip(regular_values[:-1], regular_values[1:]):
    axisline2 = go.Scatter(x=[.5 * (center_1 + center_2), .5 * (center_1 + center_2)],
                           y=[0, 5000],
                           showlegend=False,
                           mode='lines',
                           line=dict(color='blue',
                                     width=2,
                                     dash='dash'))
    fig.append_trace(axisline2, 1, 4)
fig['layout']['yaxis4'].update(showticklabels=False, ticks='',
                               showgrid=False)
fig['layout']['xaxis4'].update(showticklabels=False, ticks='',
                               showgrid=False)
In [10]:
py.iplot(fig)
Out[10]:

License

Code source:

        Gaƫl Varoquaux

License:

        BSD 3 clause
Still need help?
Contact Us

For guaranteed 24 hour response turnarounds, upgrade to a Developer Support Plan.