Show Sidebar Hide Sidebar

Comparison of F-Test and Mutual Information in Scikit-learn

his example illustrates the differences between univariate F-test statistics and mutual information.

We consider 3 features x_1, x_2, x_3 distributed uniformly over [0, 1], the target depends on them as follows: y = x_1 + sin(6 pi x_2) + 0.1 * N(0, 1), that is the third features is completely irrelevant. The code below plots the dependency of y against individual x_i and normalized values of univariate F-tests statistics and mutual information.

As F-test captures only linear dependency, it rates x_1 as the most discriminative feature. On the other hand, mutual information can capture any kind of dependency between variables and it rates x_2 as the most discriminative feature, which probably agrees better with our intuitive perception for this example. Both methods correctly marks x_3 as irrelevant.

New to Plotly?

Plotly's Python library is free and open source! Get started by downloading the client and reading the primer.
You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

Version

In [1]:
import sklearn
sklearn.__version__
Out[1]:
'0.18.1'

Imports

This tutorial imports f_regression and mutual_info_regression.

In [2]:
print(__doc__)

import plotly.plotly as py
import plotly.graph_objs as go
from plotly import tools

import numpy as np
from sklearn.feature_selection import f_regression, mutual_info_regression
Automatically created module for IPython interactive environment

Calculations

In [3]:
np.random.seed(0)
X = np.random.rand(1000, 3)
y = X[:, 0] + np.sin(6 * np.pi * X[:, 1]) + 0.1 * np.random.randn(1000)

f_test, _ = f_regression(X, y)
f_test /= np.max(f_test)

mi = mutual_info_regression(X, y)
mi /= np.max(mi)

Plot Results

In [4]:
titles = []
for i in range(3):
    titles.append("F-test={:.2f}, MI={:.2f}".format(f_test[i], mi[i]))
    
fig = tools.make_subplots(rows=1, cols=3,
                          print_grid=False,
                          subplot_titles=tuple(titles))
In [5]:
for i in range(3):
    trace = go.Scatter(x=X[:, i], y=y,
                       mode='markers',
                       marker=dict(color='blue', 
                                   line=dict(width=1, color='black')
                                  ),
                       showlegend=False
                      )
    fig.append_trace(trace, 1, i+1)
    
for i in map(str ,range(1, 4)):
    x = 'xaxis' + i
    y = 'yaxis' + i
    fig['layout'][x].update(title="<i>x_{}</i>".format(int(i)),
                            showgrid=False, zeroline=False)
    fig['layout'][y].update(title='<i>y</i>', showgrid=False,
                            zeroline=False)
In [6]:
py.iplot(fig)
Out[6]:
Still need help?
Contact Us

For guaranteed 24 hour response turnarounds, upgrade to a Developer Support Plan.