Show Sidebar Hide Sidebar

Robust Linear Model Estimation using RANSAC in Scikit-learn

In this example we see how to robustly fit a linear model to faulty data using the RANSAC algorithm.

New to Plotly?

Plotly's Python library is free and open source! Get started by downloading the client and reading the primer.
You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

Version

In [1]:
import sklearn
sklearn.__version__
Out[1]:
'0.18.1'

Imports

In [2]:
import plotly.plotly as py
import plotly.graph_objs as go

import numpy as np
from sklearn import linear_model, datasets

Calculations

In [3]:
n_samples = 1000
n_outliers = 50


X, y, coef = datasets.make_regression(n_samples=n_samples, n_features=1,
                                      n_informative=1, noise=10,
                                      coef=True, random_state=0)

# Add outlier data
np.random.seed(0)
X[:n_outliers] = 3 + 0.5 * np.random.normal(size=(n_outliers, 1))
y[:n_outliers] = -3 + 10 * np.random.normal(size=n_outliers)

# Fit line using all data
model = linear_model.LinearRegression()
model.fit(X, y)

# Robustly fit linear model with RANSAC algorithm
model_ransac = linear_model.RANSACRegressor(linear_model.LinearRegression())
model_ransac.fit(X, y)
inlier_mask = model_ransac.inlier_mask_
outlier_mask = np.logical_not(inlier_mask)

# Predict data of estimated models
line_X = np.arange(-5, 5)
line_y = model.predict(line_X[:, np.newaxis])
line_y_ransac = model_ransac.predict(line_X[:, np.newaxis])

Compare estimated coefficients

In [4]:
print("Estimated coefficients (true, normal, RANSAC):")
print(coef, model.coef_, model_ransac.estimator_.coef_)
Estimated coefficients (true, normal, RANSAC):
(array(82.1903908407869), array([ 54.17236387]), array([ 82.08533159]))

Plot Results

In [5]:
def data_to_plotly(x):
    k = []
    
    for i in range(0, len(x)):
        k.append(x[i][0])
        
    return k
In [6]:
lw = 2

p1 = go.Scatter(x=data_to_plotly(X[inlier_mask]), y=y[inlier_mask], 
                mode='markers',
                marker=dict(color='yellowgreen', size=6),
                name='Inliers')
p2 = go.Scatter(x=data_to_plotly(X[outlier_mask]), y=y[outlier_mask], 
                mode='markers',
                marker=dict(color='gold', size=6),
                name='Outliers')

p3 = go.Scatter(x=line_X, y=line_y, 
                mode='lines',
                line=dict(color='navy', width=lw,),
                name='Linear regressor')
p4 = go.Scatter(x=line_X, y=line_y_ransac, 
                mode='lines',
                line=dict(color='cornflowerblue', width=lw),
                name='RANSAC regressor')
data = [p1, p2, p3, p4]
layout = go.Layout(xaxis=dict(zeroline=False, showgrid=False),
                   yaxis=dict(zeroline=False, showgrid=False)
                  )
fig = go.Figure(data=data, layout=layout)
In [7]:
py.iplot(fig)
Out[7]:
Still need help?
Contact Us

For guaranteed 24 hour response turnarounds, upgrade to a Developer Support Plan.