Show Sidebar Hide Sidebar

Hyper-Parameters of Approximate Nearest Neighbors in Scikit-learn

This example demonstrates the behaviour of the accuracy of the nearest neighbor queries of Locality Sensitive Hashing Forest as the number of candidates and the number of estimators (trees) vary.

In the first plot, accuracy is measured with the number of candidates. Here, the term “number of candidates” refers to maximum bound for the number of distinct points retrieved from each tree to calculate the distances. Nearest neighbors are selected from this pool of candidates. Number of estimators is maintained at three fixed levels (1, 5, 10).

In the second plot, the number of candidates is fixed at 50. Number of trees is varied and the accuracy is plotted against those values. To measure the accuracy, the true nearest neighbors are required, therefore sklearn.neighbors.NearestNeighbors is used to compute the exact neighbors.

New to Plotly?

Plotly's Python library is free and open source! Get started by downloading the client and reading the primer.
You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

Version

In [1]:
import sklearn
sklearn.__version__
Out[1]:
'0.18.1'

Imports

This tutorial imports make_blobs, LSHForest and NearestNeighbors.

In [2]:
import plotly.plotly as py
import plotly.graph_objs as go
from plotly import tools

from __future__ import division
print(__doc__)

import numpy as np
from sklearn.datasets.samples_generator import make_blobs
from sklearn.neighbors import LSHForest
from sklearn.neighbors import NearestNeighbors
Automatically created module for IPython interactive environment

Calculations

In [3]:
# Initialize size of the database, iterations and required neighbors.
n_samples = 10000
n_features = 100
n_queries = 30
rng = np.random.RandomState(42)

# Generate sample data
X, _ = make_blobs(n_samples=n_samples + n_queries,
                  n_features=n_features, centers=10,
                  random_state=0)
X_index = X[:n_samples]
X_query = X[n_samples:]
# Get exact neighbors
nbrs = NearestNeighbors(n_neighbors=1, algorithm='brute',
                        metric='cosine').fit(X_index)
neighbors_exact = nbrs.kneighbors(X_query, return_distance=False)

# Set `n_candidate` values
n_candidates_values = np.linspace(10, 500, 5).astype(np.int)
n_estimators_for_candidate_value = [1, 5, 10]
n_iter = 10
stds_accuracies = np.zeros((len(n_estimators_for_candidate_value),
                            n_candidates_values.shape[0]),
                           dtype=float)
accuracies_c = np.zeros((len(n_estimators_for_candidate_value),
                         n_candidates_values.shape[0]), dtype=float)

LSH Forest is a stochastic index:

perform several iteration to estimate expected accuracy and standard deviation displayed as error bars in the plots

In [4]:
for j, value in enumerate(n_estimators_for_candidate_value):
    for i, n_candidates in enumerate(n_candidates_values):
        accuracy_c = []
        for seed in range(n_iter):
            lshf = LSHForest(n_estimators=value,
                             n_candidates=n_candidates, n_neighbors=1,
                             random_state=seed)
            # Build the LSH Forest index
            lshf.fit(X_index)
            # Get neighbors
            neighbors_approx = lshf.kneighbors(X_query,
                                               return_distance=False)
            accuracy_c.append(np.sum(np.equal(neighbors_approx,
                                              neighbors_exact)) /
                              n_queries)

        stds_accuracies[j, i] = np.std(accuracy_c)
        accuracies_c[j, i] = np.mean(accuracy_c)

# Set `n_estimators` values
n_estimators_values = [1, 5, 10, 20, 30, 40, 50]
accuracies_trees = np.zeros(len(n_estimators_values), dtype=float)

Calculate average accuracy for each value of n_estimators

In [5]:
for i, n_estimators in enumerate(n_estimators_values):
    lshf = LSHForest(n_estimators=n_estimators, n_neighbors=1)
    # Build the LSH Forest index
    lshf.fit(X_index)
    # Get neighbors
    neighbors_approx = lshf.kneighbors(X_query, return_distance=False)
    accuracies_trees[i] = np.sum(np.equal(neighbors_approx,
                                          neighbors_exact))/n_queries

Plot the Accuracy Variation with n_candidates

In [6]:
colors = ['cyan', 'magenta', 'yellow']

fig = tools.make_subplots(rows=1, cols=2,
                          print_grid=False,
                          subplot_titles=("Accuracy variation with n_candidates",
                                          "Accuracy variation with n_estimators"))

for i, n_estimators in enumerate(n_estimators_for_candidate_value):
    label = 'n_estimators = %d ' % n_estimators
    trace = go.Scatter(x=n_candidates_values, 
                       y=accuracies_c[i, :],
                       error_y=dict(visible=True, 
                                    arrayminus=stds_accuracies[i, :]),
                       line=dict(color=colors[i]), name=label)
    fig.append_trace(trace, 1, 1)
    
fig['layout']['xaxis1'].update(title="n_candidates")
fig['layout']['yaxis1'].update(title="Accuracy")

# Plot the accuracy variation with `n_estimators`
trace1 = go.Scatter(x=n_estimators_values, y=accuracies_trees, 
                    mode='markers', marker=dict(color='black'))
fig.append_trace(trace1, 1, 2)
trace2 = go.Scatter(x=n_estimators_values, y=accuracies_trees,
                    mode='lines', line=dict(color='green', width=2))
fig.append_trace(trace2, 1, 2)

fig['layout']['xaxis2'].update(title="n_estimators")
fig['layout']['yaxis2'].update(title="Accuracy")

fig['layout'].update(hovermode='closest',
                     showlegend=False, height=600)
In [7]:
py.iplot(fig)
Out[7]:

License

Author:

    Maheshakya Wijewardena <maheshakya.10@cse.mrt.ac.lk>

License:

    BSD 3 clause
Still need help?
Contact Us

For guaranteed 24 hour response turnarounds, upgrade to a Developer Support Plan.