Show Sidebar Hide Sidebar

Multilabel classification in Scikit-learn

This example simulates a multi-label document classification problem. The dataset is generated randomly based on the following process:

pick the number of labels: n ~ Poisson(n_labels) n times, choose a class c: c ~ Multinomial(theta) pick the document length: k ~ Poisson(length) k times, choose a word: w ~ Multinomial(theta_c)

In the above process, rejection sampling is used to make sure that n is more than 2, and that the document length is never zero. Likewise, we reject classes which have already been chosen. The documents that are assigned to both classes are plotted surrounded by two colored circles.

The classification is performed by projecting to the first two principal components found by PCA and CCA for visualisation purposes, followed by using the sklearn.multiclass.OneVsRestClassifier metaclassifier using two SVCs with linear kernels to learn a discriminative model for each class. Note that PCA is used to perform an unsupervised dimensionality reduction, while CCA is used to perform a supervised one.

Note: in the plot, “unlabeled samples” does not mean that we don’t know the labels (as in semi-supervised learning) but that the samples simply do not have a label.

New to Plotly?

Plotly's Python library is free and open source! Get started by downloading the client and reading the primer.
You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

Version

In [1]:
import sklearn
sklearn.__version__
Out[1]:
'0.18'

Imports

In [2]:
import numpy as np

import plotly.plotly as py
import plotly.graph_objs as go
from plotly import tools

from sklearn.datasets import make_multilabel_classification
from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import SVC
from sklearn.preprocessing import LabelBinarizer
from sklearn.decomposition import PCA
from sklearn.cross_decomposition import CCA

Calculations

In [3]:
print(__doc__)
fig = tools.make_subplots(rows=2, cols=2, 
                          subplot_titles=('With unlabeled samples + CCA',
                                          'With unlabeled samples + PCA', 
                                          'Without unlabeled samples + CCA',
                                          'Without unlabeled samples + PCA')
                         )
def plot_hyperplane(clf, min_x, max_x, name, shape, leg):
    # get the separating hyperplane
    w = clf.coef_[0]
    a = -w[0] / w[1]
    xx = np.linspace(min_x - 5, max_x + 5)  # make sure the line is long enough
    yy = a * xx - (clf.intercept_[0]) / w[1]
    if leg ==1: leg = True
    else: leg = False
    return go.Scatter(x=xx, y=yy, name=name, mode="lines",
                      showlegend=leg,
                      line=dict(
                            color=('black'),
                            width=1.5,
                            dash=shape)
                     )

def plot_subfigure(X, Y, subplot_row ,subplot_col , transform ,leg):
    if transform == "pca":
        X = PCA(n_components=2).fit_transform(X)
    elif transform == "cca":
        X = CCA(n_components=2).fit(X, Y).transform(X)
    else:
        raise ValueError

    min_x = np.min(X[:, 0])
    max_x = np.max(X[:, 0])

    min_y = np.min(X[:, 1])
    max_y = np.max(X[:, 1])

    classif = OneVsRestClassifier(SVC(kernel='linear'))
    classif.fit(X, Y)
    zero_class = np.where(Y[:, 0])
    one_class = np.where(Y[:, 1])
    
    trace1 = go.Scatter(x=X[:, 0], y=X[:, 1], mode="markers",
                        showlegend=False,
                        marker=dict(
                                color='gray',size =10,
                                line = dict(
                                        width = 2, color="black")
                        ))
    
    trace2 = go.Scatter(x=X[zero_class, 0][0], y=X[zero_class, 1][0],
                        name="Class 1", showlegend=leg,
                        mode='markers', 
                        marker=dict(
                                size=14, color='white',
                                line=dict(
                                        width=3, color='blue')
                       ))

    trace3 = go.Scatter(x=X[one_class, 0][0], y=X[one_class, 1][0],
                        name='Class 2', showlegend=leg,
                        mode = 'markers',
                        marker = dict(
                                size=14, color='white',
                                line=dict(
                                        width = 3, color='orange')
                        ))

    
    fig.append_trace(trace2, subplot_row, subplot_col)
    fig.append_trace(trace3, subplot_row, subplot_col)
    fig.append_trace(trace1, subplot_row, subplot_col)
    
    trace4 = plot_hyperplane(classif.estimators_[0], min_x, max_x, 
                    'Boundary<br>for class 1','dash', leg,)
    
    trace5 = plot_hyperplane(classif.estimators_[1], min_x, max_x, 
                    'Boundary<br>for class 2','dashdot',leg,)
    
    fig.append_trace(trace4, subplot_row, subplot_col)
    fig.append_trace(trace5, subplot_row, subplot_col)
    
    
    fig['layout']['xaxis1'].update(range=[-3, 3], zeroline=False,
                                  showgrid=False)
    fig['layout']['yaxis1'].update(range=[-5, 5], zeroline=False,
                                  showgrid=False)
    fig['layout']['xaxis3'].update(range=[-4, 4], zeroline=False,
                                  showgrid=False)
    fig['layout']['yaxis3'].update(range=[-4, 4], zeroline=False,
                                  showgrid=False)
    fig['layout']['xaxis4'].update(range=[-8, 8], zeroline=False,
                                  showgrid=False)
    fig['layout']['yaxis4'].update(range=[-10, 10], zeroline=False,
                                  showgrid=False)
    fig['layout']['xaxis2'].update(title='First principal component', range=[-3, 8],
                                   zeroline=False, showgrid=False)
    fig['layout']['yaxis2'].update(title='Second principal component', range=[-10, 10],
                                   zeroline= False, showgrid=False)
    fig['layout'].update(height=900, width=1000)
Automatically created module for IPython interactive environment
This is the format of your plot grid:
[ (1,1) x1,y1 ]  [ (1,2) x2,y2 ]
[ (2,1) x3,y3 ]  [ (2,2) x4,y4 ]

Plot Result

In [4]:
X, Y = make_multilabel_classification(n_classes=2, n_labels=1,
                                      allow_unlabeled=True,
                                      random_state=1)

plot_subfigure(X, Y, 1,1,  "cca", True)
plot_subfigure(X, Y, 1,2,  "pca", False)

X, Y = make_multilabel_classification(n_classes=2, n_labels=1,
                                      allow_unlabeled=False,
                                      random_state=1)

plot_subfigure(X, Y, 2,1,"cca", False)
plot_subfigure(X, Y, 2,2, "pca", False)

py.iplot(fig, filename="multilabel-classification")
Out[4]:
Still need help?
Contact Us

For guaranteed 24 hour response turnarounds, upgrade to a Developer Support Plan.