# Feature Agglomeration in Scikit-learn

### Version¶

In [1]:
import sklearn
sklearn.__version__

Out[1]:
'0.18'

### Imports¶

In [2]:
print(__doc__)

import plotly.plotly as py
import plotly.graph_objs as go
from plotly import tools

import numpy as np
import matplotlib.pyplot as plt

from sklearn import datasets, cluster
from sklearn.feature_extraction.image import grid_to_graph

Automatically created module for IPython interactive environment


### Calculations¶

In [3]:
digits = datasets.load_digits()
images = digits.images
X = np.reshape(images, (len(images), -1))
connectivity = grid_to_graph(*images[0].shape)

agglo = cluster.FeatureAgglomeration(connectivity=connectivity,
n_clusters=32)

agglo.fit(X)
X_reduced = agglo.transform(X)

X_restored = agglo.inverse_transform(X_reduced)
images_restored = np.reshape(X_restored, images.shape)


### Plot Result¶

In [4]:
fig = tools.make_subplots(rows=3, cols=4,
print_grid=False,
subplot_titles = ('','Original Data','','',
'','Agglomerated Data','','',
'Labels'),
specs=[[{}, {}, {}, {}],
[{}, {}, {}, {}],
[None, {}, None, None]
])

def matplotlib_to_plotly(cmap, pl_entries):
h = 1.0/(pl_entries-1)
pl_colorscale = []

for k in range(pl_entries):
C = map(np.uint8, np.array(cmap(k*h)[:3])*255)
pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))])

return pl_colorscale

for i in range(4):
original = go.Heatmap(z=images[i], showscale=False,
colorscale=matplotlib_to_plotly(plt.cm.gray,
len(images[i])))
fig.append_trace(original, 1, i+1)

agglomerated = go.Heatmap(z=images_restored[i],
showscale=False,
colorscale=matplotlib_to_plotly(plt.cm.gray,
len(images_restored[i])))
fig.append_trace(agglomerated , 2, i+1)

labels = go.Heatmap(z=np.reshape(agglo.labels_, images[0].shape),
showscale=False,
colorscale=matplotlib_to_plotly(plt.cm.spectral,
len(np.reshape(agglo.labels_, images[0].shape))))
fig.append_trace(labels , 3, 2)

fig['layout'].update(height=900)

for i in map(str,range(1,10)):
y = 'yaxis'+i
x = 'xaxis'+i
fig['layout'][y].update(autorange='reversed',
showticklabels=False, ticks='')
fig['layout'][x].update(showticklabels=False, ticks='')

py.iplot(fig)

Out[4]:

Code source:

        GaĆ«l Varoquaux



Modified for documentation by Jaques Grobler

        BSD 3 clause