Show Sidebar Hide Sidebar

Ledoit-Wolf vs OAS Estimation in Scikit-learn

The usual covariance maximum likelihood estimate can be regularized using shrinkage. Ledoit and Wolf proposed a close formula to compute the asymptotically optimal shrinkage parameter (minimizing a MSE criterion), yielding the Ledoit-Wolf covariance estimate.

Chen et al. proposed an improvement of the Ledoit-Wolf shrinkage parameter, the OAS coefficient, whose convergence is significantly better under the assumption that the data are Gaussian.

This example, inspired from Chen’s publication [1], shows a comparison of the estimated MSE of the LW and OAS methods, using Gaussian distributed data.

[1] “Shrinkage Algorithms for MMSE Covariance Estimation” Chen et al., IEEE Trans. on Sign. Proc., Volume 58, Issue 10, October 2010.

New to Plotly?

Plotly's Python library is free and open source! Get started by downloading the client and reading the primer.
You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

Version

In [1]:
import sklearn
sklearn.__version__
Out[1]:
'0.18'

Imports

This tutorial imports toeplitz, cholesky,LedoitWolf and OAS.

In [2]:
print(__doc__)

import plotly.plotly as py
import plotly.graph_objs as go

import numpy as np
from scipy.linalg import toeplitz, cholesky

from sklearn.covariance import LedoitWolf, OAS
Automatically created module for IPython interactive environment

Calculations

In [3]:
np.random.seed(0)
n_features = 100
# simulation covariance matrix (AR(1) process)
r = 0.1
real_cov = toeplitz(r ** np.arange(n_features))
coloring_matrix = cholesky(real_cov)

n_samples_range = np.arange(6, 31, 1)
repeat = 100
lw_mse = np.zeros((n_samples_range.size, repeat))
oa_mse = np.zeros((n_samples_range.size, repeat))
lw_shrinkage = np.zeros((n_samples_range.size, repeat))
oa_shrinkage = np.zeros((n_samples_range.size, repeat))
for i, n_samples in enumerate(n_samples_range):
    for j in range(repeat):
        X = np.dot(
            np.random.normal(size=(n_samples, n_features)), coloring_matrix.T)

        lw = LedoitWolf(store_precision=False, assume_centered=True)
        lw.fit(X)
        lw_mse[i, j] = lw.error_norm(real_cov, scaling=False)
        lw_shrinkage[i, j] = lw.shrinkage_

        oa = OAS(store_precision=False, assume_centered=True)
        oa.fit(X)
        oa_mse[i, j] = oa.error_norm(real_cov, scaling=False)
        oa_shrinkage[i, j] = oa.shrinkage_

Plot MSE

In [4]:
Ledoit_Wolf = go.Scatter(x=n_samples_range, 
                  y=lw_mse.mean(1), 
                  error_y=dict(visible=True, arrayminus=lw_mse.std(1)),
                  name='Ledoit-Wolf', 
                  mode='lines',
                  line= dict(color='navy', width=2)
                 )
OAS = go.Scatter(x=n_samples_range, 
                 y=oa_mse.mean(1), 
                 error_y=dict(visible=True, arrayminus=oa_mse.std(1)),
                 name='OAS', 
                 mode='lines',
                 line=dict(color='#FF8C00', width=2)
                )

data = [Ledoit_Wolf, OAS]
layout = go.Layout(title="Comparison of covariance estimators",
                   yaxis=dict(title="Squared error"),
                   xaxis=dict(title="n_samples")
                  )

fig = go.Figure(data=data, layout=layout)
In [5]:
py.iplot(fig)
Out[5]:

Plot shrinkage coefficient

In [6]:
Ledoit_Wolf = go.Scatter(x=n_samples_range, 
                    y=lw_shrinkage.mean(1),
                    error_y=dict(visible=True, arrayminus=lw_mse.std(1)),
                    name='Ledoit-Wolf', 
                    mode='lines',
                    line= dict(color='navy', width=2)
                    )

OAS = go.Scatter(x=n_samples_range, 
                 y=oa_shrinkage.mean(1), 
                 error_y=dict(visible=True, arrayminus=oa_shrinkage.std(1)),
                 name='OAS', 
                 mode='lines',
                 line=dict(color='#FF8C00', width=2)
                )

data = [Ledoit_Wolf, OAS]
layout = go.Layout(title="Comparison of covariance estimators",
                   yaxis=dict(title="Shrinkage"),
                   xaxis=dict(title="n_samples")
                  )

fig = go.Figure(data=data, layout=layout)
In [7]:
py.iplot(fig)
Out[7]:
Still need help?
Contact Us

For guaranteed 24 hour response turnarounds, upgrade to a Developer Support Plan.