Show Sidebar Hide Sidebar

# Vector Quantization Example in Scikit-learn

Face, a 1024 x 768 size image of a raccoon face, is used here to illustrate how k-means is used for vector quantization.

#### New to Plotly?¶

You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

### Version¶

In [1]:
import sklearn
sklearn.__version__

Out[1]:
'0.18'

### Imports¶

In [2]:
print(__doc__)

import plotly.plotly as py
import plotly.graph_objs as go
from plotly import tools

import numpy as np
import scipy as sp
import matplotlib.pyplot as plt

from sklearn import cluster
from sklearn.utils.testing import SkipTest
from sklearn.utils.fixes import sp_version

Automatically created module for IPython interactive environment

In [3]:
if sp_version < (0, 12):
raise SkipTest("Skipping because SciPy version earlier than 0.12.0 and "
"thus does not include the scipy.misc.face() image.")

try:
face = sp.face(gray=True)
except AttributeError:
# Newer versions of scipy have face in misc
from scipy import misc
face = misc.face(gray=True)


### Calculations¶

In [4]:
n_clusters = 5
np.random.seed(0)

X = face.reshape((-1, 1))  # We need an (n_sample, n_feature) array
k_means = cluster.KMeans(n_clusters=n_clusters, n_init=4)
k_means.fit(X)
values = k_means.cluster_centers_.squeeze()
labels = k_means.labels_

# create an array from labels and values
face_compressed = np.choose(labels, values)
face_compressed.shape = face.shape

vmin = face.min()
vmax = face.max()


### Plot Results¶

In [5]:
fig = tools.make_subplots(rows=1, cols=4)

def matplotlib_to_plotly(cmap, pl_entries):
h = 1.0/(pl_entries-1)
pl_colorscale = []

for k in range(pl_entries):
C = map(np.uint8, np.array(cmap(k*h)[:3])*255)
pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))])

return pl_colorscale

fig['layout'].update(width=900)

This is the format of your plot grid:
[ (1,1) x1,y1 ]  [ (1,2) x2,y2 ]  [ (1,3) x3,y3 ]  [ (1,4) x4,y4 ]



Original Face

In [6]:
orignal_face = go.Heatmap(z=face, showscale=False,
colorscale= matplotlib_to_plotly(plt.cm.gray, len(face)))

fig.append_trace(orignal_face, 1, 1)
fig['layout']['yaxis1'].update(autorange='reversed',
showticklabels=False, ticks='')
fig['layout']['xaxis1'].update(showticklabels=False, ticks='')


Compressed Face

In [7]:
compressed_face = go.Heatmap(z=face_compressed, showscale=False,
colorscale= matplotlib_to_plotly(plt.cm.gray, len(face)))

fig.append_trace(compressed_face, 1, 2)
fig['layout']['yaxis2'].update(autorange='reversed',
showticklabels=False, ticks='')
fig['layout']['xaxis2'].update(showticklabels=False, ticks='')


Equal Bins Face

In [8]:
regular_values = np.linspace(0, 256, n_clusters + 1)
regular_labels = np.searchsorted(regular_values, face) - 1
regular_values = .5 * (regular_values[1:] + regular_values[:-1])  # mean
regular_face = np.choose(regular_labels.ravel(), regular_values, mode="clip")
regular_face.shape = face.shape

regular_face = go.Heatmap(z=regular_face, showscale=False,
colorscale= matplotlib_to_plotly(plt.cm.gray, len(face)))

fig.append_trace(regular_face, 1, 3)
fig['layout']['yaxis3'].update(autorange='reversed',
showticklabels=False, ticks='')
fig['layout']['xaxis3'].update(showticklabels=False, ticks='')


Histogram

In [9]:
k = []
for i in range(0, len(X)):
k.append(X[i][0])

hist = go.Histogram(x=k, showlegend=False,
marker=dict(color='rgb(211,211,211)'))

fig.append_trace(hist, 1, 4)

for center_1, center_2 in zip(values[:-1], values[1:]):
axisline1 = go.Scatter(x=[.5 * (center_1 + center_2), .5 * (center_1 + center_2)],
y=[0, 5000],
showlegend=False,
mode='lines',
line=dict(color='blue',
width=2))

fig.append_trace(axisline1, 1, 4)

for center_1, center_2 in zip(regular_values[:-1], regular_values[1:]):
axisline2 = go.Scatter(x=[.5 * (center_1 + center_2), .5 * (center_1 + center_2)],
y=[0, 5000],
showlegend=False,
mode='lines',
line=dict(color='blue',
width=2,
dash='dash'))
fig.append_trace(axisline2, 1, 4)
fig['layout']['yaxis4'].update(showticklabels=False, ticks='',
showgrid=False)
fig['layout']['xaxis4'].update(showticklabels=False, ticks='',
showgrid=False)

In [10]:
py.iplot(fig)

Out[10]:

Code source:

        GaĆ«l Varoquaux



        BSD 3 clause
Still need help?