diff --git a/include/tsl/robin_hash.h b/include/tsl/robin_hash.h index 58766ec..648b7e5 100644 --- a/include/tsl/robin_hash.h +++ b/include/tsl/robin_hash.h @@ -67,6 +67,11 @@ template struct is_power_of_two_policy>: std::true_type { }; +// Only available in C++17, we need to be compatible with C++11 +template +const T& clamp( const T& v, const T& lo, const T& hi) { + return std::min(hi, std::max(lo, v)); +} using truncated_hash_type = std::uint_least32_t; @@ -494,7 +499,9 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy { const Hash& hash, const KeyEqual& equal, const Allocator& alloc, - float max_load_factor): Hash(hash), + float min_load_factor = DEFAULT_MIN_LOAD_FACTOR, + float max_load_factor = DEFAULT_MAX_LOAD_FACTOR): + Hash(hash), KeyEqual(equal), GrowthPolicy(bucket_count), m_buckets_data( @@ -506,14 +513,15 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy { m_buckets(m_buckets_data.empty()?static_empty_bucket_ptr():m_buckets_data.data()), m_bucket_count(bucket_count), m_nb_elements(0), - m_grow_on_next_insert(false) + m_grow_on_next_insert(false), + m_try_skrink_on_next_insert(false) { if(m_bucket_count > 0) { tsl_rh_assert(!m_buckets_data.empty()); m_buckets_data.back().set_as_last_bucket(); } - + this->min_load_factor(min_load_factor); this->max_load_factor(max_load_factor); } #else @@ -529,14 +537,17 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy { const Hash& hash, const KeyEqual& equal, const Allocator& alloc, - float max_load_factor): Hash(hash), + float min_load_factor = DEFAULT_MIN_LOAD_FACTOR, + float max_load_factor = DEFAULT_MAX_LOAD_FACTOR): + Hash(hash), KeyEqual(equal), GrowthPolicy(bucket_count), m_buckets_data(alloc), m_buckets(static_empty_bucket_ptr()), m_bucket_count(bucket_count), m_nb_elements(0), - m_grow_on_next_insert(false) + m_grow_on_next_insert(false), + m_try_skrink_on_next_insert(false) { if(bucket_count > max_bucket_count()) { TSL_RH_THROW_OR_TERMINATE(std::length_error, "The map exceeds its maxmimum bucket count."); @@ -550,7 +561,7 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy { m_buckets_data.back().set_as_last_bucket(); } - + this->min_load_factor(min_load_factor); this->max_load_factor(max_load_factor); } #endif @@ -564,7 +575,9 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy { m_nb_elements(other.m_nb_elements), m_load_threshold(other.m_load_threshold), m_max_load_factor(other.m_max_load_factor), - m_grow_on_next_insert(other.m_grow_on_next_insert) + m_grow_on_next_insert(other.m_grow_on_next_insert), + m_min_load_factor(other.m_min_load_factor), + m_try_skrink_on_next_insert(other.m_try_skrink_on_next_insert) { } @@ -581,7 +594,9 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy { m_nb_elements(other.m_nb_elements), m_load_threshold(other.m_load_threshold), m_max_load_factor(other.m_max_load_factor), - m_grow_on_next_insert(other.m_grow_on_next_insert) + m_grow_on_next_insert(other.m_grow_on_next_insert), + m_min_load_factor(other.m_min_load_factor), + m_try_skrink_on_next_insert(other.m_try_skrink_on_next_insert) { other.GrowthPolicy::clear(); other.m_buckets_data.clear(); @@ -590,6 +605,7 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy { other.m_nb_elements = 0; other.m_load_threshold = 0; other.m_grow_on_next_insert = false; + other.m_try_skrink_on_next_insert = false; } robin_hash& operator=(const robin_hash& other) { @@ -603,9 +619,13 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy { m_buckets_data.data(); m_bucket_count = other.m_bucket_count; m_nb_elements = other.m_nb_elements; + m_load_threshold = other.m_load_threshold; m_max_load_factor = other.m_max_load_factor; m_grow_on_next_insert = other.m_grow_on_next_insert; + + m_min_load_factor = other.m_min_load_factor; + m_try_skrink_on_next_insert = other.m_try_skrink_on_next_insert; } return *this; @@ -791,6 +811,8 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy { ++pos; } + m_try_skrink_on_next_insert = true; + return pos; } @@ -847,7 +869,8 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy { ++icloser_bucket; ++ito_move_closer_value; } - + + m_try_skrink_on_next_insert = true; return iterator(m_buckets + ireturn_bucket); } @@ -863,6 +886,7 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy { auto it = find(key, hash); if(it != end()) { erase_from_bucket(it); + m_try_skrink_on_next_insert = true; return 1; } @@ -888,6 +912,8 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy { swap(m_load_threshold, other.m_load_threshold); swap(m_max_load_factor, other.m_max_load_factor); swap(m_grow_on_next_insert, other.m_grow_on_next_insert); + swap(m_min_load_factor, other.m_min_load_factor); + swap(m_try_skrink_on_next_insert, other.m_try_skrink_on_next_insert); } @@ -1010,12 +1036,22 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy { return float(m_nb_elements)/float(bucket_count()); } + float min_load_factor() const { + return m_min_load_factor; + } + float max_load_factor() const { return m_max_load_factor; } + void min_load_factor(float ml) { + m_min_load_factor = clamp(ml, float(MINIMUM_MIN_LOAD_FACTOR), + float(MAXIMUM_MIN_LOAD_FACTOR)); + } + void max_load_factor(float ml) { - m_max_load_factor = std::max(0.1f, std::min(ml, 0.95f)); + m_max_load_factor = clamp(ml, float(MINIMUM_MAX_LOAD_FACTOR), + float(MAXIMUM_MAX_LOAD_FACTOR)); m_load_threshold = size_type(float(bucket_count())*m_max_load_factor); } @@ -1150,7 +1186,7 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy { dist_from_ideal_bucket++; } - if(grow_on_high_load()) { + if(rehash_on_extreme_load()) { ibucket = bucket_for_hash(hash); dist_from_ideal_bucket = 0; @@ -1233,7 +1269,7 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy { void rehash_impl(size_type count) { robin_hash new_table(count, static_cast(*this), static_cast(*this), - get_allocator(), m_max_load_factor); + get_allocator(), m_min_load_factor, m_max_load_factor); const bool use_stored_hash = USE_STORED_HASH_ON_REHASH(new_table.bucket_count()); for(auto& bucket: m_buckets_data) { @@ -1274,9 +1310,13 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy { /** - * Return true if the map has been rehashed. + * Grow the table if m_grow_on_next_insert is true or we reached the max_load_factor. + * Shrink the table if m_try_skrink_on_next_insert is true (an erase occured) and + * we're below the min_load_factor. + * + * Return true if the table has been rehashed. */ - bool grow_on_high_load() { + bool rehash_on_extreme_load() { if(m_grow_on_next_insert || size() >= m_load_threshold) { rehash_impl(GrowthPolicy::next_bucket_count()); m_grow_on_next_insert = false; @@ -1284,13 +1324,36 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy { return true; } + if(m_try_skrink_on_next_insert) { + m_try_skrink_on_next_insert = false; + if(m_min_load_factor != 0.0f && load_factor() < m_min_load_factor) { + reserve(size() + 1); + + return true; + } + } + return false; } public: static const size_type DEFAULT_INIT_BUCKETS_SIZE = 0; + static constexpr float DEFAULT_MAX_LOAD_FACTOR = 0.5f; + static constexpr float MINIMUM_MAX_LOAD_FACTOR = 0.2f; + static constexpr float MAXIMUM_MAX_LOAD_FACTOR = 0.95f; + + static constexpr float DEFAULT_MIN_LOAD_FACTOR = 0.0f; + static constexpr float MINIMUM_MIN_LOAD_FACTOR = 0.0f; + static constexpr float MAXIMUM_MIN_LOAD_FACTOR = 0.15f; + + static_assert(MINIMUM_MAX_LOAD_FACTOR < MAXIMUM_MAX_LOAD_FACTOR, + "MINIMUM_MAX_LOAD_FACTOR should be < MAXIMUM_MAX_LOAD_FACTOR"); + static_assert(MINIMUM_MIN_LOAD_FACTOR < MAXIMUM_MIN_LOAD_FACTOR, + "MINIMUM_MIN_LOAD_FACTOR should be < MAXIMUM_MIN_LOAD_FACTOR"); + static_assert(MAXIMUM_MIN_LOAD_FACTOR < MINIMUM_MAX_LOAD_FACTOR, + "MAXIMUM_MIN_LOAD_FACTOR should be < MINIMUM_MAX_LOAD_FACTOR"); private: static const distance_type REHASH_ON_HIGH_NB_PROBES__NPROBES = 128; @@ -1329,6 +1392,16 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy { float m_max_load_factor; bool m_grow_on_next_insert; + + float m_min_load_factor; + + /** + * We can't shrink down the map on erase operations as the erase methods need to return the next iterator. + * Shrinking the map would invalidate all the iterators and we could not return the next iterator in a meaningful way, + * On erase, we thus just indicate on erase that we should try to shrink the hash table on the next insert + * if we go below the min_load_factor. + */ + bool m_try_skrink_on_next_insert; }; } diff --git a/include/tsl/robin_map.h b/include/tsl/robin_map.h index 05c6dce..4765fc6 100644 --- a/include/tsl/robin_map.h +++ b/include/tsl/robin_map.h @@ -140,7 +140,7 @@ class robin_map { const Hash& hash = Hash(), const KeyEqual& equal = KeyEqual(), const Allocator& alloc = Allocator()): - m_ht(bucket_count, hash, equal, alloc, ht::DEFAULT_MAX_LOAD_FACTOR) + m_ht(bucket_count, hash, equal, alloc) { } @@ -600,7 +600,19 @@ class robin_map { * Hash policy */ float load_factor() const { return m_ht.load_factor(); } + + float min_load_factor() const { return m_ht.min_load_factor(); } float max_load_factor() const { return m_ht.max_load_factor(); } + + /** + * Set the `min_load_factor` to `ml`. When the `load_factor` of the map goes + * below `min_load_factor` after some erase operations, the map will be + * shrunk when an insertion occurs. The erase method itself never shrinks + * the map. + * + * The default value of `min_load_factor` is 0.0f, the map never shrinks by default. + */ + void min_load_factor(float ml) { m_ht.min_load_factor(ml); } void max_load_factor(float ml) { m_ht.max_load_factor(ml); } void rehash(size_type count) { m_ht.rehash(count); } diff --git a/include/tsl/robin_set.h b/include/tsl/robin_set.h index 219a33e..ca06754 100644 --- a/include/tsl/robin_set.h +++ b/include/tsl/robin_set.h @@ -124,7 +124,7 @@ class robin_set { const Hash& hash = Hash(), const KeyEqual& equal = KeyEqual(), const Allocator& alloc = Allocator()): - m_ht(bucket_count, hash, equal, alloc, ht::DEFAULT_MAX_LOAD_FACTOR) + m_ht(bucket_count, hash, equal, alloc) { } @@ -466,7 +466,19 @@ class robin_set { * Hash policy */ float load_factor() const { return m_ht.load_factor(); } + + float min_load_factor() const { return m_ht.min_load_factor(); } float max_load_factor() const { return m_ht.max_load_factor(); } + + /** + * Set the `min_load_factor` to `ml`. When the `load_factor` of the set goes + * below `min_load_factor` after some erase operations, the set will be + * shrunk when an insertion occurs. The erase method itself never shrinks + * the set. + * + * The default value of `min_load_factor` is 0.0f, the set never shrinks by default. + */ + void min_load_factor(float ml) { m_ht.min_load_factor(ml); } void max_load_factor(float ml) { m_ht.max_load_factor(ml); } void rehash(size_type count) { m_ht.rehash(count); } diff --git a/tests/robin_map_tests.cpp b/tests/robin_map_tests.cpp index 6367b6c..b1ae6fc 100644 --- a/tests/robin_map_tests.cpp +++ b/tests/robin_map_tests.cpp @@ -487,6 +487,69 @@ BOOST_AUTO_TEST_CASE(test_range_erase_same_iterators) { BOOST_CHECK_EQUAL(it_const.value(), -100); } +/** + * max_load_factor + */ +BOOST_AUTO_TEST_CASE(test_max_load_factor_extreme_factors) { + tsl::robin_map map; + + map.max_load_factor(0.0f); + BOOST_CHECK_GT(map.max_load_factor(), 0.0f); + + map.max_load_factor(10.0f); + BOOST_CHECK_LT(map.max_load_factor(), 1.0f); +} + +/** + * min_load_factor + */ +BOOST_AUTO_TEST_CASE(test_min_load_factor_extreme_factors) { + tsl::robin_map map; + + BOOST_CHECK_EQUAL(map.min_load_factor(), 0.0f); + BOOST_CHECK_LT(map.min_load_factor(), map.max_load_factor()); + + map.min_load_factor(-10.0f); + BOOST_CHECK_EQUAL(map.min_load_factor(), 0.0f); + + map.min_load_factor(0.9f); + map.max_load_factor(0.1f); + + // max_load_factor should always be > min_load_factor. + // Factors should have been clamped. + BOOST_CHECK_LT(map.min_load_factor(), map.max_load_factor()); +} + +BOOST_AUTO_TEST_CASE(test_min_load_factor) { + tsl::robin_map map; + + map.min_load_factor(0.15f); + BOOST_CHECK_EQUAL(map.min_load_factor(), 0.15f); + + map.max_load_factor(0.5f); + BOOST_CHECK_EQUAL(map.max_load_factor(), 0.5f); + + + map.rehash(100); + for(std::size_t i = 0; i < map.bucket_count()/2; i++) { + map.insert({utils::get_key(i), + utils::get_value(i)}); + } + + BOOST_CHECK_GT(map.load_factor(), map.min_load_factor()); + BOOST_CHECK_CLOSE(map.load_factor(), 0.5f, 0.05f); + + + while(map.load_factor() >= map.min_load_factor()) { + map.erase(map.begin()); + } + + // Shrink is done on insert. + map.insert({utils::get_key(map.bucket_count()), + utils::get_value(map.bucket_count())}); + BOOST_CHECK_GT(map.load_factor(), map.min_load_factor()); +} + /** * rehash */