Show Sidebar Hide Sidebar

Normal and Shrinkage Linear Discriminant Analysis for classification in Scikit-learn

Shows how shrinkage improves classification.

New to Plotly?

Plotly's Python library is free and open source! Get started by downloading the client and reading the primer.
You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!


In [1]:
import sklearn


This tutorial imports make_blobs and LinearDiscriminantAnalysis.

In [2]:
from __future__ import division

import plotly.plotly as py
import plotly.graph_objs as go

import numpy as np
import matplotlib.pyplot as plt

from sklearn.datasets import make_blobs
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis


In [3]:
n_train = 20  # samples for training
n_test = 200  # samples for testing
n_averages = 50  # how often to repeat classification
n_features_max = 75  # maximum number of features
step = 4  # step size for the calculation

def generate_data(n_samples, n_features):
    """Generate random blob-ish data with noisy features.

    This returns an array of input data with shape `(n_samples, n_features)`
    and an array of `n_samples` target labels.

    Only one feature contains discriminative information, the other features
    contain only noise.
    X, y = make_blobs(n_samples=n_samples, n_features=1, centers=[[-2], [2]])

    # add non-discriminative features
    if n_features > 1:
        X = np.hstack([X, np.random.randn(n_samples, n_features - 1)])
    return X, y

acc_clf1, acc_clf2 = [], []
n_features_range = range(1, n_features_max + 1, step)
for n_features in n_features_range:
    score_clf1, score_clf2 = 0, 0
    for _ in range(n_averages):
        X, y = generate_data(n_train, n_features)

        clf1 = LinearDiscriminantAnalysis(solver='lsqr', shrinkage='auto').fit(X, y)
        clf2 = LinearDiscriminantAnalysis(solver='lsqr', shrinkage=None).fit(X, y)

        X, y = generate_data(n_test, n_features)
        score_clf1 += clf1.score(X, y)
        score_clf2 += clf2.score(X, y)

    acc_clf1.append(score_clf1 / n_averages)
    acc_clf2.append(score_clf2 / n_averages)

features_samples_ratio = np.array(n_features_range) / n_train


In [4]:
lda_shrinkage = go.Scatter(x=features_samples_ratio, 
                           y=acc_clf1, mode= "lines",
                           name="Linear Discriminant Analysis with shrinkage", 
                           line=dict(color='navy', width=2))
lda_plot = go.Scatter(x=features_samples_ratio, y = acc_clf2,
                      name="Linear Discriminant Analysis",

data = [lda_plot, lda_shrinkage]
layout = go.Layout(xaxis=dict(title="n_features/n_samples",
                   yaxis=dict(title= "Classification Accuracy",
fig = go.Figure(data=data, layout=layout)

Still need help?
Contact Us

For guaranteed 24 hour response turnarounds, upgrade to a Developer Support Plan.