Show Sidebar Hide Sidebar

Recognizing Hand-Written Digits in Scikit-learn

An example showing how the scikit-learn can be used to recognize images of hand-written digits.

New to Plotly?

Plotly's Python library is free and open source! Get started by downloading the client and reading the primer.
You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

Version

In [1]:
import sklearn
sklearn.__version__
Out[1]:
'0.18'

Imports

In [2]:
print(__doc__)

import plotly.plotly as py
import plotly.graph_objs as go
from plotly import tools

import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets, svm, metrics
Automatically created module for IPython interactive environment

Calculations and Plots

In [3]:
# The digits dataset
digits = datasets.load_digits()
fig = tools.make_subplots(rows=2, cols=4,
                         subplot_titles=
                         ('Training: 0','Training: 1', 'Training: 2','Training: 3',
                         'Prediction: 8','Prediction: 8','Prediction: 4','Prediction: 9'))

# The data that we are interested in is made of 8x8 images of digits, let's
# have a look at the first 4 images, stored in the `images` attribute of the
# dataset.  If we were working from image files, we could load them using
# matplotlib.pyplot.imread.  Note that each image must have the same size. For these
# images, we know which digit they represent: it is given in the 'target' of
# the dataset.

def matplotlib_to_plotly(cmap, pl_entries):
    h = 1.0/(pl_entries-1)
    pl_colorscale = []
    
    for k in range(pl_entries):
        C = map(np.uint8, np.array(cmap(k*h)[:3])*255)
        pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))])
        
    return pl_colorscale

images_and_labels = list(zip(digits.images, digits.target))

for index, (image, label) in enumerate(images_and_labels[:4]):
    trace= go.Heatmap(z=image, 
               colorscale=matplotlib_to_plotly(plt.cm.gray_r, len(image)),
               showscale=False,
               name='Training: %i' % label)
    fig.append_trace(trace, 1, label+1)
    
    
fig['layout']['yaxis1'].update(autorange='reversed',
                               showticklabels=False, ticks='')
fig['layout']['yaxis2'].update(autorange='reversed',
                               showticklabels=False, ticks='')
fig['layout']['yaxis3'].update(autorange='reversed',
                               showticklabels=False, ticks='')
fig['layout']['yaxis4'].update(autorange='reversed',
                               showticklabels=False, ticks='')

fig['layout']['xaxis1'].update(showticklabels=False, ticks='')
fig['layout']['xaxis2'].update(showticklabels=False, ticks='')
fig['layout']['xaxis3'].update(showticklabels=False, ticks='')
fig['layout']['xaxis4'].update(showticklabels=False, ticks='')
    
    
# To apply a classifier on this data, we need to flatten the image, to
# turn the data in a (samples, feature) matrix:
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))

# Create a classifier: a support vector classifier
classifier = svm.SVC(gamma=0.001)

# We learn the digits on the first half of the digits
classifier.fit(data[:n_samples / 2], digits.target[:n_samples / 2])

# Now predict the value of the digit on the second half:
expected = digits.target[n_samples / 2:]
predicted = classifier.predict(data[n_samples / 2:])

print("Classification report for classifier %s:\n%s\n"
      % (classifier, metrics.classification_report(expected, predicted)))
print("Confusion matrix:\n%s" % metrics.confusion_matrix(expected, predicted))

images_and_predictions = list(zip(digits.images[n_samples / 2:], predicted))
i=1

for index, (image, prediction) in enumerate(images_and_predictions[:4]):
    trace1 = go.Heatmap(z=image, 
               colorscale=matplotlib_to_plotly(plt.cm.gray_r, len(image)),
               showscale=False,
               name='Prediction: %i' % prediction)
    fig.append_trace(trace1, 2, i)
    i=i+1
    
    
fig['layout']['yaxis5'].update(autorange='reversed',
                               showticklabels=False, ticks='')
fig['layout']['yaxis6'].update(autorange='reversed',
                               showticklabels=False, ticks='')
fig['layout']['yaxis7'].update(autorange='reversed',
                               showticklabels=False, ticks='')
fig['layout']['yaxis8'].update(autorange='reversed',
                               showticklabels=False, ticks='')

fig['layout']['xaxis5'].update(showticklabels=False, ticks='')
fig['layout']['xaxis6'].update(showticklabels=False, ticks='')
fig['layout']['xaxis7'].update(showticklabels=False, ticks='')
fig['layout']['xaxis8'].update(showticklabels=False, ticks='')

fig['layout'].update(height=700)
This is the format of your plot grid:
[ (1,1) x1,y1 ]  [ (1,2) x2,y2 ]  [ (1,3) x3,y3 ]  [ (1,4) x4,y4 ]
[ (2,1) x5,y5 ]  [ (2,2) x6,y6 ]  [ (2,3) x7,y7 ]  [ (2,4) x8,y8 ]

Classification report for classifier SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma=0.001, kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False):
             precision    recall  f1-score   support

          0       1.00      0.99      0.99        88
          1       0.99      0.97      0.98        91
          2       0.99      0.99      0.99        86
          3       0.98      0.87      0.92        91
          4       0.99      0.96      0.97        92
          5       0.95      0.97      0.96        91
          6       0.99      0.99      0.99        91
          7       0.96      0.99      0.97        89
          8       0.94      1.00      0.97        88
          9       0.93      0.98      0.95        92

avg / total       0.97      0.97      0.97       899


Confusion matrix:
[[87  0  0  0  1  0  0  0  0  0]
 [ 0 88  1  0  0  0  0  0  1  1]
 [ 0  0 85  1  0  0  0  0  0  0]
 [ 0  0  0 79  0  3  0  4  5  0]
 [ 0  0  0  0 88  0  0  0  0  4]
 [ 0  0  0  0  0 88  1  0  0  2]
 [ 0  1  0  0  0  0 90  0  0  0]
 [ 0  0  0  0  0  1  0 88  0  0]
 [ 0  0  0  0  0  0  0  0 88  0]
 [ 0  0  0  1  0  1  0  0  0 90]]
In [4]:
py.iplot(fig)
Out[4]:

License

Author:

    Gael Varoquaux <gael dot varoquaux at normalesup dot org>

License:

    BSD 3 clause
Still need help?
Contact Us

For guaranteed 24 hour response turnarounds, upgrade to a Developer Support Plan.