Show Sidebar Hide Sidebar

Robust Linear Estimator Fitting in Scikit-learn

Here a sine function is fit with a polynomial of order 3, for values close to zero.

  • Robust fitting is demoed in different situations:
  • No measurement errors, only modelling errors (fitting a sine with a polynomial)
  • Measurement errors in X
  • Measurement errors in y

The median absolute deviation to non corrupt new data is used to judge the quality of the prediction.

What we can see that:

  • RANSAC is good for strong outliers in the y direction
  • TheilSen is good for small outliers, both in direction X and y, but has a break point above which it performs worse than OLS.
  • The scores of HuberRegressor may not be compared directly to both TheilSen and RANSAC because it does not attempt to completely filter the outliers but lessen their effect.

New to Plotly?

Plotly's Python library is free and open source! Get started by downloading the client and reading the primer.
You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

Version

In [1]:
import sklearn
sklearn.__version__
Out[1]:
'0.18.1'

Imports

In [2]:
import plotly.plotly as py
import plotly.graph_objs as go
import numpy as np

from sklearn.linear_model import (
    LinearRegression, TheilSenRegressor, RANSACRegressor, HuberRegressor)
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import make_pipeline

Calculations

In [3]:
np.random.seed(42)

X = np.random.normal(size=400)
y = np.sin(X)
# Make sure that it X is 2D
X = X[:, np.newaxis]

X_test = np.random.normal(size=200)
y_test = np.sin(X_test)
X_test = X_test[:, np.newaxis]

y_errors = y.copy()
y_errors[::3] = 3

X_errors = X.copy()
X_errors[::3] = 3

y_errors_large = y.copy()
y_errors_large[::3] = 10

X_errors_large = X.copy()
X_errors_large[::3] = 10

estimators = [('OLS', LinearRegression()),
              ('Theil-Sen', TheilSenRegressor(random_state=42)),
              ('RANSAC', RANSACRegressor(random_state=42)),
              ('HuberRegressor', HuberRegressor())]
colors = {'OLS': 'turquoise', 'Theil-Sen': 'gold', 'RANSAC': 'lightgreen', 'HuberRegressor': 'black'}
linestyle = {'OLS': 'dash', 'Theil-Sen': 'dashdot', 'RANSAC': 'dot', 'HuberRegressor': 'dot'}
lw = 3

x_plot = np.linspace(X.min(), X.max())

Plot Results

In [4]:
plots = []

for title, this_X, this_y in [
        ('Modeling Errors Only', X, y),
        ('Corrupt X, Small Deviants', X_errors, y),
        ('Corrupt y, Small Deviants', X, y_errors),
        ('Corrupt X, Large Deviants', X_errors_large, y),
        ('Corrupt y, Large Deviants', X, y_errors_large)]:
    
    data = []
    trace = go.Scatter(x=this_X[:, 0], y=this_y, 
                       mode='markers', 
                       marker=dict(color='blue', size=3),
                       showlegend=False)
    data.append(trace)
    
    for name, estimator in estimators:
        model = make_pipeline(PolynomialFeatures(3), estimator)
        model.fit(this_X, this_y)
        mse = mean_squared_error(model.predict(X_test), y_test)
        y_plot = model.predict(x_plot[:, np.newaxis])
        
        trace = go.Scatter(x=x_plot, y=y_plot, 
                           mode='lines',
                           line=dict(color=colors[name], dash=linestyle[name],
                                     width=lw),
                           name='%s: error = %.3f' % (name, mse))
        data.append(trace)
        
    layout = go.Layout(title=title,
                       xaxis=dict(range=[-4, 10], showgrid=False,
                                  zeroline=False),
                       yaxis=dict(range=[-2, 10], showgrid=False,
                                  zeroline=False)
                      )
    fig = go.Figure(data=data, layout=layout)
    
    plots.append(fig)
    

Modeling Errors Only

In [5]:
py.iplot(plots[0])
Out[5]:

Corrupt X, Small Deviants

In [6]:
py.iplot(plots[1])
Out[6]:

Corrupt y, Small Deviants

In [7]:
py.iplot(plots[2])
Out[7]:

Corrupt X, Large Deviants

In [8]:
py.iplot(plots[3])
Out[8]:

Corrupt y, Large Deviants

In [9]:
py.iplot(plots[4])
Out[9]:
Still need help?
Contact Us

For guaranteed 24 hour response turnarounds, upgrade to a Developer Support Plan.