Show Sidebar Hide Sidebar

K-Means Assumptions in Scikit-learn

This example is meant to illustrate situations where k-means will produce unintuitive and possibly unexpected clusters. In the first three plots, the input data does not conform to some implicit assumption that k-means makes and undesirable clusters are produced as a result. In the last plot, k-means returns intuitive clusters despite unevenly sized blobs.

New to Plotly?

Plotly's Python library is free and open source! Get started by downloading the client and reading the primer.
You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

Version

In [1]:
import sklearn
sklearn.__version__
Out[1]:
'0.18'

Imports

This tutorial imports KMeans and make_blobs.

In [2]:
from plotly import tools
import plotly.plotly as py
import plotly.graph_objs as go


import numpy as np
import matplotlib.pyplot as plt

from sklearn.cluster import KMeans
from sklearn.datasets import make_blobs

Plot Results

In [3]:
fig = tools.make_subplots(rows=2, cols=2,
                         print_grid=False,
                         subplot_titles=('Incorrect Number of Blobs',
                                         'Anisotropicly Distributed Blobs',
                                         'Unequal Variance','Unevenly Sized Blobs'))


n_samples = 1500
random_state = 170
X, y = make_blobs(n_samples=n_samples, random_state=random_state)

# Incorrect number of clusters
y_pred = KMeans(n_clusters=2, random_state=random_state).fit_predict(X)


Incorrect_Number = go.Scatter(x=X[:, 0],
                              y=X[:, 1], 
                              mode='markers',
                              showlegend=False,
                              marker=dict(color=y_pred,
                                          line=dict(
                                               width=1, color='black')
                                         ))

fig.append_trace(Incorrect_Number, 1, 1)

# Anisotropicly distributed data
transformation = [[ 0.60834549, -0.63667341], [-0.40887718, 0.85253229]]
X_aniso = np.dot(X, transformation)
y_pred = KMeans(n_clusters=3, random_state=random_state).fit_predict(X_aniso)

Anisotropicly_distributed = go.Scatter(x=X_aniso[:, 0], 
                                       y=X_aniso[:, 1], 
                                       mode='markers',
                                       showlegend=False,
                                       marker=dict(color=y_pred,
                                                   line=dict(
                                                        width=1, color='black')
                                                  ))
fig.append_trace(Anisotropicly_distributed, 1, 2)

# Different variance
X_varied, y_varied = make_blobs(n_samples=n_samples,
                                cluster_std=[1.0, 2.5, 0.5],
                                random_state=random_state)
y_pred = KMeans(n_clusters=3, random_state=random_state).fit_predict(X_varied)

Different_variance = go.Scatter(x=X_varied[:, 0], 
                                y=X_varied[:, 1], 
                                mode='markers',
                                showlegend=False,
                                marker=dict(color=y_pred,
                                            line=dict(
                                                  width=1, color='black')
                                            ))
fig.append_trace(Different_variance, 2, 1)

# Unevenly sized blobs
X_filtered = np.vstack((X[y == 0][:500], X[y == 1][:100], X[y == 2][:10]))
y_pred = KMeans(n_clusters=3, random_state=random_state).fit_predict(X_filtered)

Unevenly_sized = go.Scatter(x=X_filtered[:, 0], 
                            y=X_filtered[:, 1], 
                            mode='markers',
                            showlegend=False,
                            marker=dict(color=y_pred,
                                        line=dict(
                                              width=1, color='black')
                                     ))
fig.append_trace(Unevenly_sized, 2, 2)

fig['layout'].update(height=800)

for i in map(str, range(1,5)):
    x = 'xaxis' + i
    y = 'yaxis' + i
    fig['layout'][x].update(zeroline=False, showgrid=False)
    fig['layout'][y].update(zeroline=False, showgrid=False)
    
py.iplot(fig)
Out[3]:

License

Author:

    Phil Roth <mr.phil.roth@gmail.com>

License:

    BSD 3 clause
Still need help?
Contact Us

For guaranteed 24 hour response turnarounds, upgrade to a Developer Support Plan.