Show Sidebar Hide Sidebar

# Comparison of F-Test and Mutual Information in Scikit-learn

his example illustrates the differences between univariate F-test statistics and mutual information.

We consider 3 features x_1, x_2, x_3 distributed uniformly over [0, 1], the target depends on them as follows: y = x_1 + sin(6 pi x_2) + 0.1 * N(0, 1), that is the third features is completely irrelevant. The code below plots the dependency of y against individual x_i and normalized values of univariate F-tests statistics and mutual information.

As F-test captures only linear dependency, it rates x_1 as the most discriminative feature. On the other hand, mutual information can capture any kind of dependency between variables and it rates x_2 as the most discriminative feature, which probably agrees better with our intuitive perception for this example. Both methods correctly marks x_3 as irrelevant.

#### New to Plotly?¶

You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

### Version¶

In [1]:
import sklearn
sklearn.__version__

Out[1]:
'0.18.1'

### Imports¶

This tutorial imports f_regression and mutual_info_regression.

In [2]:
print(__doc__)

import plotly.plotly as py
import plotly.graph_objs as go
from plotly import tools

import numpy as np
from sklearn.feature_selection import f_regression, mutual_info_regression

Automatically created module for IPython interactive environment


### Calculations¶

In [3]:
np.random.seed(0)
X = np.random.rand(1000, 3)
y = X[:, 0] + np.sin(6 * np.pi * X[:, 1]) + 0.1 * np.random.randn(1000)

f_test, _ = f_regression(X, y)
f_test /= np.max(f_test)

mi = mutual_info_regression(X, y)
mi /= np.max(mi)


### Plot Results¶

In [4]:
titles = []
for i in range(3):
titles.append("F-test={:.2f}, MI={:.2f}".format(f_test[i], mi[i]))

fig = tools.make_subplots(rows=1, cols=3,
print_grid=False,
subplot_titles=tuple(titles))

In [5]:
for i in range(3):
trace = go.Scatter(x=X[:, i], y=y,
mode='markers',
marker=dict(color='blue',
line=dict(width=1, color='black')
),
showlegend=False
)
fig.append_trace(trace, 1, i+1)

for i in map(str ,range(1, 4)):
x = 'xaxis' + i
y = 'yaxis' + i
fig['layout'][x].update(title="<i>x_{}</i>".format(int(i)),
showgrid=False, zeroline=False)
fig['layout'][y].update(title='<i>y</i>', showgrid=False,
zeroline=False)

In [6]:
py.iplot(fig)

Out[6]:
Still need help?