Show Sidebar Hide Sidebar

Comparing Different Clustering Algorithms on Toy Datasets in Scikit-learn

This example aims at showing characteristics of different clustering algorithms on datasets that are “interesting” but still in 2D. The last dataset is an example of a ‘null’ situation for clustering: the data is homogeneous, and there is no good clustering.

While these examples give some intuition about the algorithms, this intuition might not apply to very high dimensional data.

The results could be improved by tweaking the parameters for each clustering strategy, for instance setting the number of clusters for the methods that needs this parameter specified. Note that affinity propagation has a tendency to create many clusters. Thus in this example its two parameters (damping and per-point preference) were set to mitigate this behavior.

New to Plotly?

Plotly's Python library is free and open source! Get started by downloading the client and reading the primer.
You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

Version

In [1]:
import sklearn
sklearn.__version__
Out[1]:
'0.18'

Imports

In [2]:
print(__doc__)

import time

import plotly.plotly as py
import plotly.graph_objs as go
from plotly import tools

import numpy as np
import matplotlib.pyplot as plt

from sklearn import cluster, datasets
from sklearn.neighbors import kneighbors_graph
from sklearn.preprocessing import StandardScaler
Automatically created module for IPython interactive environment

Calculations

In [3]:
np.random.seed(0)

# Generate datasets. We choose the size big enough to see the scalability
# of the algorithms, but not too big to avoid too long running times
n_samples = 1500
noisy_circles = datasets.make_circles(n_samples=n_samples, factor=.5,
                                      noise=.05)
noisy_moons = datasets.make_moons(n_samples=n_samples, noise=.05)
blobs = datasets.make_blobs(n_samples=n_samples, random_state=8)
no_structure = np.random.rand(n_samples, 2), None

colors = np.array([x for x in ['blue','green','red','cyan',
                                'magenta','yellow','black']])
colors = np.hstack([colors] * 20)

clustering_names = [
    'MiniBatchKMeans', 'Affinity<br>Propagation', 'MeanShift',
    'Spectral<br>Clustering', 'Ward', 'Agglomerative<br>Clustering',
    'DBSCAN', 'Birch']

Plot Results

In [4]:
fig = tools.make_subplots(rows=8, cols=4,
                          print_grid=False)

fig['layout'].update(height=1000)
# Set subplot titles
j=0

for i in map(str, range(1, 33, 4)):
    y = 'yaxis'+i
    fig['layout'][y].update(title=clustering_names[j])
    j+=1
    
for i in map(str, range(1, 33)):
    y = 'yaxis'+i
    x = 'xaxis'+i
    fig['layout'][y].update(ticks='', showticklabels=False,
                            zeroline=False, showgrid=False)
    
    fig['layout'][x].update(ticks='', showticklabels=False,
                            zeroline=False, showgrid=False)
In [7]:
row_num = 0
col_num = 0
datasets = [noisy_circles, noisy_moons, blobs, no_structure]

for i_dataset, dataset in enumerate(datasets):
    X, y = dataset
    # normalize dataset for easier parameter selection
    X = StandardScaler().fit_transform(X)

    # estimate bandwidth for mean shift
    bandwidth = cluster.estimate_bandwidth(X, quantile=0.3)

    # connectivity matrix for structured Ward
    connectivity = kneighbors_graph(X, n_neighbors=10, include_self=False)
    # make connectivity symmetric
    connectivity = 0.5 * (connectivity + connectivity.T)

    # create clustering estimators
    ms = cluster.MeanShift(bandwidth=bandwidth, bin_seeding=True)
    two_means = cluster.MiniBatchKMeans(n_clusters=2)
    ward = cluster.AgglomerativeClustering(n_clusters=2, linkage='ward',
                                           connectivity=connectivity)
    spectral = cluster.SpectralClustering(n_clusters=2,
                                          eigen_solver='arpack',
                                          affinity="nearest_neighbors")
    dbscan = cluster.DBSCAN(eps=.2)
    affinity_propagation = cluster.AffinityPropagation(damping=.9,
                                                       preference=-200)

    average_linkage = cluster.AgglomerativeClustering(
        linkage="average", affinity="cityblock", n_clusters=2,
        connectivity=connectivity)

    birch = cluster.Birch(n_clusters=2)
    clustering_algorithms = [
        two_means, affinity_propagation, ms, spectral, ward, average_linkage,
        dbscan, birch]
    
    for name, algorithm in zip(clustering_names, clustering_algorithms):
        # predict cluster memberships
        t0 = time.time()
        algorithm.fit(X)
        t1 = time.time()
        if hasattr(algorithm, 'labels_'):
            y_pred = algorithm.labels_.astype(np.int)
        else:
            y_pred = algorithm.predict(X)

        # plot
        
        trace = go.Scatter(x=X[:, 0], y=X[:, 1], 
                           showlegend=False,
                           mode='markers',
                           marker=dict(color=colors[y_pred].tolist(),
                                       size=3,)
                          )

        if hasattr(algorithm, 'cluster_centers_'):
            centers = algorithm.cluster_centers_
            center_colors = colors[:len(centers)]
            center = go.Scatter(x=[centers[:, 0]], 
                                y=[centers[:, 1]], 
                                showlegend=False,
                                mode='markers',
                                marker=dict(color=center_colors[0],
                                            size=3,
                                            line=dict(color='black',
                                                    width=1))
                              )
                              
        fig.append_trace(trace, row_num%8+1,  col_num%4 +1)
        fig.append_trace(center, row_num%8+1, col_num%4 +1)
        
        
        row_num += 1
    col_num += 1
In [6]:
py.iplot(fig, validate=False)
Out[6]:
Still need help?
Contact Us

For guaranteed 24 hour response turnarounds, upgrade to a Developer Support Plan.