Show Sidebar Hide Sidebar

# Normal and Shrinkage Linear Discriminant Analysis for classification in Scikit-learn

Shows how shrinkage improves classification.

#### New to Plotly?¶

You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

### Version¶

In [1]:
import sklearn
sklearn.__version__

Out[1]:
'0.18'

### Imports¶

This tutorial imports make_blobs and LinearDiscriminantAnalysis.

In [2]:
from __future__ import division

import plotly.plotly as py
import plotly.graph_objs as go

import numpy as np
import matplotlib.pyplot as plt

from sklearn.datasets import make_blobs
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis


### Calculations¶

In [3]:
n_train = 20  # samples for training
n_test = 200  # samples for testing
n_averages = 50  # how often to repeat classification
n_features_max = 75  # maximum number of features
step = 4  # step size for the calculation

def generate_data(n_samples, n_features):
"""Generate random blob-ish data with noisy features.

This returns an array of input data with shape (n_samples, n_features)
and an array of n_samples target labels.

Only one feature contains discriminative information, the other features
contain only noise.
"""
X, y = make_blobs(n_samples=n_samples, n_features=1, centers=[[-2], [2]])

if n_features > 1:
X = np.hstack([X, np.random.randn(n_samples, n_features - 1)])
return X, y

acc_clf1, acc_clf2 = [], []
n_features_range = range(1, n_features_max + 1, step)
for n_features in n_features_range:
score_clf1, score_clf2 = 0, 0
for _ in range(n_averages):
X, y = generate_data(n_train, n_features)

clf1 = LinearDiscriminantAnalysis(solver='lsqr', shrinkage='auto').fit(X, y)
clf2 = LinearDiscriminantAnalysis(solver='lsqr', shrinkage=None).fit(X, y)

X, y = generate_data(n_test, n_features)
score_clf1 += clf1.score(X, y)
score_clf2 += clf2.score(X, y)

acc_clf1.append(score_clf1 / n_averages)
acc_clf2.append(score_clf2 / n_averages)

features_samples_ratio = np.array(n_features_range) / n_train


### Plots¶

In [4]:
lda_shrinkage = go.Scatter(x=features_samples_ratio,
y=acc_clf1, mode= "lines",
name="Linear Discriminant Analysis with shrinkage",
line=dict(color='navy', width=2))
lda_plot = go.Scatter(x=features_samples_ratio, y = acc_clf2,
mode="lines",
name="Linear Discriminant Analysis",
line=dict(color='gold'))

data = [lda_plot, lda_shrinkage]
layout = go.Layout(xaxis=dict(title="n_features/n_samples",
showgrid=False),
yaxis=dict(title= "Classification Accuracy",
showgrid=False))
fig = go.Figure(data=data, layout=layout)

py.iplot(fig)

Out[4]: