Show Sidebar Hide Sidebar

K-Means Clustering on the Handwritten Digits Data in Scikit-learn

In this example we compare the various initialization strategies for K-means in terms of runtime and quality of the results.

As the ground truth is known here, we also apply different cluster quality metrics to judge the goodness of fit of the cluster labels to the ground truth.

Cluster quality metrics evaluated (see Clustering performance evaluation for definitions and discussions of the metrics):

Shorthand full name
homo Homogeneity Score
compl Completeness Score
v-meas V Measure
ARI Adjusted Rand Index
AMI Adjusted Mutual Information
silhouette Silhouette Coefficient

New to Plotly?

Plotly's Python library is free and open source! Get started by downloading the client and reading the primer.
You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!


In [1]:
import sklearn


This tutorail imports KMeans, load_digits, PCA and scale.

In [2]:
import plotly.plotly as py
import plotly.graph_objs as go

from time import time
import numpy as np
import matplotlib.pyplot as plt

from sklearn import metrics
from sklearn.cluster import KMeans
from sklearn.datasets import load_digits
from sklearn.decomposition import PCA
from sklearn.preprocessing import scale


In [3]:

digits = load_digits()
data = scale(

n_samples, n_features = data.shape
n_digits = len(np.unique(
labels =

sample_size = 300

print("n_digits: %d, \t n_samples %d, \t n_features %d"
      % (n_digits, n_samples, n_features))

print(79 * '_')
print('% 9s' % 'init'
      '    time  inertia    homo   compl  v-meas     ARI AMI  silhouette')

def bench_k_means(estimator, name, data):
    t0 = time()
    print('% 9s   %.2fs    %i   %.3f   %.3f   %.3f   %.3f   %.3f    %.3f'
          % (name, (time() - t0), estimator.inertia_,
             metrics.homogeneity_score(labels, estimator.labels_),
             metrics.completeness_score(labels, estimator.labels_),
             metrics.v_measure_score(labels, estimator.labels_),
             metrics.adjusted_rand_score(labels, estimator.labels_),
             metrics.adjusted_mutual_info_score(labels,  estimator.labels_),
             metrics.silhouette_score(data, estimator.labels_,

bench_k_means(KMeans(init='k-means++', n_clusters=n_digits, n_init=10),
              name="k-means++", data=data)

bench_k_means(KMeans(init='random', n_clusters=n_digits, n_init=10),
              name="random", data=data)

# in this case the seeding of the centers is deterministic, hence we run the
# kmeans algorithm only once with n_init=1
pca = PCA(n_components=n_digits).fit(data)
bench_k_means(KMeans(init=pca.components_, n_clusters=n_digits, n_init=1),
print(79 * '_')
n_digits: 10, 	 n_samples 1797, 	 n_features 64
init    time  inertia    homo   compl  v-meas     ARI AMI  silhouette
k-means++   0.29s    69432   0.602   0.650   0.625   0.465   0.598    0.146
   random   0.21s    69694   0.669   0.710   0.689   0.553   0.666    0.147
PCA-based   0.04s    70804   0.671   0.698   0.684   0.561   0.668    0.118

Plot Results on PCA-reduced data

In [4]:
def matplotlib_to_plotly(cmap, pl_entries):
    h = 1.0/(pl_entries-1)
    pl_colorscale = []
    for k in range(pl_entries):
        C = map(np.uint8, np.array(cmap(k*h)[:3])*255)
        pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))])
    return pl_colorscale
In [5]:
reduced_data = PCA(n_components=2).fit_transform(data)
kmeans = KMeans(init='k-means++', n_clusters=n_digits, n_init=10)

# Step size of the mesh. Decrease to increase the quality of the VQ.
h = .02     # point in the mesh [x_min, x_max]x[y_min, y_max].

# Plot the decision boundary. For that, we will assign a color to each
x_min, x_max = reduced_data[:, 0].min() - 1, reduced_data[:, 0].max() + 1
y_min, y_max = reduced_data[:, 1].min() - 1, reduced_data[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))

# Obtain labels for each point in mesh. Use last trained model.
Z = kmeans.predict(np.c_[xx.ravel(), yy.ravel()])

# Put the result into a color plot
Z = Z.reshape(xx.shape)

back = go.Heatmap(x=xx[0][:len(Z)],
                  colorscale=matplotlib_to_plotly(, len(Z)))

markers = go.Scatter(x=reduced_data[:, 0], 
                     y=reduced_data[:, 1],
                             size=3, color='black'))

# Plot the centroids as a white 
centroids = kmeans.cluster_centers_
center = go.Scatter(x=centroids[:, 0],
                    y=centroids[:, 1],
                            size=10, color='white'))
data=[back, markers, center]
In [6]:
layout = go.Layout(title ='K-means clustering on the digits dataset (PCA-reduced data)<br>'
                           'Centroids are marked with white',
                   xaxis=dict(ticks='', showticklabels=False,
                   yaxis=dict(ticks='', showticklabels=False,
fig = go.Figure(data=data, layout=layout)

Still need help?
Contact Us

For guaranteed 24 hour response turnarounds, upgrade to a Developer Support Plan.