Show Sidebar Hide Sidebar

# Robust Linear Model Estimation using RANSAC in Scikit-learn

In this example we see how to robustly fit a linear model to faulty data using the RANSAC algorithm.

#### New to Plotly?¶

You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

### Version¶

In [1]:
import sklearn
sklearn.__version__

Out[1]:
'0.18.1'

### Imports¶

In [2]:
import plotly.plotly as py
import plotly.graph_objs as go

import numpy as np
from sklearn import linear_model, datasets


### Calculations¶

In [3]:
n_samples = 1000
n_outliers = 50

X, y, coef = datasets.make_regression(n_samples=n_samples, n_features=1,
n_informative=1, noise=10,
coef=True, random_state=0)

np.random.seed(0)
X[:n_outliers] = 3 + 0.5 * np.random.normal(size=(n_outliers, 1))
y[:n_outliers] = -3 + 10 * np.random.normal(size=n_outliers)

# Fit line using all data
model = linear_model.LinearRegression()
model.fit(X, y)

# Robustly fit linear model with RANSAC algorithm
model_ransac = linear_model.RANSACRegressor(linear_model.LinearRegression())
model_ransac.fit(X, y)

# Predict data of estimated models
line_X = np.arange(-5, 5)
line_y = model.predict(line_X[:, np.newaxis])
line_y_ransac = model_ransac.predict(line_X[:, np.newaxis])


Compare estimated coefficients

In [4]:
print("Estimated coefficients (true, normal, RANSAC):")
print(coef, model.coef_, model_ransac.estimator_.coef_)

Estimated coefficients (true, normal, RANSAC):
(array(82.1903908407869), array([ 54.17236387]), array([ 82.08533159]))


### Plot Results¶

In [5]:
def data_to_plotly(x):
k = []

for i in range(0, len(x)):
k.append(x[i][0])

return k

In [6]:
lw = 2

mode='markers',
marker=dict(color='yellowgreen', size=6),
name='Inliers')
mode='markers',
marker=dict(color='gold', size=6),
name='Outliers')

p3 = go.Scatter(x=line_X, y=line_y,
mode='lines',
line=dict(color='navy', width=lw,),
name='Linear regressor')
p4 = go.Scatter(x=line_X, y=line_y_ransac,
mode='lines',
line=dict(color='cornflowerblue', width=lw),
name='RANSAC regressor')
data = [p1, p2, p3, p4]
layout = go.Layout(xaxis=dict(zeroline=False, showgrid=False),
yaxis=dict(zeroline=False, showgrid=False)
)
fig = go.Figure(data=data, layout=layout)

In [7]:
py.iplot(fig)

Out[7]:
Still need help?