Skip to content

Commit a0f7da0

Browse files
committed
Moved all of the algorithm wrapper classes into separate files
1 parent 4cef691 commit a0f7da0

20 files changed

+631
-577
lines changed

ann_benchmarks/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from main import *
1+
from __future__ import absolute_import
2+
from ann_benchmarks.main import *

ann_benchmarks/algorithms/__init__.py

Whitespace-only changes.

ann_benchmarks/algorithms/annoy.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from __future__ import absolute_import
2+
import annoy
3+
from ann_benchmarks.algorithms.base import BaseANN
4+
5+
class Annoy(BaseANN):
6+
def __init__(self, metric, n_trees, search_k):
7+
self._n_trees = n_trees
8+
self._search_k = search_k
9+
self._metric = metric
10+
self.name = 'Annoy(n_trees=%d, search_k=%d)' % (n_trees, search_k)
11+
12+
def fit(self, X):
13+
self._annoy = annoy.AnnoyIndex(f=X.shape[1], metric=self._metric)
14+
for i, x in enumerate(X):
15+
self._annoy.add_item(i, x.tolist())
16+
self._annoy.build(self._n_trees)
17+
18+
def query(self, v, n):
19+
return self._annoy.get_nns_by_vector(v.tolist(), n, self._search_k)

ann_benchmarks/algorithms/balltree.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from __future__ import absolute_import
2+
import sklearn.neighbors
3+
import sklearn.preprocessing
4+
from ann_benchmarks.algorithms.base import BaseANN
5+
6+
class BallTree(BaseANN):
7+
def __init__(self, metric, leaf_size=20):
8+
self.name = 'BallTree(leaf_size=%d)' % leaf_size
9+
self._leaf_size = leaf_size
10+
self._metric = metric
11+
12+
def fit(self, X):
13+
if self._metric == 'angular':
14+
X = sklearn.preprocessing.normalize(X, axis=1, norm='l2')
15+
self._tree = sklearn.neighbors.BallTree(X, leaf_size=self._leaf_size)
16+
17+
def query(self, v, n):
18+
if self._metric == 'angular':
19+
v = sklearn.preprocessing.normalize(v, axis=1, norm='l2')[0]
20+
dist, ind = self._tree.query(v, k=n)
21+
return ind[0]

ann_benchmarks/algorithms/base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from __future__ import absolute_import
2+
3+
class BaseANN(object):
4+
def use_threads(self):
5+
return True
6+
def done(self):
7+
pass
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from __future__ import absolute_import
2+
import numpy
3+
import sklearn.neighbors
4+
from ann_benchmarks.distance import metrics as pd
5+
from ann_benchmarks.algorithms.base import BaseANN
6+
7+
class BruteForce(BaseANN):
8+
def __init__(self, metric):
9+
self._metric = metric
10+
self.name = 'BruteForce()'
11+
12+
def fit(self, X):
13+
metric = {'angular': 'cosine', 'euclidean': 'l2', 'hamming': 'hamming'}[self._metric]
14+
self._nbrs = sklearn.neighbors.NearestNeighbors(algorithm='brute', metric=metric)
15+
self._nbrs.fit(X)
16+
17+
def query(self, v, n):
18+
return list(self._nbrs.kneighbors([v],
19+
return_distance = False, n_neighbors = n)[0])
20+
21+
def query_with_distances(self, v, n):
22+
(distances, positions) = self._nbrs.kneighbors([v],
23+
return_distance = True, n_neighbors = n)
24+
return zip(list(positions[0]), list(distances[0]))
25+
26+
class BruteForceBLAS(BaseANN):
27+
"""kNN search that uses a linear scan = brute force."""
28+
def __init__(self, metric, precision=numpy.float32):
29+
if metric not in ('angular', 'euclidean', 'hamming'):
30+
raise NotImplementedError("BruteForceBLAS doesn't support metric %s" % metric)
31+
elif metric == 'hamming' and precision != numpy.bool:
32+
raise NotImplementedError("BruteForceBLAS doesn't support precision %s with Hamming distances" % precision)
33+
self._metric = metric
34+
self._precision = precision
35+
self.name = 'BruteForceBLAS()'
36+
37+
def fit(self, X):
38+
"""Initialize the search index."""
39+
lens = (X ** 2).sum(-1) # precompute (squared) length of each vector
40+
if self._metric == 'angular':
41+
X /= numpy.sqrt(lens)[..., numpy.newaxis] # normalize index vectors to unit length
42+
self.index = numpy.ascontiguousarray(X, dtype=self._precision)
43+
elif self._metric == 'euclidean':
44+
self.index = numpy.ascontiguousarray(X, dtype=self._precision)
45+
self.lengths = numpy.ascontiguousarray(lens, dtype=self._precision)
46+
elif self._metric == 'hamming':
47+
self.index = numpy.ascontiguousarray(
48+
map(numpy.packbits, X), dtype=numpy.uint8)
49+
else:
50+
assert False, "invalid metric" # shouldn't get past the constructor!
51+
52+
def query(self, v, n):
53+
return map(lambda (index, _): index, self.query_with_distances(v, n))
54+
55+
popcount = []
56+
for i in xrange(256):
57+
popcount.append(bin(i).count("1"))
58+
59+
def query_with_distances(self, v, n):
60+
"""Find indices of `n` most similar vectors from the index to query vector `v`."""
61+
if self._metric == 'hamming':
62+
v = numpy.packbits(v)
63+
64+
# use same precision for query as for index
65+
v = numpy.ascontiguousarray(v, dtype = self.index.dtype)
66+
67+
# HACK we ignore query length as that's a constant not affecting the final ordering
68+
if self._metric == 'angular':
69+
# argmax_a cossim(a, b) = argmax_a dot(a, b) / |a||b| = argmin_a -dot(a, b)
70+
dists = -numpy.dot(self.index, v)
71+
elif self._metric == 'euclidean':
72+
# argmin_a (a - b)^2 = argmin_a a^2 - 2ab + b^2 = argmin_a a^2 - 2ab
73+
dists = self.lengths - 2 * numpy.dot(self.index, v)
74+
elif self._metric == 'hamming':
75+
diff = numpy.bitwise_xor(v, self.index)
76+
pc = BruteForceBLAS.popcount
77+
den = float(len(v) * 8)
78+
dists = [sum([pc[part] for part in point]) / den for point in diff]
79+
else:
80+
assert False, "invalid metric" # shouldn't get past the constructor!
81+
indices = numpy.argpartition(dists, n)[:n] # partition-sort by distance, get `n` closest
82+
def fix(index):
83+
ep = self.index[index]
84+
ev = v
85+
if self._metric == "hamming":
86+
ep = numpy.unpackbits(ep)
87+
ev = numpy.unpackbits(ev)
88+
return (index, pd[self._metric](ep, ev))
89+
return map(fix, indices)

ann_benchmarks/algorithms/external.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from __future__ import absolute_import
2+
import shlex
3+
import subprocess
4+
from ann_benchmarks.algorithms.base import BaseANN
5+
6+
class Subprocess(BaseANN):
7+
def __raw_line(self):
8+
return shlex.split( \
9+
self.__get_program_handle().stdout.readline().strip())
10+
def __line(self):
11+
line = self.__raw_line()
12+
while len(line) < 1 or line[0] != "epbprtv0":
13+
line = self.__raw_line()
14+
return line[1:]
15+
16+
@staticmethod
17+
def __quote(token):
18+
return "'" + str(token).replace("'", "'\\'") + "'"
19+
20+
def __write(self, string):
21+
self.__get_program_handle().stdin.write(string + "\n")
22+
23+
def __get_program_handle(self):
24+
if not self._program:
25+
self._program = subprocess.Popen(
26+
self._args,
27+
bufsize = 1, # line buffering
28+
stdin = subprocess.PIPE,
29+
stdout = subprocess.PIPE,
30+
universal_newlines = True)
31+
for key, value in self._params.iteritems():
32+
self.__write("%s %s" % \
33+
(Subprocess.__quote(key), Subprocess.__quote(value)))
34+
assert(self.__line()[0] == "ok")
35+
self.__write("")
36+
assert(self.__line()[0] == "ok")
37+
return self._program
38+
39+
def __init__(self, args, encoder, params):
40+
self.name = "Subprocess(program = %s, %s)" % (args[0], str(params))
41+
self._program = None
42+
self._args = args
43+
self._encoder = encoder
44+
self._params = params
45+
46+
def fit(self, X):
47+
for entry in X:
48+
self.__write(self._encoder(entry))
49+
assert(self.__line()[0] == "ok")
50+
self.__write("")
51+
assert(self.__line()[0] == "ok")
52+
53+
def query(self, v, n):
54+
self.__write("%s %d" % \
55+
(Subprocess.__quote(self._encoder(v)), n))
56+
status = self.__line()
57+
if status[0] == "ok":
58+
count = int(status[1])
59+
results = []
60+
i = 0
61+
while i < count:
62+
line = self.__line()
63+
results.append(int(line[0]))
64+
i += 1
65+
assert(len(results) == count)
66+
return results
67+
else:
68+
assert(status[0] == "fail")
69+
return []
70+
71+
def use_threads(self):
72+
return False
73+
def done(self):
74+
if self._program:
75+
self._program.terminate()

ann_benchmarks/algorithms/falconn.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from __future__ import absolute_import
2+
import numpy
3+
import falconn
4+
from ann_benchmarks.algorithms.base import BaseANN
5+
6+
class FALCONN(BaseANN):
7+
def __init__(self, metric, num_bits, num_tables, num_probes = None):
8+
if not num_probes:
9+
num_probes = num_tables
10+
self.name = 'FALCONN(K={}, L={}, T={})'.format(num_bits, num_tables, num_probes)
11+
self._metric = metric
12+
self._num_bits = num_bits
13+
self._num_tables = num_tables
14+
self._num_probes = num_probes
15+
self._center = None
16+
self._params = None
17+
self._index = None
18+
self._buf = None
19+
20+
def fit(self, X):
21+
if X.dtype != numpy.float32:
22+
X = X.astype(numpy.float32)
23+
if self._metric == 'angular':
24+
X /= numpy.linalg.norm(X, axis=1).reshape(-1, 1)
25+
self._center = numpy.mean(X, axis=0)
26+
X -= self._center
27+
self._params = falconn.LSHConstructionParameters()
28+
self._params.dimension = X.shape[1]
29+
self._params.distance_function = 'euclidean_squared'
30+
self._params.lsh_family = 'cross_polytope'
31+
falconn.compute_number_of_hash_functions(self._num_bits, self._params)
32+
self._params.l = self._num_tables
33+
self._params.num_rotations = 1
34+
self._params.num_setup_threads = 0
35+
self._params.storage_hash_table = 'flat_hash_table'
36+
self._params.seed = 95225714
37+
self._index = falconn.LSHIndex(self._params)
38+
self._index.setup(X)
39+
self._index.set_num_probes(self._num_probes)
40+
self._buf = numpy.zeros((X.shape[1],), dtype=numpy.float32)
41+
42+
def query(self, v, n):
43+
numpy.copyto(self._buf, v)
44+
if self._metric == 'angular':
45+
self._buf /= numpy.linalg.norm(self._buf)
46+
self._buf -= self._center
47+
return self._index.find_k_nearest_neighbors(self._buf, n)
48+
49+
def use_threads(self):
50+
# See https://github.com/FALCONN-LIB/FALCONN/issues/6
51+
return False

ann_benchmarks/algorithms/flann.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from __future__ import absolute_import
2+
import pyflann
3+
import sklearn.preprocessing
4+
from ann_benchmarks.algorithms.base import BaseANN
5+
6+
class FLANN(BaseANN):
7+
def __init__(self, metric, target_precision):
8+
self._target_precision = target_precision
9+
self.name = 'FLANN(target_precision=%f)' % target_precision
10+
self._metric = metric
11+
12+
def fit(self, X):
13+
self._flann = pyflann.FLANN(target_precision=self._target_precision, algorithm='autotuned', log_level='info')
14+
if self._metric == 'angular':
15+
X = sklearn.preprocessing.normalize(X, axis=1, norm='l2')
16+
self._flann.build_index(X)
17+
18+
def query(self, v, n):
19+
if self._metric == 'angular':
20+
v = sklearn.preprocessing.normalize(v, axis=1, norm='l2')[0]
21+
return self._flann.nn_index(v, n)[0][0]

ann_benchmarks/algorithms/itu.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from __future__ import absolute_import
2+
import sys
3+
sys.path.append('install/ann-filters/build/wrappers/swig/')
4+
import numpy
5+
import locality_sensitive
6+
from ann_benchmarks.algorithms.base import BaseANN
7+
8+
class ITUFilteringDouble(BaseANN):
9+
def __init__(self, metric, alpha = None, beta = None, threshold = None, tau = None, kappa1 = None, kappa2 = None, m1 = None, m2 = None):
10+
self._loader = locality_sensitive.double_vector_loader()
11+
self._context = None
12+
self._strategy = None
13+
self._metric = metric
14+
self._alpha = alpha
15+
self._beta = beta
16+
self._threshold = threshold
17+
self._tau = tau
18+
self._kappa1 = kappa1
19+
self._kappa2 = kappa2
20+
self._m1 = m1
21+
self._m2 = m2
22+
self.name = ("ITUFilteringDouble(..., threshold = %f, ...)" % threshold)
23+
24+
def fit(self, X):
25+
if self._metric == 'angular':
26+
X /= numpy.linalg.norm(X, axis=1).reshape(-1, 1)
27+
self._loader.add(X)
28+
self._context = locality_sensitive.double_vector_context(
29+
self._loader, self._alpha, self._beta)
30+
self._strategy = locality_sensitive.factories.make_double_filtering(
31+
self._context, self._threshold,
32+
locality_sensitive.filtering_configuration.from_values(
33+
self._kappa1, self._kappa2, self._tau, self._m1, self._m2))
34+
35+
def query(self, v, n):
36+
if self._metric == 'angular':
37+
v /= numpy.linalg.norm(v)
38+
return self._strategy.find(v, n, None)
39+
40+
def use_threads(self):
41+
return False
42+
43+
class ITUHashing(BaseANN):
44+
def __init__(self, seed, c = 2.0, r = 2.0):
45+
self._loader = locality_sensitive.bit_vector_loader()
46+
self._context = None
47+
self._strategy = None
48+
self._c = c
49+
self._r = r
50+
self._seed = seed
51+
self.name = ("ITUHashing(c = %f, r = %f, seed = %u)" % (c, r, seed))
52+
53+
def fit(self, X):
54+
locality_sensitive.set_seed(self._seed)
55+
for entry in X:
56+
locality_sensitive.hacks.add(self._loader, entry.tolist())
57+
self._context = locality_sensitive.bit_vector_context(
58+
self._loader, self._c, self._r)
59+
self._strategy = locality_sensitive.factories.make_hashing(
60+
self._context)
61+
62+
def query(self, v, n):
63+
return locality_sensitive.hacks.find(self._strategy, n, v.tolist())
64+
65+
def use_threads(self):
66+
return False

ann_benchmarks/algorithms/kdtree.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from __future__ import absolute_import
2+
import sklearn.neighbors
3+
import sklearn.preprocessing
4+
from ann_benchmarks.algorithms.base import BaseANN
5+
6+
class KDTree(BaseANN):
7+
def __init__(self, metric, leaf_size=20):
8+
self.name = 'KDTree(leaf_size=%d)' % leaf_size
9+
self._leaf_size = leaf_size
10+
self._metric = metric
11+
12+
def fit(self, X):
13+
if self._metric == 'angular':
14+
X = sklearn.preprocessing.normalize(X, axis=1, norm='l2')
15+
self._tree = sklearn.neighbors.KDTree(X, leaf_size=self._leaf_size)
16+
17+
def query(self, v, n):
18+
if self._metric == 'angular':
19+
v = sklearn.preprocessing.normalize(v, axis=1, norm='l2')[0]
20+
dist, ind = self._tree.query(v, k=n)
21+
return ind[0]

0 commit comments

Comments
 (0)