Show Sidebar Hide Sidebar

Manifold Learning on Handwritten Digits in Scikit-learn

An illustration of various embeddings on the digits dataset.

The RandomTreesEmbedding, from the sklearn.ensemble module, is not technically a manifold embedding method, as it learn a high-dimensional representation on which we apply a dimensionality reduction method. However, it is often useful to cast a dataset into a representation in which the classes are linearly-separable.

t-SNE will be initialized with the embedding that is generated by PCA in this example, which is not the default setting. It ensures global stability of the embedding, i.e., the embedding does not depend on random initialization.

New to Plotly?

Plotly's Python library is free and open source! Get started by downloading the client and reading the primer.
You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

Version

In [1]:
import sklearn
sklearn.__version__
Out[1]:
'0.18.1'

Imports

In [2]:
import plotly.plotly as py
import plotly.graph_objs as go

from time import time

import numpy as np
import matplotlib.pyplot as plt
from sklearn import (manifold, datasets, decomposition, ensemble,
                     discriminant_analysis, random_projection)

Calculations

In [3]:
digits = datasets.load_digits(n_class=6)
X = digits.data
y = digits.target
n_samples, n_features = X.shape
n_neighbors = 30

Plot Results

In [4]:
def plot_embedding(X, title=None):
    x_min, x_max = np.min(X, 0), np.max(X, 0)
    X = (X - x_min) / (x_max - x_min)
    anno=[]
    
    for i in range(X.shape[0]):
        anno.append(dict(x=X[i, 0], y=X[i, 1], text=str(digits.target[i]),
                         showarrow=False,
                           font=dict(
                                size=11,
                                color='rgb'+str(plt.cm.Set1(y[i] / 10.)[:3])),
                         ))

    shown_images = np.array([[1., 1.]])  # just something big

    for i in range(digits.data.shape[0]):
        dist = np.sum((X[i] - shown_images) ** 2, 1)
        if np.min(dist) < 4e-3:
            # don't show points that are too close
            continue
        shown_images = np.r_[shown_images, [X[i]]]

        x_ = []
        y_ = []
        data = []
        for i in range(0, len(shown_images)):
            x_.append(shown_images[i][0])
            y_.append(shown_images[i][1])

        data.append(shown_images)
                    
    trace = go.Scatter(x=x_, y=y_, 
                       showlegend=False,
                       mode='markers',
                       marker=dict(color='white', size=15,
                                  line=dict(color='black', width=1)))
    layout = go.Layout(annotations=anno, title=title,
                       xaxis=dict(ticks='', showticklabels=False,
                                  showgrid=False, zeroline=False),
                       yaxis=dict(ticks='', showticklabels=False,
                                  showgrid=False, zeroline=False),
                      )
    fig = go.Figure(data=[trace], layout=layout)
    
    return fig

Plot Images of the Digits

In [5]:
n_img_per_row = 20
img = np.zeros((10 * n_img_per_row, 10 * n_img_per_row))
for i in range(n_img_per_row):
    ix = 10 * i + 1
    for j in range(n_img_per_row):
        iy = 10 * j + 1
        img[ix:ix + 8, iy:iy + 8] = X[i * n_img_per_row + j].reshape((8, 8))
In [6]:
def matplotlib_to_plotly(cmap, pl_entries):
    h = 1.0/(pl_entries-1)
    pl_colorscale = []
    
    for k in range(pl_entries):
        C = map(np.uint8, np.array(cmap(k*h)[:3])*255)
        pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))])
        
    return pl_colorscale

trace = go.Heatmap(z=img, 
                   colorscale=matplotlib_to_plotly(plt.cm.binary, 5),
                   showscale=False)

layout = go.Layout(title='A selection from the 64-dimensional digits dataset',
                   xaxis=dict(ticks='', showticklabels=False),
                   yaxis=dict(ticks='', showticklabels=False,
                              autorange='reversed'),
                      )
fig = go.Figure(data=[trace], layout=layout)
In [7]:
py.iplot(fig)
Out[7]:

Random 2D projection using a random unitary matrix

In [8]:
print("Computing random projection")
rp = random_projection.SparseRandomProjection(n_components=2, random_state=42)
X_projected = rp.fit_transform(X)
fig = plot_embedding(X_projected, "Random Projection of the digits")
Computing random projection
In [9]:
py.iplot(fig)
Out[9]:

Projection on to the first 2 principal components

In [10]:
print("Computing PCA projection")
t0 = time()
X_pca = decomposition.TruncatedSVD(n_components=2).fit_transform(X)

fig = plot_embedding(X_pca,
                    "Principal Components projection of the digits (time %.2fs)" %
                    (time() - t0))
Computing PCA projection
In [11]:
py.iplot(fig)
Out[11]:

Projection on to the first 2 linear discriminant components

In [12]:
print("Computing Linear Discriminant Analysis projection")
X2 = X.copy()
X2.flat[::X.shape[1] + 1] += 0.01  # Make X invertible
t0 = time()
X_lda = discriminant_analysis.LinearDiscriminantAnalysis(n_components=2).fit_transform(X2, y)

fig = plot_embedding(X_lda,
               "Linear Discriminant projection of the digits (time %.2fs)" %
               (time() - t0))
Computing Linear Discriminant Analysis projection
In [13]:
py.iplot(fig)
Out[13]:

Isomap projection of the digits dataset

In [14]:
print("Computing Isomap embedding")
t0 = time()
X_iso = manifold.Isomap(n_neighbors, n_components=2).fit_transform(X)
print("Done.")

fig = plot_embedding(X_iso,
                     "Isomap projection of the digits (time %.2fs)" %
                     (time() - t0))
Computing Isomap embedding
Done.
In [15]:
py.iplot(fig)
Out[15]:

Locally linear embedding of the digits dataset

In [16]:
print("Computing LLE embedding")
clf = manifold.LocallyLinearEmbedding(n_neighbors, n_components=2,
                                      method='standard')
t0 = time()
X_lle = clf.fit_transform(X)
print("Done. Reconstruction error: %g" % clf.reconstruction_error_)

fig = plot_embedding(X_lle,
                    "Locally Linear Embedding of the digits (time %.2fs)" %
                    (time() - t0))
Computing LLE embedding
Done. Reconstruction error: 1.63544e-06
In [17]:
py.iplot(fig)
Out[17]:

Modified Locally linear embedding of the digits dataset

In [18]:
print("Computing modified LLE embedding")
clf = manifold.LocallyLinearEmbedding(n_neighbors, n_components=2,
                                      method='modified')
t0 = time()
X_mlle = clf.fit_transform(X)
print("Done. Reconstruction error: %g" % clf.reconstruction_error_)

fig = plot_embedding(X_mlle,
                    "Modified Locally Linear Embedding of the digits (time %.2fs)" %
                    (time() - t0))
Computing modified LLE embedding
Done. Reconstruction error: 0.360668
In [19]:
py.iplot(fig)
Out[19]:

HLLE embedding of the digits dataset

In [20]:
print("Computing Hessian LLE embedding")
clf = manifold.LocallyLinearEmbedding(n_neighbors, n_components=2,
                                      method='hessian')
t0 = time()
X_hlle = clf.fit_transform(X)
print("Done. Reconstruction error: %g" % clf.reconstruction_error_)

fig = plot_embedding(X_hlle,
                     "Hessian Locally Linear Embedding of the digits (time %.2fs)" %
                     (time() - t0))
Computing Hessian LLE embedding
Done. Reconstruction error: 0.212801
In [21]:
py.iplot(fig)
Out[21]:

LTSA embedding of the digits dataset

In [22]:
print("Computing LTSA embedding")
clf = manifold.LocallyLinearEmbedding(n_neighbors, n_components=2,
                                      method='ltsa')
t0 = time()
X_ltsa = clf.fit_transform(X)
print("Done. Reconstruction error: %g" % clf.reconstruction_error_)
fig = plot_embedding(X_ltsa,
                     "Local Tangent Space Alignment of the digits (time %.2fs)" %
                      (time() - t0))
Computing LTSA embedding
Done. Reconstruction error: 0.212804
In [23]:
py.iplot(fig)
Out[23]:

MDS embedding of the digits dataset

In [24]:
print("Computing MDS embedding")
clf = manifold.MDS(n_components=2, n_init=1, max_iter=100)
t0 = time()
X_mds = clf.fit_transform(X)
print("Done. Stress: %f" % clf.stress_)
fig = plot_embedding(X_mds,
                     "MDS embedding of the digits (time %.2fs)" %
                      (time() - t0))
Computing MDS embedding
Done. Stress: 143118271.858794
In [25]:
py.iplot(fig)
Out[25]:

Random Trees embedding of the digits dataset

In [26]:
print("Computing Totally Random Trees embedding")
hasher = ensemble.RandomTreesEmbedding(n_estimators=200, random_state=0,
                                       max_depth=5)
t0 = time()
X_transformed = hasher.fit_transform(X)
pca = decomposition.TruncatedSVD(n_components=2)
X_reduced = pca.fit_transform(X_transformed)

fig = plot_embedding(X_reduced,
                    "Random forest embedding of the digits (time %.2fs)" %
                    (time() - t0))
Computing Totally Random Trees embedding
In [27]:
py.iplot(fig)
Out[27]:

Spectral embedding of the digits dataset

In [28]:
print("Computing Spectral embedding")
embedder = manifold.SpectralEmbedding(n_components=2, random_state=0,
                                      eigen_solver="arpack")
t0 = time()
X_se = embedder.fit_transform(X)

fig = plot_embedding(X_se,
                    "Spectral embedding of the digits (time %.2fs)" %
                     (time() - t0))
Computing Spectral embedding
In [29]:
py.iplot(fig)
Out[29]:

t-SNE embedding of the digits dataset

In [30]:
print("Computing t-SNE embedding")
tsne = manifold.TSNE(n_components=2, init='pca', random_state=0)
t0 = time()
X_tsne = tsne.fit_transform(X)

fig = plot_embedding(X_tsne,
                    "t-SNE embedding of the digits (time %.2fs)" %
                    (time() - t0))
Computing t-SNE embedding
In [31]:
py.iplot(fig)
Out[31]:

License

Authors:

      Fabian Pedregosa <fabian.pedregosa@inria.fr>

      Olivier Grisel <olivier.grisel@ensta.org>

      Mathieu Blondel <mathieu@mblondel.org>

      Gael Varoquaux

License:

      BSD 3 clause (C) INRIA 2011
Still need help?
Contact Us

For guaranteed 24 hour response turnarounds, upgrade to a Developer Support Plan.