Show Sidebar Hide Sidebar

Scalability of Approximate Nearest Neighbors in Scikit-learn

This example studies the scalability profile of approximate 10-neighbors queries using the LSHForest with n_estimators=20 and n_candidates=200 when varying the number of samples in the dataset. The first plot demonstrates the relationship between query time and index size of LSHForest. Query time is compared with the brute force method in exact nearest neighbor search for the same index sizes. The brute force queries have a very predictable linear scalability with the index (full scan). LSHForest index have sub-linear scalability profile but can be slower for small datasets.

The second plot shows the speedup when using approximate queries vs brute force exact queries. The speedup tends to increase with the dataset size but should reach a plateau typically when doing queries on datasets with millions of samples and a few hundreds of dimensions. Higher dimensional datasets tends to benefit more from LSHForest indexing.

The break even point (speedup = 1) depends on the dimensionality and structure of the indexed data and the parameters of the LSHForest index.

The precision of approximate queries should decrease slowly with the dataset size. The speed of the decrease depends mostly on the LSHForest parameters and the dimensionality of the data.

New to Plotly?

Plotly's Python library is free and open source! Get started by downloading the client and reading the primer.
You can set up Plotly to work in online or offline mode, or in jupyter notebooks.
We also have a quick-reference cheatsheet (new!) to help you get started!

Version

In [1]:
import sklearn
sklearn.__version__
Out[1]:
'0.18.1'

Imports

In [2]:
import plotly.plotly as py
import plotly.graph_objs as go

from __future__ import division
print(__doc__)

import time
import numpy as np
from sklearn.datasets.samples_generator import make_blobs
from sklearn.neighbors import LSHForest
from sklearn.neighbors import NearestNeighbors
Automatically created module for IPython interactive environment

Calculations

In [3]:
# Parameters of the study
n_samples_min = int(1e3)
n_samples_max = int(1e5)
n_features = 100
n_centers = 100
n_queries = 100
n_steps = 6
n_iter = 5

# Initialize the range of `n_samples`
n_samples_values = np.logspace(np.log10(n_samples_min),
                               np.log10(n_samples_max),
                               n_steps).astype(np.int)

# Generate some structured data
rng = np.random.RandomState(42)
all_data, _ = make_blobs(n_samples=n_samples_max + n_queries,
                         n_features=n_features, centers=n_centers, shuffle=True,
                         random_state=0)
queries = all_data[:n_queries]
index_data = all_data[n_queries:]

# Metrics to collect for the plots
average_times_exact = []
average_times_approx = []
std_times_approx = []
accuracies = []
std_accuracies = []
average_speedups = []
std_speedups = []

Calculate the average query time

In [4]:
for n_samples in n_samples_values:
    X = index_data[:n_samples]
    # Initialize LSHForest for queries of a single neighbor
    lshf = LSHForest(n_estimators=20, n_candidates=200,
                     n_neighbors=10).fit(X)
    nbrs = NearestNeighbors(algorithm='brute', metric='cosine',
                            n_neighbors=10).fit(X)
    time_approx = []
    time_exact = []
    accuracy = []

    for i in range(n_iter):
        # pick one query at random to study query time variability in LSHForest
        query = queries[[rng.randint(0, n_queries)]]

        t0 = time.time()
        exact_neighbors = nbrs.kneighbors(query, return_distance=False)
        time_exact.append(time.time() - t0)

        t0 = time.time()
        approx_neighbors = lshf.kneighbors(query, return_distance=False)
        time_approx.append(time.time() - t0)

        accuracy.append(np.in1d(approx_neighbors, exact_neighbors).mean())

    average_time_exact = np.mean(time_exact)
    average_time_approx = np.mean(time_approx)
    speedup = np.array(time_exact) / np.array(time_approx)
    average_speedup = np.mean(speedup)
    mean_accuracy = np.mean(accuracy)
    std_accuracy = np.std(accuracy)
    print("Index size: %d, exact: %0.3fs, LSHF: %0.3fs, speedup: %0.1f, "
          "accuracy: %0.2f +/-%0.2f" %
          (n_samples, average_time_exact, average_time_approx, average_speedup,
           mean_accuracy, std_accuracy))

    accuracies.append(mean_accuracy)
    std_accuracies.append(std_accuracy)
    average_times_exact.append(average_time_exact)
    average_times_approx.append(average_time_approx)
    std_times_approx.append(np.std(time_approx))
    average_speedups.append(average_speedup)
    std_speedups.append(np.std(speedup))
Index size: 1000, exact: 0.001s, LSHF: 0.006s, speedup: 0.1, accuracy: 1.00 +/-0.00
Index size: 2511, exact: 0.002s, LSHF: 0.007s, speedup: 0.2, accuracy: 1.00 +/-0.00
Index size: 6309, exact: 0.005s, LSHF: 0.008s, speedup: 0.7, accuracy: 1.00 +/-0.00
Index size: 15848, exact: 0.015s, LSHF: 0.010s, speedup: 1.5, accuracy: 1.00 +/-0.00
Index size: 39810, exact: 0.027s, LSHF: 0.009s, speedup: 2.9, accuracy: 1.00 +/-0.00
Index size: 100000, exact: 0.070s, LSHF: 0.016s, speedup: 4.8, accuracy: 0.98 +/-0.04

Plot average query time against n_samples

In [5]:
p1 = go.Scatter(x=n_samples_values,
                y=average_times_approx,
                error_y=dict(visible=True, 
                             arrayminus=std_times_approx),
                line=dict(color='red', width=2),
                name='LSHForest')

p2 = go.Scatter(x=n_samples_values, y=average_times_exact,
                mode='lines', line=dict(color='blue', width=2),
                name="NearestNeighbors(algorithm='brute', metric='cosine')")

layout = go.Layout(title = "Impact of index size on response time for first "
                            "nearest neighbors queries",
                   xaxis=dict(title="n_samples"),
                   yaxis=dict(title="Average query time in seconds"))
fig = go.Figure(data=[p1, p2], layout=layout)
py.iplot(fig)
Out[5]:

Plot average query speedup versus index size

In [6]:
p1 = go.Scatter(x=n_samples_values, 
                y=average_speedups, 
                error_y=dict(visible=True, 
                             arrayminus=std_speedups),
                line=dict(color='red', width=2))

layout = go.Layout(title = "Speedup of the approximate NN queries vs brute force",
                   xaxis=dict(title="n_samples"),
                   yaxis=dict(title="Average speedup"))
fig = go.Figure(data=[p1], layout=layout)
py.iplot(fig)
Out[6]:

Plot average precision versus index size

In [7]:
p1 = go.Scatter(x=n_samples_values, y=accuracies,
                error_y=dict(visible=True, 
                             arrayminus=std_accuracies),
                line=dict(color='cyan', width=2))

layout = go.Layout(title = "Precision of 10-nearest-neighbors queries with index size",
                   xaxis=dict(title="n_samples"),
                   yaxis=dict(title="precision@10", range=[0, 1.1]))
fig = go.Figure(data=[p1], layout=layout)
py.iplot(fig)
Out[7]:

License

Authors:

    Maheshakya Wijewardena <maheshakya.10@cse.mrt.ac.lk>

    Olivier Grisel <olivier.grisel@ensta.org>

License:

     BSD 3 clause
Still need help?
Contact Us

For guaranteed 24 hour response turnarounds, upgrade to a Developer Support Plan.