Show Sidebar Hide Sidebar

Various Agglomerative Clustering on a 2D embedding of digits in Scikit-learn

An illustration of various linkage option for agglomerative clustering on a 2D embedding of the digits dataset.

The goal of this example is to show intuitively how the metrics behave, and not to find good clusters for the digits. This is why the example works on a 2D embedding.

What this example shows us is the behavior “rich getting richer” of agglomerative clustering that tends to create uneven cluster sizes. This behavior is especially pronounced for the average linkage strategy, that ends up with a couple of singleton clusters.

New to Plotly?

Plotly's Python library is free and open source! Get started by downloading the client and reading the primer.
You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

Version

In [1]:
import sklearn
sklearn.__version__
Out[1]:
'0.18'

Imports

This tutorial imports AgglomerativeClustering.

In [2]:
print(__doc__)
from time import time

import plotly.plotly as py
import plotly.graph_objs as go
from plotly import tools

import numpy as np
import matplotlib
from scipy import ndimage
from matplotlib import pyplot as plt

from sklearn import manifold, datasets
from sklearn.cluster import AgglomerativeClustering
Automatically created module for IPython interactive environment

Calculations

In [3]:
digits = datasets.load_digits(n_class=10)
X = digits.data
y = digits.target
n_samples, n_features = X.shape

np.random.seed(0)

def nudge_images(X, y):
    # Having a larger dataset shows more clearly the behavior of the
    # methods, but we multiply the size of the dataset only by 2, as the
    # cost of the hierarchical clustering methods are strongly
    # super-linear in n_samples
    shift = lambda x: ndimage.shift(x.reshape((8, 8)),
                                  .3 * np.random.normal(size=2),
                                  mode='constant',
                                  ).ravel()
    X = np.concatenate([X, np.apply_along_axis(shift, 1, X)])
    Y = np.concatenate([y, y], axis=0)
    return X, Y


X, y = nudge_images(X, y)

Plot Results

In [4]:
def matplotlib_to_plotly(cmap, pl_entries):
    h = 1.0/(pl_entries-1)
    pl_colorscale = []
    
    for k in range(pl_entries):
        C = map(np.uint8, np.array(cmap(k*h)[:3])*255)
        pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))])
        
    return pl_colorscale

Visualize the clustering

In [5]:
def plot_clustering(X_red, X, labels, title=None):
    annotations=[]
    x_min, x_max = np.min(X_red, axis=0), np.max(X_red, axis=0)
    X_red = (X_red - x_min) / (x_max - x_min)
    for i in range(X_red.shape[0]):
        color = matplotlib.colors.colorConverter.to_rgb(plt.cm.spectral(labels[i] / 10.))
        color = 'rgb' + str(color)
        annotation_ = dict(x=X_red[i, 0], y=X_red[i, 1],
                           text=str(y[i]),
                           showarrow=False,
                           font=dict(
                                size=9,
                                color=color)
                          )
       
        annotations.append(annotation_)
    return annotations

2D embedding of the digits dataset

In [6]:
print("Computing embedding")
X_red = manifold.SpectralEmbedding(n_components=2).fit_transform(X)
print("Done.")
Computing embedding
Done.
In [7]:
plot1 = []
for linkage in ('ward', 'average', 'complete'):
    clustering = AgglomerativeClustering(linkage=linkage, n_clusters=10)
    t0 = time()
    clustering.fit(X_red)
    print("%s : %.2fs" % (linkage, time() - t0))

    plot = plot_clustering(X_red, X, clustering.labels_, "%s linkage" % linkage)
    plot1.append(plot)
ward : 0.44s
average : 0.31s
complete : 0.31s

Plot 'ward'

In [8]:
layout = go.Layout(annotations=plot1[0],
                   title='ward linkage',
                   xaxis=dict(zeroline=False, showticklabels=False,
                              showgrid=False, ticks=''),
                   yaxis=dict(zeroline=False,  showticklabels=False,
                              showgrid=False, ticks='')
                  )
fig = go.Figure(data=[go.Scatter( )], layout=layout)

py.iplot(fig) 
Out[8]:

Plot 'average'

In [9]:
layout = go.Layout(annotations=plot1[1],
                   title='average linkage',
                   xaxis=dict(zeroline=False, showticklabels=False,
                              showgrid=False, ticks=''),
                   yaxis=dict(zeroline=False,  showticklabels=False,
                              showgrid=False, ticks='')
                  )
fig = go.Figure(data=[go.Scatter( )], layout=layout)

py.iplot(fig) 
Out[9]:

Plot 'complete'

In [10]:
layout = go.Layout(annotations=plot1[2],
                   title='complete linkage',
                   xaxis=dict(zeroline=False, showticklabels=False,
                              showgrid=False, ticks=''),
                   yaxis=dict(zeroline=False,  showticklabels=False,
                              showgrid=False, ticks='')
                  )
fig = go.Figure(data=[go.Scatter( )], layout=layout)

py.iplot(fig)
Out[10]:

License

Authors:

    Gael Varoquaux

License:

    BSD 3 clause (C) INRIA 2014
Still need help?
Contact Us

For guaranteed 24 hour response turnarounds, upgrade to a Developer Support Plan.