Show Sidebar Hide Sidebar

Using FunctionTransformer to Select Columns in Scikit-learn

Shows how to use a function transformer in a pipeline. If you know your dataset’s first principle component is irrelevant for a classification task, you can use the FunctionTransformer to select all but the first column of the PCA transformed data.

New to Plotly?

Plotly's Python library is free and open source! Get started by downloading the client and reading the primer.
You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

Version

In [1]:
import sklearn
sklearn.__version__
Out[1]:
'0.18.1'

Imports

This tutorial imports train_test_split, PCA, make_pipeline and FunctionTransformer.

In [2]:
import plotly.plotly as py
import plotly.graph_objs as go
from plotly import tools
import numpy as np

from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import FunctionTransformer

Calculations

In [3]:
def _generate_vector(shift=0.5, noise=15):
    return np.arange(1000) + (np.random.rand(1000) - shift) * noise


def generate_dataset():
    """
    This dataset is two lines with a slope ~ 1, where one has
    a y offset of ~100
    """
    return np.vstack((
        np.vstack((
            _generate_vector(),
            _generate_vector() + 100,
        )).T,
        np.vstack((
            _generate_vector(),
            _generate_vector(),
        )).T,
    )), np.hstack((np.zeros(1000), np.ones(1000)))


def all_but_first_column(X):
    return X[:, 1:]


def drop_first_component(X, y):
    """
    Create a pipeline with PCA and the column selector and use it to
    transform the dataset.
    """
    pipeline = make_pipeline(
        PCA(), FunctionTransformer(all_but_first_column),
    )
    X_train, X_test, y_train, y_test = train_test_split(X, y)
    pipeline.fit(X_train, y_train)
    return pipeline.transform(X_test), y_test

Plot Results

In [4]:
if __name__ == '__main__':
    X, y = generate_dataset()
    lw = 0
    fig = tools.make_subplots(rows=1, cols=2, 
                              print_grid=False)
    
    p1 = go.Scatter(x=X[:, 0], y=X[:, 1],
                    mode='markers', 
                    marker=dict(color=y, colorscale="Jet"),
                    showlegend=False)
    fig.append_trace(p1, 1, 1)
    X_transformed, y_transformed = drop_first_component(*generate_dataset())
    
    p2 = go.Scatter(x=X_transformed[:, 0],
                    y=np.zeros(len(X_transformed)),
                    mode='markers',
                    marker=dict(color=y_transformed, colorscale="Jet"),
                    showlegend=False)
    fig.append_trace(p2, 1, 2)
In [5]:
py.iplot(fig)
Out[5]:
Still need help?
Contact Us

For guaranteed 24 hour response turnarounds, upgrade to a Developer Support Plan.