Skip to content

Commit

Permalink
Add support for min_load_factor (Tessil#17).
Browse files Browse the repository at this point in the history
  • Loading branch information
Tessil authored Apr 28, 2019
1 parent 5d3f52f commit a279380
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 16 deletions.
101 changes: 87 additions & 14 deletions include/tsl/robin_hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ template<std::size_t GrowthFactor>
struct is_power_of_two_policy<tsl::rh::power_of_two_growth_policy<GrowthFactor>>: std::true_type {
};

// Only available in C++17, we need to be compatible with C++11
template<class T>
const T& clamp( const T& v, const T& lo, const T& hi) {
return std::min(hi, std::max(lo, v));
}


using truncated_hash_type = std::uint_least32_t;
Expand Down Expand Up @@ -494,7 +499,9 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy {
const Hash& hash,
const KeyEqual& equal,
const Allocator& alloc,
float max_load_factor): Hash(hash),
float min_load_factor = DEFAULT_MIN_LOAD_FACTOR,
float max_load_factor = DEFAULT_MAX_LOAD_FACTOR):
Hash(hash),
KeyEqual(equal),
GrowthPolicy(bucket_count),
m_buckets_data(
Expand All @@ -506,14 +513,15 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy {
m_buckets(m_buckets_data.empty()?static_empty_bucket_ptr():m_buckets_data.data()),
m_bucket_count(bucket_count),
m_nb_elements(0),
m_grow_on_next_insert(false)
m_grow_on_next_insert(false),
m_try_skrink_on_next_insert(false)
{
if(m_bucket_count > 0) {
tsl_rh_assert(!m_buckets_data.empty());
m_buckets_data.back().set_as_last_bucket();
}


this->min_load_factor(min_load_factor);
this->max_load_factor(max_load_factor);
}
#else
Expand All @@ -529,14 +537,17 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy {
const Hash& hash,
const KeyEqual& equal,
const Allocator& alloc,
float max_load_factor): Hash(hash),
float min_load_factor = DEFAULT_MIN_LOAD_FACTOR,
float max_load_factor = DEFAULT_MAX_LOAD_FACTOR):
Hash(hash),
KeyEqual(equal),
GrowthPolicy(bucket_count),
m_buckets_data(alloc),
m_buckets(static_empty_bucket_ptr()),
m_bucket_count(bucket_count),
m_nb_elements(0),
m_grow_on_next_insert(false)
m_grow_on_next_insert(false),
m_try_skrink_on_next_insert(false)
{
if(bucket_count > max_bucket_count()) {
TSL_RH_THROW_OR_TERMINATE(std::length_error, "The map exceeds its maxmimum bucket count.");
Expand All @@ -550,7 +561,7 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy {
m_buckets_data.back().set_as_last_bucket();
}


this->min_load_factor(min_load_factor);
this->max_load_factor(max_load_factor);
}
#endif
Expand All @@ -564,7 +575,9 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy {
m_nb_elements(other.m_nb_elements),
m_load_threshold(other.m_load_threshold),
m_max_load_factor(other.m_max_load_factor),
m_grow_on_next_insert(other.m_grow_on_next_insert)
m_grow_on_next_insert(other.m_grow_on_next_insert),
m_min_load_factor(other.m_min_load_factor),
m_try_skrink_on_next_insert(other.m_try_skrink_on_next_insert)
{
}

Expand All @@ -581,7 +594,9 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy {
m_nb_elements(other.m_nb_elements),
m_load_threshold(other.m_load_threshold),
m_max_load_factor(other.m_max_load_factor),
m_grow_on_next_insert(other.m_grow_on_next_insert)
m_grow_on_next_insert(other.m_grow_on_next_insert),
m_min_load_factor(other.m_min_load_factor),
m_try_skrink_on_next_insert(other.m_try_skrink_on_next_insert)
{
other.GrowthPolicy::clear();
other.m_buckets_data.clear();
Expand All @@ -590,6 +605,7 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy {
other.m_nb_elements = 0;
other.m_load_threshold = 0;
other.m_grow_on_next_insert = false;
other.m_try_skrink_on_next_insert = false;
}

robin_hash& operator=(const robin_hash& other) {
Expand All @@ -603,9 +619,13 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy {
m_buckets_data.data();
m_bucket_count = other.m_bucket_count;
m_nb_elements = other.m_nb_elements;

m_load_threshold = other.m_load_threshold;
m_max_load_factor = other.m_max_load_factor;
m_grow_on_next_insert = other.m_grow_on_next_insert;

m_min_load_factor = other.m_min_load_factor;
m_try_skrink_on_next_insert = other.m_try_skrink_on_next_insert;
}

return *this;
Expand Down Expand Up @@ -791,6 +811,8 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy {
++pos;
}

m_try_skrink_on_next_insert = true;

return pos;
}

Expand Down Expand Up @@ -847,7 +869,8 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy {
++icloser_bucket;
++ito_move_closer_value;
}


m_try_skrink_on_next_insert = true;

return iterator(m_buckets + ireturn_bucket);
}
Expand All @@ -863,6 +886,7 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy {
auto it = find(key, hash);
if(it != end()) {
erase_from_bucket(it);
m_try_skrink_on_next_insert = true;

return 1;
}
Expand All @@ -888,6 +912,8 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy {
swap(m_load_threshold, other.m_load_threshold);
swap(m_max_load_factor, other.m_max_load_factor);
swap(m_grow_on_next_insert, other.m_grow_on_next_insert);
swap(m_min_load_factor, other.m_min_load_factor);
swap(m_try_skrink_on_next_insert, other.m_try_skrink_on_next_insert);
}


Expand Down Expand Up @@ -1010,12 +1036,22 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy {
return float(m_nb_elements)/float(bucket_count());
}

float min_load_factor() const {
return m_min_load_factor;
}

float max_load_factor() const {
return m_max_load_factor;
}

void min_load_factor(float ml) {
m_min_load_factor = clamp(ml, float(MINIMUM_MIN_LOAD_FACTOR),
float(MAXIMUM_MIN_LOAD_FACTOR));
}

void max_load_factor(float ml) {
m_max_load_factor = std::max(0.1f, std::min(ml, 0.95f));
m_max_load_factor = clamp(ml, float(MINIMUM_MAX_LOAD_FACTOR),
float(MAXIMUM_MAX_LOAD_FACTOR));
m_load_threshold = size_type(float(bucket_count())*m_max_load_factor);
}

Expand Down Expand Up @@ -1150,7 +1186,7 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy {
dist_from_ideal_bucket++;
}

if(grow_on_high_load()) {
if(rehash_on_extreme_load()) {
ibucket = bucket_for_hash(hash);
dist_from_ideal_bucket = 0;

Expand Down Expand Up @@ -1233,7 +1269,7 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy {

void rehash_impl(size_type count) {
robin_hash new_table(count, static_cast<Hash&>(*this), static_cast<KeyEqual&>(*this),
get_allocator(), m_max_load_factor);
get_allocator(), m_min_load_factor, m_max_load_factor);

const bool use_stored_hash = USE_STORED_HASH_ON_REHASH(new_table.bucket_count());
for(auto& bucket: m_buckets_data) {
Expand Down Expand Up @@ -1274,23 +1310,50 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy {


/**
* Return true if the map has been rehashed.
* Grow the table if m_grow_on_next_insert is true or we reached the max_load_factor.
* Shrink the table if m_try_skrink_on_next_insert is true (an erase occured) and
* we're below the min_load_factor.
*
* Return true if the table has been rehashed.
*/
bool grow_on_high_load() {
bool rehash_on_extreme_load() {
if(m_grow_on_next_insert || size() >= m_load_threshold) {
rehash_impl(GrowthPolicy::next_bucket_count());
m_grow_on_next_insert = false;

return true;
}

if(m_try_skrink_on_next_insert) {
m_try_skrink_on_next_insert = false;
if(m_min_load_factor != 0.0f && load_factor() < m_min_load_factor) {
reserve(size() + 1);

return true;
}
}

return false;
}


public:
static const size_type DEFAULT_INIT_BUCKETS_SIZE = 0;

static constexpr float DEFAULT_MAX_LOAD_FACTOR = 0.5f;
static constexpr float MINIMUM_MAX_LOAD_FACTOR = 0.2f;
static constexpr float MAXIMUM_MAX_LOAD_FACTOR = 0.95f;

static constexpr float DEFAULT_MIN_LOAD_FACTOR = 0.0f;
static constexpr float MINIMUM_MIN_LOAD_FACTOR = 0.0f;
static constexpr float MAXIMUM_MIN_LOAD_FACTOR = 0.15f;

static_assert(MINIMUM_MAX_LOAD_FACTOR < MAXIMUM_MAX_LOAD_FACTOR,
"MINIMUM_MAX_LOAD_FACTOR should be < MAXIMUM_MAX_LOAD_FACTOR");
static_assert(MINIMUM_MIN_LOAD_FACTOR < MAXIMUM_MIN_LOAD_FACTOR,
"MINIMUM_MIN_LOAD_FACTOR should be < MAXIMUM_MIN_LOAD_FACTOR");
static_assert(MAXIMUM_MIN_LOAD_FACTOR < MINIMUM_MAX_LOAD_FACTOR,
"MAXIMUM_MIN_LOAD_FACTOR should be < MINIMUM_MAX_LOAD_FACTOR");

private:
static const distance_type REHASH_ON_HIGH_NB_PROBES__NPROBES = 128;
Expand Down Expand Up @@ -1329,6 +1392,16 @@ class robin_hash: private Hash, private KeyEqual, private GrowthPolicy {
float m_max_load_factor;

bool m_grow_on_next_insert;

float m_min_load_factor;

/**
* We can't shrink down the map on erase operations as the erase methods need to return the next iterator.
* Shrinking the map would invalidate all the iterators and we could not return the next iterator in a meaningful way,
* On erase, we thus just indicate on erase that we should try to shrink the hash table on the next insert
* if we go below the min_load_factor.
*/
bool m_try_skrink_on_next_insert;
};

}
Expand Down
14 changes: 13 additions & 1 deletion include/tsl/robin_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class robin_map {
const Hash& hash = Hash(),
const KeyEqual& equal = KeyEqual(),
const Allocator& alloc = Allocator()):
m_ht(bucket_count, hash, equal, alloc, ht::DEFAULT_MAX_LOAD_FACTOR)
m_ht(bucket_count, hash, equal, alloc)
{
}

Expand Down Expand Up @@ -600,7 +600,19 @@ class robin_map {
* Hash policy
*/
float load_factor() const { return m_ht.load_factor(); }

float min_load_factor() const { return m_ht.min_load_factor(); }
float max_load_factor() const { return m_ht.max_load_factor(); }

/**
* Set the `min_load_factor` to `ml`. When the `load_factor` of the map goes
* below `min_load_factor` after some erase operations, the map will be
* shrunk when an insertion occurs. The erase method itself never shrinks
* the map.
*
* The default value of `min_load_factor` is 0.0f, the map never shrinks by default.
*/
void min_load_factor(float ml) { m_ht.min_load_factor(ml); }
void max_load_factor(float ml) { m_ht.max_load_factor(ml); }

void rehash(size_type count) { m_ht.rehash(count); }
Expand Down
14 changes: 13 additions & 1 deletion include/tsl/robin_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class robin_set {
const Hash& hash = Hash(),
const KeyEqual& equal = KeyEqual(),
const Allocator& alloc = Allocator()):
m_ht(bucket_count, hash, equal, alloc, ht::DEFAULT_MAX_LOAD_FACTOR)
m_ht(bucket_count, hash, equal, alloc)
{
}

Expand Down Expand Up @@ -466,7 +466,19 @@ class robin_set {
* Hash policy
*/
float load_factor() const { return m_ht.load_factor(); }

float min_load_factor() const { return m_ht.min_load_factor(); }
float max_load_factor() const { return m_ht.max_load_factor(); }

/**
* Set the `min_load_factor` to `ml`. When the `load_factor` of the set goes
* below `min_load_factor` after some erase operations, the set will be
* shrunk when an insertion occurs. The erase method itself never shrinks
* the set.
*
* The default value of `min_load_factor` is 0.0f, the set never shrinks by default.
*/
void min_load_factor(float ml) { m_ht.min_load_factor(ml); }
void max_load_factor(float ml) { m_ht.max_load_factor(ml); }

void rehash(size_type count) { m_ht.rehash(count); }
Expand Down
63 changes: 63 additions & 0 deletions tests/robin_map_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,69 @@ BOOST_AUTO_TEST_CASE(test_range_erase_same_iterators) {
BOOST_CHECK_EQUAL(it_const.value(), -100);
}

/**
* max_load_factor
*/
BOOST_AUTO_TEST_CASE(test_max_load_factor_extreme_factors) {
tsl::robin_map<std::int64_t, std::int64_t> map;

map.max_load_factor(0.0f);
BOOST_CHECK_GT(map.max_load_factor(), 0.0f);

map.max_load_factor(10.0f);
BOOST_CHECK_LT(map.max_load_factor(), 1.0f);
}

/**
* min_load_factor
*/
BOOST_AUTO_TEST_CASE(test_min_load_factor_extreme_factors) {
tsl::robin_map<std::int64_t, std::int64_t> map;

BOOST_CHECK_EQUAL(map.min_load_factor(), 0.0f);
BOOST_CHECK_LT(map.min_load_factor(), map.max_load_factor());

map.min_load_factor(-10.0f);
BOOST_CHECK_EQUAL(map.min_load_factor(), 0.0f);

map.min_load_factor(0.9f);
map.max_load_factor(0.1f);

// max_load_factor should always be > min_load_factor.
// Factors should have been clamped.
BOOST_CHECK_LT(map.min_load_factor(), map.max_load_factor());
}

BOOST_AUTO_TEST_CASE(test_min_load_factor) {
tsl::robin_map<std::int64_t, std::int64_t> map;

map.min_load_factor(0.15f);
BOOST_CHECK_EQUAL(map.min_load_factor(), 0.15f);

map.max_load_factor(0.5f);
BOOST_CHECK_EQUAL(map.max_load_factor(), 0.5f);


map.rehash(100);
for(std::size_t i = 0; i < map.bucket_count()/2; i++) {
map.insert({utils::get_key<std::int64_t>(i),
utils::get_value<std::int64_t>(i)});
}

BOOST_CHECK_GT(map.load_factor(), map.min_load_factor());
BOOST_CHECK_CLOSE(map.load_factor(), 0.5f, 0.05f);


while(map.load_factor() >= map.min_load_factor()) {
map.erase(map.begin());
}

// Shrink is done on insert.
map.insert({utils::get_key<std::int64_t>(map.bucket_count()),
utils::get_value<std::int64_t>(map.bucket_count())});
BOOST_CHECK_GT(map.load_factor(), map.min_load_factor());
}

/**
* rehash
*/
Expand Down

0 comments on commit a279380

Please sign in to comment.