Show Sidebar Hide Sidebar

Compare BIRCH and MiniBatchKMeans in Scikit-learn

Note: this page is part of the documentation for version 3 of, which is not the most recent version.
See our Version 4 Migration Guide for information about how to upgrade.

This example compares the timing of Birch (with and without the global clustering step) and MiniBatchKMeans on a synthetic dataset having 100,000 samples and 2 features generated using make_blobs.

If n_clusters is set to None, the data is reduced from 100,000 samples to a set of 158 clusters. This can be viewed as a preprocessing step before the final (global) clustering step that further reduces these 158 clusters to 100 clusters.

New to Plotly?

Plotly's Python library is free and open source! Get started by downloading the client and reading the primer.
You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!


In [1]:
import sklearn


This tutorial imports Birch, MiniBatchKMeans and make_blobs.

In [2]:

import plotly.plotly as py
import plotly.graph_objs as go
from plotly import tools

from itertools import cycle
from time import time
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors

from sklearn.preprocessing import StandardScaler
from sklearn.cluster import Birch, MiniBatchKMeans
from sklearn.datasets.samples_generator import make_blobs
Automatically created module for IPython interactive environment


In [3]:
# Generate centers for the blobs so that it forms a 10 X 10 grid.
xx = np.linspace(-22, 22, 10)
yy = np.linspace(-22, 22, 10)
xx, yy = np.meshgrid(xx, yy)
n_centres = np.hstack((np.ravel(xx)[:, np.newaxis],
                       np.ravel(yy)[:, np.newaxis]))

# Generate blobs to do a comparison between MiniBatchKMeans and Birch.
X, y = make_blobs(n_samples=100000, centers=n_centres, random_state=0)

# Use all colors that matplotlib provides by default.
colors_ = cycle(colors.cnames.keys())
birch_models = [Birch(threshold=1.7, n_clusters=None),
                Birch(threshold=1.7, n_clusters=100)]
final_step = ['without global clustering', 'with global clustering']
In [4]:
fig = tools.make_subplots(rows=1, cols=3,
                          subplot_titles=('BIRCH without global clustering',
                                          'BIRCH with global clustering',
This is the format of your plot grid:
[ (1,1) x1,y1 ]  [ (1,2) x2,y2 ]  [ (1,3) x3,y3 ]

Compute clustering with BIRCH

In [5]:
fignum = 1

for ind, (birch_model, info) in enumerate(zip(birch_models, final_step)):
    t = time()
    time_ = time() - t
    print("Birch %s as the final step took %0.2f seconds" % (
          info, (time() - t)))

    # Plot result
    labels = birch_model.labels_
    centroids = birch_model.subcluster_centers_
    n_clusters = np.unique(labels).size
    print("n_clusters : %d" % n_clusters)
    for this_centroid, k, col in zip(centroids, range(n_clusters), colors_):
        mask = labels == k
        birch = go.Scattergl(x=X[mask, 0], y=X[mask, 1],
                           mode='markers', marker=dict(size=2,
        fig.append_trace(birch, 1, fignum)
        if birch_model.n_clusters is None:
            center1 = go.Scatter(x=this_centroid[0], y=this_centroid[1],
                                 mode='markers', marker=dict(color='black'))
            fig.append_trace(center1, 1, fignum)
Birch without global clustering as the final step took 2.49 seconds
n_clusters : 158
Birch with global clustering as the final step took 2.30 seconds
n_clusters : 100

Compute clustering with MiniBatchKMeans.

In [6]:
mbk = MiniBatchKMeans(init='k-means++', n_clusters=100, batch_size=100,
                      n_init=10, max_no_improvement=10, verbose=0,
t0 = time()
t_mini_batch = time() - t0
print("Time taken to run MiniBatchKMeans %0.2f seconds" % t_mini_batch)
mbk_means_labels_unique = np.unique(mbk.labels_)

for this_centroid, k, col in zip(mbk.cluster_centers_,
                                 range(n_clusters), colors_):
    mask = mbk.labels_ == k
    minibatchkmeans_ = go.Scattergl(x=X[mask, 0], y=X[mask, 1],
                                  mode='markers', marker=dict(size=2,
    center2 = go.Scatter(x=[this_centroid[0]], y=[this_centroid[1]],
                         mode='markers', marker=dict(color='black', size=6))
    fig.append_trace(minibatchkmeans_, 1, 3)
    fig.append_trace(center2, 1, 3)
Time taken to run MiniBatchKMeans 3.11 seconds
In [7]:
for i in map(str, range(1, 4)):
    x = 'xaxis' + i
    y = 'yaxis' + i
    fig['layout'][x].update(zeroline=False, showgrid=False,
                            ticks='', showticklabels=False)
    fig['layout'][y].update(zeroline=False, showgrid=False,
                            ticks='', showticklabels=False)

fig['layout'].update(height=900, width=900,
                     margin=dict(l=10, r=10))
py.iplot(fig, validate=False)



       Manoj Kumar <

       Alexandre Gramfort <>


       BSD 3 clause