Skip to content

Commit

Permalink
no thread limits, instead use thread pool for queries
Browse files Browse the repository at this point in the history
  • Loading branch information
Erik Bernhardsson committed Jul 11, 2016
1 parent 9467c0d commit 522dcac
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions ann_benchmarks/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,12 @@
print('resetting memory limit from', soft, 'to', memory_limit)
resource.setrlimit(resource.RLIMIT_DATA, (memory_limit, hard))

os.environ['OMP_THREAD_LIMIT'] = '1' # just to limit number of processors

# Nmslib specific code
# Remove old indices stored on disk
INDEX_DIR='indices'
#import shutil
#if os.path.exists(INDEX_DIR):
# shutil.rmtree(INDEX_DIR)
import shutil
if os.path.exists(INDEX_DIR):
shutil.rmtree(INDEX_DIR)

class BaseANN(object):
pass
Expand Down Expand Up @@ -222,7 +220,6 @@ def __init__(self, metric, P, index_params, save_index):
self._save_index = save_index

def fit(self, X):
os.environ['OMP_THREAD_LIMIT'] = '40'
import pykgraph

if X.dtype != numpy.float32:
Expand All @@ -237,9 +234,13 @@ def fit(self, X):
self._kgraph.build(**self._index_params) #iterations=30, L=100, delta=0.002, recall=0.99, K=25)
if not os.path.exists(INDEX_DIR):
os.makedirs(INDEX_DIR)
<<<<<<< 9467c0df281b03597098ef968f180de8bf4d4671
if self._save_index:
self._kgraph.save(path)
os.environ['OMP_THREAD_LIMIT'] = '1'
=======
self._kgraph.save(path)
>>>>>>> no thread limits, instead use thread pool for queries

def query(self, v, n):
if v.dtype != numpy.float32:
Expand All @@ -262,7 +263,6 @@ def __init__(self, metric, method_name, index_param, save_index, query_param):
os.makedirs(d)

def fit(self, X):
os.environ['OMP_THREAD_LIMIT'] = '40'
import nmslib_vector
if self._method_name == 'vptree':
# To avoid this issue:
Expand All @@ -287,8 +287,6 @@ def fit(self, X):

nmslib_vector.setQueryTimeParams(self._index, self._query_param)

os.environ['OMP_THREAD_LIMIT'] = '1'

def query(self, v, n):
import nmslib_vector
return nmslib_vector.knnQuery(self._index, n, v.tolist())
Expand Down Expand Up @@ -393,6 +391,7 @@ def query(self, v, n):
indices = numpy.argpartition(dists, n)[:n] # partition-sort by distance, get `n` closest
return sorted(indices, key=lambda index: dists[index]) # sort `n` closest into correct order


def get_dataset(which='glove', limit=-1, random_state = 2, test_size = 10000):
cache = 'queries/%s-%d-%d-%d.npz' % (which, test_size, limit, random_state)
if os.path.exists(cache):
Expand Down Expand Up @@ -445,9 +444,12 @@ def run_algo(args, library, algo, results_fn):
for i in xrange(3): # Do multiple times to warm up page cache, use fastest
t0 = time.time()
k = 0.0
for v, correct in queries:
def single_query(t):
v, correct = t
found = algo.query(v, 10)
k += len(set(found).intersection(correct))
pool = multiprocessing.pool.ThreadPool()
pool.map(single_query, queries)
search_time = (time.time() - t0) / len(queries)
precision = k / (len(queries) * 10)
best_search_time = min(best_search_time, search_time)
Expand Down

0 comments on commit 522dcac

Please sign in to comment.