diff --git a/docs/sphinx/rst/bibliography.rst b/docs/sphinx/rst/bibliography.rst index fb9a44ceeb..bbcd98d084 100644 --- a/docs/sphinx/rst/bibliography.rst +++ b/docs/sphinx/rst/bibliography.rst @@ -96,6 +96,9 @@ Bibliography .. [CASCADEHASHING] **Fast and Accurate Image Matching with Cascade Hashing for 3D Reconstruction.** Jian Cheng, Cong Leng, Jiaxiang Wu, Hainan Cui, Hanqing Lu. CVPR 2014. +.. [HNSW] **Efficient and robust approximate nearest neighbor search using Hierarchical Navigable Small World graphs.** + Yu. A. Malkov, D. A. Yashunin, TPAMI 2018. + .. [Magnus] **Two-View Orthographic Epipolar Geometry: Minimal and Optimal Solvers.** Magnus Oskarsson. In Journal of Mathematical Imaging and Vision, 2017. diff --git a/docs/sphinx/rst/openMVG/matching/matching.rst b/docs/sphinx/rst/openMVG/matching/matching.rst index 844285bbb1..11ca4bb490 100644 --- a/docs/sphinx/rst/openMVG/matching/matching.rst +++ b/docs/sphinx/rst/openMVG/matching/matching.rst @@ -15,6 +15,7 @@ Three implementations are available: * a Brute force, * an Approximate Nearest Neighbor [FLANN]_, * a Cascade hashing Nearest Neighbor [CASCADEHASHING]_. +* an approximate nearest neighbor search using Hierarchical Navigable Small World graphs [HNSW] This module works for data of any dimensionality, it could be use to match: diff --git a/src/openMVG/matching/matcher_hnsw.hpp b/src/openMVG/matching/matcher_hnsw.hpp new file mode 100644 index 0000000000..19bb684045 --- /dev/null +++ b/src/openMVG/matching/matcher_hnsw.hpp @@ -0,0 +1,170 @@ +// This file is part of OpenMVG, an Open Multiple View Geometry C++ library. + +// Copyright (c) 2019 Romain Janvier and Pierre Moulon + +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef OPENMVG_MATCHING_MATCHER_HNSW_HPP +#define OPENMVG_MATCHING_MATCHER_HNSW_HPP + +#include +#ifdef OPENMVG_USE_OPENMP +#include +#endif +#include + +#include "openMVG/matching/matching_interface.hpp" +#include "openMVG/matching/metric.hpp" + +#include "third_party/hnswlib/hnswlib.h" + +using namespace hnswlib; + +namespace openMVG { +namespace matching { + +// By default compute square(L2 distance). +template > +class HNSWMatcher: public ArrayMatcher +{ + public: + using DistanceType = typename Metric::ResultType; + + HNSWMatcher() = default; + virtual ~HNSWMatcher()= default; + + /** + * Build the matching structure + * + * \param[in] dataset Input data. + * \param[in] nbRows The number of component. + * \param[in] dimension Length of the data contained in the dataset. + * + * \return True if success. + */ + bool Build + ( + const Scalar * dataset, + int nbRows, + int dimension + ) override + { + if (nbRows < 1) + { + HNSWmetric.reset(nullptr); + HNSWmatcher.reset(nullptr); + return false; + } + + dimension_ = dimension; + + // Here this is tricky since there is no specialization + if(typeid(DistanceType)== typeid(int)) { + HNSWmetric.reset(dynamic_cast *>(new L2SpaceI(dimension))); + } else + if (typeid(DistanceType) == typeid(float)) { + HNSWmetric.reset(dynamic_cast *>(new L2Space(dimension))); + } else { + std::cerr << "HNSW matcher: this type of distance is not handled Yet" << std::endl; + } + + HNSWmatcher.reset(new HierarchicalNSW(HNSWmetric.get(), nbRows, 16, 100) ); + HNSWmatcher->setEf(16); + + // add first point.. + HNSWmatcher->addPoint((void *)(dataset), (size_t) 0); + //...and the other in // + #ifdef OPENMVG_USE_OPENMP + #pragma omp parallel for + #endif + for (int i = 1; i < nbRows; i++) { + HNSWmatcher->addPoint((void *) (dataset + dimension * i), (size_t) i); + } + + return true; + }; + + /** + * Search the nearest Neighbor of the scalar array query. + * + * \param[in] query The query array. + * \param[out] indice The indice of array in the dataset that. + * have been computed as the nearest array. + * \param[out] distance The distance between the two arrays. + * + * \return True if success. + */ + bool SearchNeighbour + ( + const Scalar * query, + int * indice, + DistanceType * distance + ) override + { + if (HNSWmatcher.get() == nullptr) + return false; + auto result = HNSWmatcher->searchKnn(query, 1).top(); + *indice = result.second; + *distance = result.first; + return true; + } + + /** + * Search the N nearest Neighbor of the scalar array query. + * + * \param[in] query The query array. + * \param[in] nbQuery The number of query rows. + * \param[out] indices The corresponding (query, neighbor) indices. + * \param[out] distances The distances between the matched arrays. + * \param[in] NN The number of maximal neighbor that will be searched. + * + * \return True if success. + */ + bool SearchNeighbours + ( + const Scalar * query, int nbQuery, + IndMatches * pvec_indices, + std::vector * pvec_distances, + size_t NN + ) override + { + if (HNSWmatcher.get() == nullptr) + { + return false; + } + pvec_indices->reserve(nbQuery * NN); + pvec_distances->reserve(nbQuery * NN); + #ifdef OPENMVG_USE_OPENMP + #pragma omp parallel for + #endif + for (int i = 0; i < nbQuery; i++) { + auto result = HNSWmatcher->searchKnn((const void *) (query + dimension_ * i), NN, + [](const std::pair &a, const std::pair &b) -> bool { + return a.first < b.first; + }); + #ifdef OPENMVG_USE_OPENMP + #pragma omp critical + #endif + { + for (const auto & res : result) + { + pvec_indices->emplace_back(i, res.second); + pvec_distances->emplace_back(res.first); + } + } + } + return true; + }; + +private: + int dimension_; + std::unique_ptr> HNSWmetric; + std::unique_ptr> HNSWmatcher; +}; + +} // namespace matching +} // namespace openMVG + +#endif // OPENMVG_MATCHING_MATCHER_HNSW_HPP \ No newline at end of file diff --git a/src/openMVG/matching/matcher_type.hpp b/src/openMVG/matching/matcher_type.hpp index 5778513ef3..c75d2ea507 100644 --- a/src/openMVG/matching/matcher_type.hpp +++ b/src/openMVG/matching/matcher_type.hpp @@ -17,6 +17,7 @@ enum EMatcherType : unsigned char BRUTE_FORCE_L2, ANN_L2, CASCADE_HASHING_L2, + HNSW_L2, BRUTE_FORCE_HAMMING }; diff --git a/src/openMVG/matching/matching_test.cpp b/src/openMVG/matching/matching_test.cpp index 260c320b99..0569e63478 100644 --- a/src/openMVG/matching/matching_test.cpp +++ b/src/openMVG/matching/matching_test.cpp @@ -11,6 +11,7 @@ #include "openMVG/matching/matcher_brute_force.hpp" #include "openMVG/matching/matcher_cascade_hashing.hpp" #include "openMVG/matching/matcher_kdtree_flann.hpp" +#include "openMVG/matching/matcher_hnsw.hpp" #include "openMVG/numeric/eigen_alias_definition.hpp" @@ -118,6 +119,38 @@ TEST(Matching, ArrayMatcher_Kdtree_Flann_Simple__NN) EXPECT_EQ(IndMatch(0,4), vec_nIndice[4]); } +TEST(Matching, ArrayMatcher_Hnsw_Simple__NN) +{ + const float array[] = {0, 1, 2, 5, 6}; + // no 3, because it involve the same dist as 1,1 + + HNSWMatcher matcher; + EXPECT_TRUE( matcher.Build(array, 5, 1) ); + + const float query[] = {2}; + IndMatches vec_nIndice; + vector vec_fDistance; + const int NN = 5; + EXPECT_TRUE( matcher.SearchNeighbours(query, 1, &vec_nIndice, &vec_fDistance, NN) ); + + EXPECT_EQ( 5, vec_nIndice.size()); + EXPECT_EQ( 5, vec_fDistance.size()); + + // Check distances: + EXPECT_NEAR( vec_fDistance[0], Square(2.0f-2.0f), 1e-6); + EXPECT_NEAR( vec_fDistance[1], Square(1.0f-2.0f), 1e-6); + EXPECT_NEAR( vec_fDistance[2], Square(0.0f-2.0f), 1e-6); + EXPECT_NEAR( vec_fDistance[3], Square(5.0f-2.0f), 1e-6); + EXPECT_NEAR( vec_fDistance[4], Square(6.0f-2.0f), 1e-6); + + // Check indexes: + EXPECT_EQ(IndMatch(0,2), vec_nIndice[0]); + EXPECT_EQ(IndMatch(0,1), vec_nIndice[1]); + EXPECT_EQ(IndMatch(0,0), vec_nIndice[2]); + EXPECT_EQ(IndMatch(0,3), vec_nIndice[3]); + EXPECT_EQ(IndMatch(0,4), vec_nIndice[4]); +} + //-- Test LIMIT case (empty arrays) TEST(Matching, ArrayMatcherBruteForce_Simple_EmptyArrays) diff --git a/src/openMVG/matching/regions_matcher.cpp b/src/openMVG/matching/regions_matcher.cpp index 744c18deaf..3b3fac08ad 100644 --- a/src/openMVG/matching/regions_matcher.cpp +++ b/src/openMVG/matching/regions_matcher.cpp @@ -10,6 +10,7 @@ #include "openMVG/matching/matcher_brute_force.hpp" #include "openMVG/matching/matcher_cascade_hashing.hpp" #include "openMVG/matching/matcher_kdtree_flann.hpp" +#include "openMVG/matching/matcher_hnsw.hpp" #include "openMVG/matching/metric.hpp" #include "openMVG/matching/metric_hamming.hpp" @@ -84,6 +85,13 @@ std::unique_ptr RegionMatcherFactory region_matcher.reset(new matching::RegionsMatcherT(regions, true)); } break; + case HNSW_L2: + { + using MetricT = L2; + using MatcherT = HNSWMatcher; + region_matcher.reset(new matching::RegionsMatcherT(regions, true)); + } + break; case CASCADE_HASHING_L2: { using MetricT = L2; @@ -114,6 +122,13 @@ std::unique_ptr RegionMatcherFactory region_matcher.reset(new matching::RegionsMatcherT(regions, true)); } break; + case HNSW_L2: + { + using MetricT = L2; + using MatcherT = HNSWMatcher; + region_matcher.reset(new matching::RegionsMatcherT(regions, true)); + } + break; case CASCADE_HASHING_L2: { using MetricT = L2; diff --git a/src/software/SfM/main_ComputeMatches.cpp b/src/software/SfM/main_ComputeMatches.cpp index dc50fcbf6f..d6cecb9137 100644 --- a/src/software/SfM/main_ComputeMatches.cpp +++ b/src/software/SfM/main_ComputeMatches.cpp @@ -126,6 +126,7 @@ int main(int argc, char **argv) << " AUTO: auto choice from regions type,\n" << " For Scalar based regions descriptor:\n" << " BRUTEFORCEL2: L2 BruteForce matching,\n" + << " HNSWL2: L2 Approximate Matching with Hierarchical Navigable Small World graphs (float only),\n" << " ANNL2: L2 Approximate Nearest Neighbor matching,\n" << " CASCADEHASHINGL2: L2 Cascade Hashing matching.\n" << " FASTCASCADEHASHINGL2: (default)\n" @@ -335,6 +336,12 @@ int main(int argc, char **argv) collectionMatcher.reset(new Matcher_Regions(fDistRatio, BRUTE_FORCE_HAMMING)); } else + if (sNearestMatchingMethod == "HNSWL2") + { + std::cout << "Using HNSWL2 matcher" << std::endl; + collectionMatcher.reset(new Matcher_Regions(fDistRatio, HNSW_L2)); + } + else if (sNearestMatchingMethod == "ANNL2") { std::cout << "Using ANN_L2 matcher" << std::endl; diff --git a/src/third_party/CMakeLists.txt b/src/third_party/CMakeLists.txt index 816a9414f7..b71dfab120 100644 --- a/src/third_party/CMakeLists.txt +++ b/src/third_party/CMakeLists.txt @@ -96,11 +96,11 @@ if (DEFINED OpenMVG_USE_INTERNAL_EIGEN) add_subdirectory(eigen) endif() -list(APPEND directories cmdLine histogram htmlDoc progress vectorGraphics) +list(APPEND directories cmdLine histogram htmlDoc progress vectorGraphics hnswlib) foreach(inDirectory ${directories}) install( DIRECTORY ./${inDirectory} - DESTINATION include/openMVG/third_party/ + DESTINATION ${CMAKE_INSTALL_PREFIX}/include/openMVG/third_party/ COMPONENT headers FILES_MATCHING PATTERN "*.hpp" PATTERN "*.h" ) diff --git a/src/third_party/hnswlib/bruteforce.h b/src/third_party/hnswlib/bruteforce.h new file mode 100644 index 0000000000..5b1bd655ac --- /dev/null +++ b/src/third_party/hnswlib/bruteforce.h @@ -0,0 +1,170 @@ +#pragma once +#include +#include +#include +#include + +namespace hnswlib { + template + class BruteforceSearch : public AlgorithmInterface { + public: + BruteforceSearch(SpaceInterface *s) { + + } + BruteforceSearch(SpaceInterface *s, const std::string &location) { + loadIndex(location, s); + } + + BruteforceSearch(SpaceInterface *s, size_t maxElements) { + maxelements_ = maxElements; + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + size_per_element_ = data_size_ + sizeof(labeltype); + data_ = (char *) malloc(maxElements * size_per_element_); + if (data_ == nullptr) + std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data"); + cur_element_count = 0; + } + + ~BruteforceSearch() { + free(data_); + } + + char *data_; + size_t maxelements_; + size_t cur_element_count; + size_t size_per_element_; + + size_t data_size_; + DISTFUNC fstdistfunc_; + void *dist_func_param_; + std::mutex index_lock; + + std::unordered_map dict_external_to_internal; + + void addPoint(const void *datapoint, labeltype label) { + + int idx; + { + std::unique_lock lock(index_lock); + + + + auto search=dict_external_to_internal.find(label); + if (search != dict_external_to_internal.end()) { + idx=search->second; + } + else{ + if (cur_element_count >= maxelements_) { + throw std::runtime_error("The number of elements exceeds the specified limit\n"); + } + idx=cur_element_count; + dict_external_to_internal[label] = idx; + cur_element_count++; + } + } + memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype)); + memcpy(data_ + size_per_element_ * idx, datapoint, data_size_); + + + + + }; + + void removePoint(labeltype cur_external) { + size_t cur_c=dict_external_to_internal[cur_external]; + + dict_external_to_internal.erase(cur_external); + + labeltype label=*((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_)); + dict_external_to_internal[label]=cur_c; + memcpy(data_ + size_per_element_ * cur_c, + data_ + size_per_element_ * (cur_element_count-1), + data_size_+sizeof(labeltype)); + cur_element_count--; + + } + + + std::priority_queue> + searchKnn(const void *query_data, size_t k) const { + std::priority_queue> topResults; + if (cur_element_count == 0) return topResults; + for (int i = 0; i < k; i++) { + dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); + topResults.push(std::pair(dist, *((labeltype *) (data_ + size_per_element_ * i + + data_size_)))); + } + dist_t lastdist = topResults.top().first; + for (int i = k; i < cur_element_count; i++) { + dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); + if (dist <= lastdist) { + topResults.push(std::pair(dist, *((labeltype *) (data_ + size_per_element_ * i + + data_size_)))); + if (topResults.size() > k) + topResults.pop(); + lastdist = topResults.top().first; + } + + } + return topResults; + }; + + template + std::vector> + searchKnn(const void* query_data, size_t k, Comp comp) { + std::vector> result; + if (cur_element_count == 0) return result; + + auto ret = searchKnn(query_data, k); + + while (!ret.empty()) { + result.push_back(ret.top()); + ret.pop(); + } + + std::sort(result.begin(), result.end(), comp); + + return result; + } + + void saveIndex(const std::string &location) { + std::ofstream output(location, std::ios::binary); + std::streampos position; + + writeBinaryPOD(output, maxelements_); + writeBinaryPOD(output, size_per_element_); + writeBinaryPOD(output, cur_element_count); + + output.write(data_, maxelements_ * size_per_element_); + + output.close(); + } + + void loadIndex(const std::string &location, SpaceInterface *s) { + + + std::ifstream input(location, std::ios::binary); + std::streampos position; + + readBinaryPOD(input, maxelements_); + readBinaryPOD(input, size_per_element_); + readBinaryPOD(input, cur_element_count); + + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + size_per_element_ = data_size_ + sizeof(labeltype); + data_ = (char *) malloc(maxelements_ * size_per_element_); + if (data_ == nullptr) + std::runtime_error("Not enough memory: loadIndex failed to allocate data"); + + input.read(data_, maxelements_ * size_per_element_); + + input.close(); + + } + + }; +} diff --git a/src/third_party/hnswlib/hnswalg.h b/src/third_party/hnswlib/hnswalg.h new file mode 100644 index 0000000000..afc1222d3e --- /dev/null +++ b/src/third_party/hnswlib/hnswalg.h @@ -0,0 +1,987 @@ +#pragma once + +#include "visited_list_pool.h" +#include "hnswlib.h" +#include +#include +#include +#include + + +namespace hnswlib { + typedef unsigned int tableint; + typedef unsigned int linklistsizeint; + + template + class HierarchicalNSW : public AlgorithmInterface { + public: + + HierarchicalNSW(SpaceInterface *s) { + + } + + HierarchicalNSW(SpaceInterface *s, const std::string &location, bool nmslib = false, size_t max_elements=0) { + loadIndex(location, s, max_elements); + } + + HierarchicalNSW(SpaceInterface *s, size_t max_elements, size_t M = 16, size_t ef_construction = 200, size_t random_seed = 100) : + link_list_locks_(max_elements), element_levels_(max_elements) { + max_elements_ = max_elements; + + has_deletions_=false; + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + M_ = M; + maxM_ = M_; + maxM0_ = M_ * 2; + ef_construction_ = std::max(ef_construction,M_); + ef_ = 10; + + level_generator_.seed(random_seed); + + size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); + size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype); + offsetData_ = size_links_level0_; + label_offset_ = size_links_level0_ + data_size_; + offsetLevel0_ = 0; + + data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_); + if (data_level0_memory_ == nullptr) + throw std::runtime_error("Not enough memory"); + + cur_element_count = 0; + + visited_list_pool_ = new VisitedListPool(1, max_elements); + + + + //initializations for special treatment of the first node + enterpoint_node_ = -1; + maxlevel_ = -1; + + linkLists_ = (char **) malloc(sizeof(void *) * max_elements_); + if (linkLists_ == nullptr) + throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists"); + size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); + mult_ = 1 / log(1.0 * M_); + revSize_ = 1.0 / mult_; + } + + struct CompareByFirst { + constexpr bool operator()(std::pair const &a, + std::pair const &b) const noexcept { + return a.first < b.first; + } + }; + + ~HierarchicalNSW() { + + free(data_level0_memory_); + for (tableint i = 0; i < cur_element_count; i++) { + if (element_levels_[i] > 0) + free(linkLists_[i]); + } + free(linkLists_); + delete visited_list_pool_; + } + + size_t max_elements_; + size_t cur_element_count; + size_t size_data_per_element_; + size_t size_links_per_element_; + + size_t M_; + size_t maxM_; + size_t maxM0_; + size_t ef_construction_; + + double mult_, revSize_; + int maxlevel_; + + + VisitedListPool *visited_list_pool_; + std::mutex cur_element_count_guard_; + + std::vector link_list_locks_; + tableint enterpoint_node_; + + + size_t size_links_level0_; + size_t offsetData_, offsetLevel0_; + + + char *data_level0_memory_; + char **linkLists_; + std::vector element_levels_; + + size_t data_size_; + + bool has_deletions_; + + + size_t label_offset_; + DISTFUNC fstdistfunc_; + void *dist_func_param_; + std::unordered_map label_lookup_; + + std::default_random_engine level_generator_; + + inline labeltype getExternalLabel(tableint internal_id) const { + labeltype return_label; + memcpy(&return_label,(data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype)); + return return_label; + } + + inline void setExternalLabel(tableint internal_id, labeltype label) const { + memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype)); + } + + inline labeltype *getExternalLabeLp(tableint internal_id) const { + return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_); + } + + inline char *getDataByInternalId(tableint internal_id) const { + return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_); + } + + int getRandomLevel(double reverse_size) { + std::uniform_real_distribution distribution(0.0, 1.0); + double r = -log(distribution(level_generator_)) * reverse_size; + return (int) r; + } + + std::priority_queue, std::vector>, CompareByFirst> + searchBaseLayer(tableint ep_id, const void *data_point, int layer) { + VisitedList *vl = visited_list_pool_->getFreeVisitedList(); + vl_type *visited_array = vl->mass; + vl_type visited_array_tag = vl->curV; + + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + std::priority_queue, std::vector>, CompareByFirst> candidateSet; + + dist_t lowerBound; + if (!isMarkedDeleted(ep_id)) { + dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); + top_candidates.emplace(dist, ep_id); + lowerBound = dist; + candidateSet.emplace(-dist, ep_id); + } else { + lowerBound = std::numeric_limits::max(); + candidateSet.emplace(-lowerBound, ep_id); + } + visited_array[ep_id] = visited_array_tag; + + while (!candidateSet.empty()) { + std::pair curr_el_pair = candidateSet.top(); + if ((-curr_el_pair.first) > lowerBound) { + break; + } + candidateSet.pop(); + + tableint curNodeNum = curr_el_pair.second; + + std::unique_lock lock(link_list_locks_[curNodeNum]); + + int *data;// = (int *)(linkList0_ + curNodeNum * size_links_per_element0_); + if (layer == 0) { + data = (int*)get_linklist0(curNodeNum); + } else { + data = (int*)get_linklist(curNodeNum, layer); +// data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_); + } + size_t size = getListCount((linklistsizeint*)data); + tableint *datal = (tableint *) (data + 1); +#ifdef USE_SSE + _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); + _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0); +#endif + + for (size_t j = 0; j < size; j++) { + tableint candidate_id = *(datal + j); +// if (candidate_id == 0) continue; +#ifdef USE_SSE + _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0); +#endif + if (visited_array[candidate_id] == visited_array_tag) continue; + visited_array[candidate_id] = visited_array_tag; + char *currObj1 = (getDataByInternalId(candidate_id)); + + dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_); + if (top_candidates.size() < ef_construction_ || lowerBound > dist1) { + candidateSet.emplace(-dist1, candidate_id); +#ifdef USE_SSE + _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0); +#endif + + if (!isMarkedDeleted(candidate_id)) + top_candidates.emplace(dist1, candidate_id); + + if (top_candidates.size() > ef_construction_) + top_candidates.pop(); + + if (!top_candidates.empty()) + lowerBound = top_candidates.top().first; + } + } + } + visited_list_pool_->releaseVisitedList(vl); + + return top_candidates; + } + + template + std::priority_queue, std::vector>, CompareByFirst> + searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef) const { + VisitedList *vl = visited_list_pool_->getFreeVisitedList(); + vl_type *visited_array = vl->mass; + vl_type visited_array_tag = vl->curV; + + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + std::priority_queue, std::vector>, CompareByFirst> candidate_set; + + dist_t lowerBound; + if (!has_deletions || !isMarkedDeleted(ep_id)) { + dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); + lowerBound = dist; + top_candidates.emplace(dist, ep_id); + candidate_set.emplace(-dist, ep_id); + } else { + lowerBound = std::numeric_limits::max(); + candidate_set.emplace(-lowerBound, ep_id); + } + + visited_array[ep_id] = visited_array_tag; + + while (!candidate_set.empty()) { + + std::pair current_node_pair = candidate_set.top(); + + if ((-current_node_pair.first) > lowerBound) { + break; + } + candidate_set.pop(); + + tableint current_node_id = current_node_pair.second; + int *data = (int *) get_linklist0(current_node_id); + size_t size = getListCount((linklistsizeint*)data); +// bool cur_node_deleted = isMarkedDeleted(current_node_id); + +#ifdef USE_SSE + _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); + _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); + _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0); + _mm_prefetch((char *) (data + 2), _MM_HINT_T0); +#endif + + for (size_t j = 1; j <= size; j++) { + int candidate_id = *(data + j); +// if (candidate_id == 0) continue; +#ifdef USE_SSE + _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0); + _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_, + _MM_HINT_T0);//////////// +#endif + if (!(visited_array[candidate_id] == visited_array_tag)) { + + visited_array[candidate_id] = visited_array_tag; + + char *currObj1 = (getDataByInternalId(candidate_id)); + dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_); + + if (top_candidates.size() < ef || lowerBound > dist) { + candidate_set.emplace(-dist, candidate_id); +#ifdef USE_SSE + _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ + + offsetLevel0_,/////////// + _MM_HINT_T0);//////////////////////// +#endif + + if (!has_deletions || !isMarkedDeleted(candidate_id)) + top_candidates.emplace(dist, candidate_id); + + if (top_candidates.size() > ef) + top_candidates.pop(); + + if (!top_candidates.empty()) + lowerBound = top_candidates.top().first; + } + } + } + } + + visited_list_pool_->releaseVisitedList(vl); + return top_candidates; + } + + void getNeighborsByHeuristic2( + std::priority_queue, std::vector>, CompareByFirst> &top_candidates, + const size_t M) { + if (top_candidates.size() < M) { + return; + } + std::priority_queue> queue_closest; + std::vector> return_list; + while (top_candidates.size() > 0) { + queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second); + top_candidates.pop(); + } + + while (queue_closest.size()) { + if (return_list.size() >= M) + break; + std::pair curent_pair = queue_closest.top(); + dist_t dist_to_query = -curent_pair.first; + queue_closest.pop(); + bool good = true; + for (std::pair second_pair : return_list) { + dist_t curdist = + fstdistfunc_(getDataByInternalId(second_pair.second), + getDataByInternalId(curent_pair.second), + dist_func_param_);; + if (curdist < dist_to_query) { + good = false; + break; + } + } + if (good) { + return_list.push_back(curent_pair); + } + + + } + + for (std::pair curent_pair : return_list) { + + top_candidates.emplace(-curent_pair.first, curent_pair.second); + } + } + + + linklistsizeint *get_linklist0(tableint internal_id) const { + return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); + }; + + linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const { + return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); + }; + + linklistsizeint *get_linklist(tableint internal_id, int level) const { + return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_); + }; + + void mutuallyConnectNewElement(const void *data_point, tableint cur_c, + std::priority_queue, std::vector>, CompareByFirst> top_candidates, + int level) { + + size_t Mcurmax = level ? maxM_ : maxM0_; + getNeighborsByHeuristic2(top_candidates, M_); + if (top_candidates.size() > M_) + throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic"); + + std::vector selectedNeighbors; + selectedNeighbors.reserve(M_); + while (top_candidates.size() > 0) { + selectedNeighbors.push_back(top_candidates.top().second); + top_candidates.pop(); + } + + { + linklistsizeint *ll_cur; + if (level == 0) + ll_cur = get_linklist0(cur_c); + else + ll_cur = get_linklist(cur_c, level); + + if (*ll_cur) { + throw std::runtime_error("The newly inserted element should have blank link list"); + } + setListCount(ll_cur,selectedNeighbors.size()); + tableint *data = (tableint *) (ll_cur + 1); + + + for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { + if (data[idx]) + throw std::runtime_error("Possible memory corruption"); + if (level > element_levels_[selectedNeighbors[idx]]) + throw std::runtime_error("Trying to make a link on a non-existent level"); + + data[idx] = selectedNeighbors[idx]; + + } + } + for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { + + std::unique_lock lock(link_list_locks_[selectedNeighbors[idx]]); + + + linklistsizeint *ll_other; + if (level == 0) + ll_other = get_linklist0(selectedNeighbors[idx]); + else + ll_other = get_linklist(selectedNeighbors[idx], level); + + size_t sz_link_list_other = getListCount(ll_other); + + if (sz_link_list_other > Mcurmax) + throw std::runtime_error("Bad value of sz_link_list_other"); + if (selectedNeighbors[idx] == cur_c) + throw std::runtime_error("Trying to connect an element to itself"); + if (level > element_levels_[selectedNeighbors[idx]]) + throw std::runtime_error("Trying to make a link on a non-existent level"); + + tableint *data = (tableint *) (ll_other + 1); + if (sz_link_list_other < Mcurmax) { + data[sz_link_list_other] = cur_c; + setListCount(ll_other, sz_link_list_other + 1); + } else { + // finding the "weakest" element to replace it with the new one + dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]), + dist_func_param_); + // Heuristic: + std::priority_queue, std::vector>, CompareByFirst> candidates; + candidates.emplace(d_max, cur_c); + + for (size_t j = 0; j < sz_link_list_other; j++) { + candidates.emplace( + fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]), + dist_func_param_), data[j]); + } + + getNeighborsByHeuristic2(candidates, Mcurmax); + + int indx = 0; + while (candidates.size() > 0) { + data[indx] = candidates.top().second; + candidates.pop(); + indx++; + } + setListCount(ll_other, indx); + // Nearest K: + /*int indx = -1; + for (int j = 0; j < sz_link_list_other; j++) { + dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_); + if (d > d_max) { + indx = j; + d_max = d; + } + } + if (indx >= 0) { + data[indx] = cur_c; + } */ + } + + } + } + + std::mutex global; + size_t ef_; + + void setEf(size_t ef) { + ef_ = ef; + } + + + std::priority_queue> searchKnnInternal(void *query_data, int k) { + std::priority_queue> top_candidates; + if (cur_element_count == 0) return top_candidates; + tableint currObj = enterpoint_node_; + dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); + + for (size_t level = maxlevel_; level > 0; level--) { + bool changed = true; + while (changed) { + changed = false; + int *data; + data = (int *) get_linklist(currObj,level); + int size = getListCount(data); + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); + + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + + if (has_deletions_) { + std::priority_queue> top_candidates1=searchBaseLayerST(currObj, query_data, + ef_); + top_candidates.swap(top_candidates1); + } + else{ + std::priority_queue> top_candidates1=searchBaseLayerST(currObj, query_data, + ef_); + top_candidates.swap(top_candidates1); + } + + while (top_candidates.size() > k) { + top_candidates.pop(); + } + return top_candidates; + }; + + void resizeIndex(size_t new_max_elements){ + if (new_max_elements(new_max_elements).swap(link_list_locks_); + + + // Reallocate base layer + char * data_level0_memory_new = (char *) malloc(new_max_elements * size_data_per_element_); + if (data_level0_memory_new == nullptr) + throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer"); + memcpy(data_level0_memory_new, data_level0_memory_,cur_element_count * size_data_per_element_); + free(data_level0_memory_); + data_level0_memory_=data_level0_memory_new; + + // Reallocate all other layers + char ** linkLists_new = (char **) malloc(sizeof(void *) * new_max_elements); + if (linkLists_new == nullptr) + throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers"); + memcpy(linkLists_new, linkLists_,cur_element_count * sizeof(void *)); + free(linkLists_); + linkLists_=linkLists_new; + + max_elements_=new_max_elements; + + } + + void saveIndex(const std::string &location) { + std::ofstream output(location, std::ios::binary); + std::streampos position; + + writeBinaryPOD(output, offsetLevel0_); + writeBinaryPOD(output, max_elements_); + writeBinaryPOD(output, cur_element_count); + writeBinaryPOD(output, size_data_per_element_); + writeBinaryPOD(output, label_offset_); + writeBinaryPOD(output, offsetData_); + writeBinaryPOD(output, maxlevel_); + writeBinaryPOD(output, enterpoint_node_); + writeBinaryPOD(output, maxM_); + + writeBinaryPOD(output, maxM0_); + writeBinaryPOD(output, M_); + writeBinaryPOD(output, mult_); + writeBinaryPOD(output, ef_construction_); + + output.write(data_level0_memory_, cur_element_count * size_data_per_element_); + + for (size_t i = 0; i < cur_element_count; i++) { + unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; + writeBinaryPOD(output, linkListSize); + if (linkListSize) + output.write(linkLists_[i], linkListSize); + } + output.close(); + } + + void loadIndex(const std::string &location, SpaceInterface *s, size_t max_elements_i=0) { + + + std::ifstream input(location, std::ios::binary); + + if (!input.is_open()) + throw std::runtime_error("Cannot open file"); + + + // get file size: + input.seekg(0,input.end); + std::streampos total_filesize=input.tellg(); + input.seekg(0,input.beg); + + readBinaryPOD(input, offsetLevel0_); + readBinaryPOD(input, max_elements_); + readBinaryPOD(input, cur_element_count); + + size_t max_elements=max_elements_i; + if(max_elements < cur_element_count) + max_elements = max_elements_; + max_elements_ = max_elements; + readBinaryPOD(input, size_data_per_element_); + readBinaryPOD(input, label_offset_); + readBinaryPOD(input, offsetData_); + readBinaryPOD(input, maxlevel_); + readBinaryPOD(input, enterpoint_node_); + + readBinaryPOD(input, maxM_); + readBinaryPOD(input, maxM0_); + readBinaryPOD(input, M_); + readBinaryPOD(input, mult_); + readBinaryPOD(input, ef_construction_); + + + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + + auto pos=input.tellg(); + + + /// Optional - check if index is ok: + + input.seekg(cur_element_count * size_data_per_element_,input.cur); + for (size_t i = 0; i < cur_element_count; i++) { + if(input.tellg() < 0 || input.tellg()>=total_filesize){ + throw std::runtime_error("Index seems to be corrupted or unsupported"); + } + + unsigned int linkListSize; + readBinaryPOD(input, linkListSize); + if (linkListSize != 0) { + input.seekg(linkListSize,input.cur); + } + } + + // throw exception if it either corrupted or old index + if(input.tellg()!=total_filesize) + throw std::runtime_error("Index seems to be corrupted or unsupported"); + + input.clear(); + + /// Optional check end + + input.seekg(pos,input.beg); + + + data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_); + if (data_level0_memory_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); + input.read(data_level0_memory_, cur_element_count * size_data_per_element_); + + + + + size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); + + + size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); + std::vector(max_elements).swap(link_list_locks_); + + + visited_list_pool_ = new VisitedListPool(1, max_elements); + + + linkLists_ = (char **) malloc(sizeof(void *) * max_elements); + if (linkLists_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists"); + element_levels_ = std::vector(max_elements); + revSize_ = 1.0 / mult_; + ef_ = 10; + for (size_t i = 0; i < cur_element_count; i++) { + label_lookup_[getExternalLabel(i)]=i; + unsigned int linkListSize; + readBinaryPOD(input, linkListSize); + if (linkListSize == 0) { + element_levels_[i] = 0; + + linkLists_[i] = nullptr; + } else { + element_levels_[i] = linkListSize / size_links_per_element_; + linkLists_[i] = (char *) malloc(linkListSize); + if (linkLists_[i] == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); + input.read(linkLists_[i], linkListSize); + } + } + + has_deletions_=false; + + for (size_t i = 0; i < cur_element_count; i++) { + if(isMarkedDeleted(i)) + has_deletions_=true; + } + + input.close(); + + return; + } + + template + std::vector getDataByLabel(labeltype label) + { + tableint label_c; + auto search = label_lookup_.find(label); + if (search == label_lookup_.end() || isMarkedDeleted(search->second)) { + throw std::runtime_error("Label not found"); + } + label_c = search->second; + + char* data_ptrv = getDataByInternalId(label_c); + size_t dim = *((size_t *) dist_func_param_); + std::vector data; + data_t* data_ptr = (data_t*) data_ptrv; + for (int i = 0; i < dim; i++) { + data.push_back(*data_ptr); + data_ptr += 1; + } + return data; + } + + static const unsigned char DELETE_MARK = 0x01; +// static const unsigned char REUSE_MARK = 0x10; + /** + * Marks an element with the given label deleted, does NOT really change the current graph. + * @param label + */ + void markDelete(labeltype label) + { + has_deletions_=true; + auto search = label_lookup_.find(label); + if (search == label_lookup_.end()) { + throw std::runtime_error("Label not found"); + } + markDeletedInternal(search->second); + } + + /** + * Uses the first 8 bits of the memory for the linked list to store the mark, + * whereas maxM0_ has to be limited to the lower 24 bits, however, still large enough in almost all cases. + * @param internalId + */ + void markDeletedInternal(tableint internalId) { + unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; + *ll_cur |= DELETE_MARK; + } + + /** + * Remove the deleted mark of the node. + * @param internalId + */ + void unmarkDeletedInternal(tableint internalId) { + unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; + *ll_cur &= ~DELETE_MARK; + } + + /** + * Checks the first 8 bits of the memory to see if the element is marked deleted. + * @param internalId + * @return + */ + bool isMarkedDeleted(tableint internalId) const { + unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId))+2; + return *ll_cur & DELETE_MARK; + } + + unsigned short int getListCount(linklistsizeint * ptr) const { + return *((unsigned short int *)ptr); + } + + void setListCount(linklistsizeint * ptr, unsigned short int size) const { + *((unsigned short int*)(ptr))=*((unsigned short int *)&size); + } + + void addPoint(const void *data_point, labeltype label) { + addPoint(data_point, label,-1); + } + + tableint addPoint(const void *data_point, labeltype label, int level) { + tableint cur_c = 0; + { + std::unique_lock lock(cur_element_count_guard_); + if (cur_element_count >= max_elements_) { + throw std::runtime_error("The number of elements exceeds the specified limit"); + }; + + cur_c = cur_element_count; + cur_element_count++; + + auto search = label_lookup_.find(label); + if (search != label_lookup_.end()) { + std::unique_lock lock_el(link_list_locks_[search->second]); + has_deletions_ = true; + markDeletedInternal(search->second); + } + label_lookup_[label] = cur_c; + } + + std::unique_lock lock_el(link_list_locks_[cur_c]); + int curlevel = getRandomLevel(mult_); + if (level > 0) + curlevel = level; + + element_levels_[cur_c] = curlevel; + + + std::unique_lock templock(global); + int maxlevelcopy = maxlevel_; + if (curlevel <= maxlevelcopy) + templock.unlock(); + tableint currObj = enterpoint_node_; + tableint enterpoint_copy = enterpoint_node_; + + + memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_); + + // Initialisation of the data and label + memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype)); + memcpy(getDataByInternalId(cur_c), data_point, data_size_); + + + if (curlevel) { + linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1); + if (linkLists_[cur_c] == nullptr) + throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist"); + memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1); + } + + if ((signed)currObj != -1) { + + if (curlevel < maxlevelcopy) { + + dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_); + for (int level = maxlevelcopy; level > curlevel; level--) { + + + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; + std::unique_lock lock(link_list_locks_[currObj]); + data = get_linklist(currObj,level); + int size = getListCount(data); + + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_); + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + } + + bool epDeleted = isMarkedDeleted(enterpoint_copy); + for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) { + if (level > maxlevelcopy || level < 0) // possible? + throw std::runtime_error("Level error"); + + std::priority_queue, std::vector>, CompareByFirst> top_candidates = searchBaseLayer( + currObj, data_point, level); + if (epDeleted) { + top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy); + if (top_candidates.size() > ef_construction_) + top_candidates.pop(); + } + mutuallyConnectNewElement(data_point, cur_c, top_candidates, level); + + currObj = top_candidates.top().second; + } + + + } else { + // Do nothing for the first element + enterpoint_node_ = 0; + maxlevel_ = curlevel; + + } + + //Releasing lock for the maximum level + if (curlevel > maxlevelcopy) { + enterpoint_node_ = cur_c; + maxlevel_ = curlevel; + } + return cur_c; + }; + + std::priority_queue> + searchKnn(const void *query_data, size_t k) const { + std::priority_queue> result; + if (cur_element_count == 0) return result; + + tableint currObj = enterpoint_node_; + dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); + + for (int level = maxlevel_; level > 0; level--) { + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; + + data = (unsigned int *) get_linklist(currObj, level); + int size = getListCount(data); + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); + + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + if (has_deletions_) { + std::priority_queue, std::vector>, CompareByFirst> top_candidates1=searchBaseLayerST( + currObj, query_data, std::max(ef_, k)); + top_candidates.swap(top_candidates1); + } + else{ + std::priority_queue, std::vector>, CompareByFirst> top_candidates1=searchBaseLayerST( + currObj, query_data, std::max(ef_, k)); + top_candidates.swap(top_candidates1); + } + while (top_candidates.size() > k) { + top_candidates.pop(); + } + while (top_candidates.size() > 0) { + std::pair rez = top_candidates.top(); + result.push(std::pair(rez.first, getExternalLabel(rez.second))); + top_candidates.pop(); + } + return result; + }; + + template + std::vector> + searchKnn(const void* query_data, size_t k, Comp comp) { + std::vector> result; + if (cur_element_count == 0) return result; + + auto ret = searchKnn(query_data, k); + + while (!ret.empty()) { + result.push_back(ret.top()); + ret.pop(); + } + + std::sort(result.begin(), result.end(), comp); + + return result; + } + + }; + +} diff --git a/src/third_party/hnswlib/hnswlib.h b/src/third_party/hnswlib/hnswlib.h new file mode 100644 index 0000000000..dbfb16561e --- /dev/null +++ b/src/third_party/hnswlib/hnswlib.h @@ -0,0 +1,88 @@ +#pragma once +#ifndef NO_MANUAL_VECTORIZATION +#ifdef __SSE__ +#define USE_SSE +#ifdef __AVX__ +#define USE_AVX +#endif +#endif +#endif + +#if defined(USE_AVX) || defined(USE_SSE) +#ifdef _MSC_VER +#include +#include +#else +#include +#endif + +#if defined(__GNUC__) +#define PORTABLE_ALIGN32 __attribute__((aligned(32))) +#else +#define PORTABLE_ALIGN32 __declspec(align(32)) +#endif +#endif + +#include +#include + +#include + +namespace hnswlib { + typedef size_t labeltype; + + template + class pairGreater { + public: + bool operator()(const T& p1, const T& p2) { + return p1.first > p2.first; + } + }; + + template + static void writeBinaryPOD(std::ostream &out, const T &podRef) { + out.write((char *) &podRef, sizeof(T)); + } + + template + static void readBinaryPOD(std::istream &in, T &podRef) { + in.read((char *) &podRef, sizeof(T)); + } + + template + using DISTFUNC = MTYPE(*)(const void *, const void *, const void *); + + + template + class SpaceInterface { + public: + //virtual void search(void *); + virtual size_t get_data_size() = 0; + + virtual DISTFUNC get_dist_func() = 0; + + virtual void *get_dist_func_param() = 0; + + virtual ~SpaceInterface() {} + }; + + template + class AlgorithmInterface { + public: + virtual void addPoint(const void *datapoint, labeltype label)=0; + virtual std::priority_queue> searchKnn(const void *, size_t) const = 0; + template + std::vector> searchKnn(const void*, size_t, Comp) { + } + virtual void saveIndex(const std::string &location)=0; + virtual ~AlgorithmInterface(){ + } + }; + + +} + +#include "space_l2.h" +#include "space_ip.h" +#include "bruteforce.h" +#include "hnswalg.h" diff --git a/src/third_party/hnswlib/space_ip.h b/src/third_party/hnswlib/space_ip.h new file mode 100644 index 0000000000..e94674730c --- /dev/null +++ b/src/third_party/hnswlib/space_ip.h @@ -0,0 +1,248 @@ +#pragma once +#include "hnswlib.h" + +namespace hnswlib { + + static float + InnerProduct(const void *pVect1, const void *pVect2, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + float res = 0; + for (unsigned i = 0; i < qty; i++) { + res += ((float *) pVect1)[i] * ((float *) pVect2)[i]; + } + return (1.0f - res); + + } + +#if defined(USE_AVX) + +// Favor using AVX if available. + static float + InnerProductSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + size_t qty16 = qty / 16; + size_t qty4 = qty / 4; + + const float *pEnd1 = pVect1 + 16 * qty16; + const float *pEnd2 = pVect1 + 4 * qty4; + + __m256 sum256 = _mm256_set1_ps(0); + + while (pVect1 < pEnd1) { + //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); + + __m256 v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + __m256 v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); + + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); + } + + __m128 v1, v2; + __m128 sum_prod = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1)); + + while (pVect1 < pEnd2) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + } + + _mm_store_ps(TmpRes, sum_prod); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];; + return 1.0f - sum; +} + +#elif defined(USE_SSE) + + static float + InnerProductSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + size_t qty16 = qty / 16; + size_t qty4 = qty / 4; + + const float *pEnd1 = pVect1 + 16 * qty16; + const float *pEnd2 = pVect1 + 4 * qty4; + + __m128 v1, v2; + __m128 sum_prod = _mm_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + } + + while (pVect1 < pEnd2) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + } + + _mm_store_ps(TmpRes, sum_prod); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + + return 1.0f - sum; + } + +#endif + +#if defined(USE_AVX) + + static float + InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + size_t qty16 = qty / 16; + + + const float *pEnd1 = pVect1 + 16 * qty16; + + __m256 sum256 = _mm256_set1_ps(0); + + while (pVect1 < pEnd1) { + //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); + + __m256 v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + __m256 v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); + + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); + } + + _mm256_store_ps(TmpRes, sum256); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; + + return 1.0f - sum; + } + +#elif defined(USE_SSE) + + static float + InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + size_t qty16 = qty / 16; + + const float *pEnd1 = pVect1 + 16 * qty16; + + __m128 v1, v2; + __m128 sum_prod = _mm_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + } + _mm_store_ps(TmpRes, sum_prod); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + + return 1.0f - sum; + } + +#endif + + class InnerProductSpace : public SpaceInterface { + + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t dim_; + public: + InnerProductSpace(size_t dim) { + fstdistfunc_ = InnerProduct; + #if defined(USE_AVX) || defined(USE_SSE) + if (dim % 4 == 0) + fstdistfunc_ = InnerProductSIMD4Ext; + if (dim % 16 == 0) + fstdistfunc_ = InnerProductSIMD16Ext; +#endif + dim_ = dim; + data_size_ = dim * sizeof(float); + } + + size_t get_data_size() { + return data_size_; + } + + DISTFUNC get_dist_func() { + return fstdistfunc_; + } + + void *get_dist_func_param() { + return &dim_; + } + + ~InnerProductSpace() {} + }; + + +} diff --git a/src/third_party/hnswlib/space_l2.h b/src/third_party/hnswlib/space_l2.h new file mode 100644 index 0000000000..4d3ac69ac4 --- /dev/null +++ b/src/third_party/hnswlib/space_l2.h @@ -0,0 +1,244 @@ +#pragma once +#include "hnswlib.h" + +namespace hnswlib { + + static float + L2Sqr(const void *pVect1, const void *pVect2, const void *qty_ptr) { + //return *((float *)pVect2); + size_t qty = *((size_t *) qty_ptr); + float res = 0; + for (unsigned i = 0; i < qty; i++) { + float t = ((float *) pVect1)[i] - ((float *) pVect2)[i]; + res += t * t; + } + return (res); + + } + +#if defined(USE_AVX) + + // Favor using AVX if available. + static float + L2SqrSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + float PORTABLE_ALIGN32 TmpRes[8]; + size_t qty16 = qty >> 4; + + const float *pEnd1 = pVect1 + (qty16 << 4); + + __m256 diff, v1, v2; + __m256 sum = _mm256_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + diff = _mm256_sub_ps(v1, v2); + sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); + + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + diff = _mm256_sub_ps(v1, v2); + sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); + } + + _mm256_store_ps(TmpRes, sum); + float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; + + return (res); +} + +#elif defined(USE_SSE) + + static float + L2SqrSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + float PORTABLE_ALIGN32 TmpRes[8]; + // size_t qty4 = qty >> 2; + size_t qty16 = qty >> 4; + + const float *pEnd1 = pVect1 + (qty16 << 4); + // const float* pEnd2 = pVect1 + (qty4 << 2); + // const float* pEnd3 = pVect1 + qty; + + __m128 diff, v1, v2; + __m128 sum = _mm_set1_ps(0); + + while (pVect1 < pEnd1) { + //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + } + _mm_store_ps(TmpRes, sum); + float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + + return (res); + } +#endif + + +#ifdef USE_SSE + static float + L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + + // size_t qty4 = qty >> 2; + size_t qty16 = qty >> 2; + + const float *pEnd1 = pVect1 + (qty16 << 2); + + __m128 diff, v1, v2; + __m128 sum = _mm_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + } + _mm_store_ps(TmpRes, sum); + float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + + return (res); + } +#endif + + class L2Space : public SpaceInterface { + + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t dim_; + public: + L2Space(size_t dim) { + fstdistfunc_ = L2Sqr; + #if defined(USE_SSE) || defined(USE_AVX) + if (dim % 4 == 0) + fstdistfunc_ = L2SqrSIMD4Ext; + if (dim % 16 == 0) + fstdistfunc_ = L2SqrSIMD16Ext; + /*else{ + throw runtime_error("Data type not supported!"); + }*/ + #endif + dim_ = dim; + data_size_ = dim * sizeof(float); + } + + size_t get_data_size() { + return data_size_; + } + + DISTFUNC get_dist_func() { + return fstdistfunc_; + } + + void *get_dist_func_param() { + return &dim_; + } + + ~L2Space() {} + }; + + static int + L2SqrI(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) { + + size_t qty = *((size_t *) qty_ptr); + int res = 0; + unsigned char *a = (unsigned char *) pVect1; + unsigned char *b = (unsigned char *) pVect2; + /*for (int i = 0; i < qty; i++) { + int t = int((a)[i]) - int((b)[i]); + res += t*t; + }*/ + + qty = qty >> 2; + for (size_t i = 0; i < qty; i++) { + + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + + + } + + return (res); + + } + + class L2SpaceI : public SpaceInterface { + + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t dim_; + public: + L2SpaceI(size_t dim) { + fstdistfunc_ = L2SqrI; + dim_ = dim; + data_size_ = dim * sizeof(unsigned char); + } + + size_t get_data_size() { + return data_size_; + } + + DISTFUNC get_dist_func() { + return fstdistfunc_; + } + + void *get_dist_func_param() { + return &dim_; + } + + ~L2SpaceI() {} + }; + + +} diff --git a/src/third_party/hnswlib/visited_list_pool.h b/src/third_party/hnswlib/visited_list_pool.h new file mode 100644 index 0000000000..6b0f445878 --- /dev/null +++ b/src/third_party/hnswlib/visited_list_pool.h @@ -0,0 +1,78 @@ +#pragma once + +#include +#include + +namespace hnswlib { + typedef unsigned short int vl_type; + + class VisitedList { + public: + vl_type curV; + vl_type *mass; + unsigned int numelements; + + VisitedList(int numelements1) { + curV = -1; + numelements = numelements1; + mass = new vl_type[numelements]; + } + + void reset() { + curV++; + if (curV == 0) { + memset(mass, 0, sizeof(vl_type) * numelements); + curV++; + } + }; + + ~VisitedList() { delete[] mass; } + }; +/////////////////////////////////////////////////////////// +// +// Class for multi-threaded pool-management of VisitedLists +// +///////////////////////////////////////////////////////// + + class VisitedListPool { + std::deque pool; + std::mutex poolguard; + int numelements; + + public: + VisitedListPool(int initmaxpools, int numelements1) { + numelements = numelements1; + for (int i = 0; i < initmaxpools; i++) + pool.push_front(new VisitedList(numelements)); + } + + VisitedList *getFreeVisitedList() { + VisitedList *rez; + { + std::unique_lock lock(poolguard); + if (pool.size() > 0) { + rez = pool.front(); + pool.pop_front(); + } else { + rez = new VisitedList(numelements); + } + } + rez->reset(); + return rez; + }; + + void releaseVisitedList(VisitedList *vl) { + std::unique_lock lock(poolguard); + pool.push_front(vl); + }; + + ~VisitedListPool() { + while (pool.size()) { + VisitedList *rez = pool.front(); + pool.pop_front(); + delete rez; + } + }; + }; +} +