Show Sidebar Hide Sidebar

# SVM Separating Hyperplane for Unbalanced Classes in Scikit-learn

Find the optimal separating hyperplane using an SVC for classes that are unbalanced.

We first find the separating plane with a plain SVC and then plot (dashed) the separating hyperplane with automatically correction for unbalanced classes.

#### New to Plotly?¶

You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

### Version¶

In [1]:
import sklearn
sklearn.__version__

Out[1]:
'0.18.1'

### Imports¶

In [2]:
print(__doc__)

import plotly.plotly as py
import plotly.graph_objs as go

import numpy as np
from sklearn import svm

Automatically created module for IPython interactive environment


### Calculations¶

In [3]:
# we create 40 separable points
rng = np.random.RandomState(0)
n_samples_1 = 1000
n_samples_2 = 100
X = np.r_[1.5 * rng.randn(n_samples_1, 2),
0.5 * rng.randn(n_samples_2, 2) + [2, 2]]
y = [0] * (n_samples_1) + [1] * (n_samples_2)

# fit the model and get the separating hyperplane
clf = svm.SVC(kernel='linear', C=1.0)
clf.fit(X, y)

w = clf.coef_[0]
a = -w[0] / w[1]
xx = np.linspace(-5, 5)
yy = a * xx - clf.intercept_[0] / w[1]

# get the separating hyperplane using weighted classes
wclf = svm.SVC(kernel='linear', class_weight={1: 10})
wclf.fit(X, y)

ww = wclf.coef_[0]
wa = -ww[0] / ww[1]
wyy = wa * xx - wclf.intercept_[0] / ww[1]


### Plot Results¶

In [6]:
# plot separating hyperplanes and samples

h0 = go.Scatter(x=xx, y=yy,
mode='lines',
line=dict(color='black', width=1),
name='no weights')
h1 = go.Scatter(x=xx, y=wyy,
mode='lines',
line=dict(color='black', width=1,
dash='dash'),
name='with weights')

p1 = go.Scatter(x=X[:, 0], y=X[:, 1],
mode='markers',
showlegend=False,
marker=dict(color=y,
colorscale='Jet',
line=dict(color='black', width=1)))
layout = go.Layout(xaxis=dict(zeroline=False),
yaxis=dict(zeroline=False),
hovermode='closest')
fig = go.Figure(data = [h0, h1, p1], layout=layout)

py.iplot(fig)

Out[6]:
Still need help?