Skip to content

Commit

Permalink
Fix NNCache sizing with playout/visits limits.
Browse files Browse the repository at this point in the history
* Limit the NNcache size when visits are limited.
* Fixed handling of -p 0 and -v 0 which mean "no limit".

Also avoid potential overflow in the size calculation.

Pull request leela-zero#1281.
  • Loading branch information
Hersmunch authored and gcp committed Apr 30, 2018
1 parent bb1fd49 commit 4b070f8
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 16 deletions.
16 changes: 15 additions & 1 deletion src/Leela.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,22 @@ static void parse_commandline(int argc, char *argv[]) {
"Add --noponder if you want a weakened engine.\n");
exit(EXIT_FAILURE);
}

// 0 may be specified to mean "no limit"
if (cfg_max_playouts == 0) {
cfg_max_playouts =
std::numeric_limits<decltype(cfg_max_playouts)>::max();
}
}

if (vm.count("visits")) {
cfg_max_visits = vm["visits"].as<int>();

// 0 may be specified to mean "no limit"
if (cfg_max_visits == 0) {
cfg_max_visits =
std::numeric_limits<decltype(cfg_max_visits)>::max();
}
}

if (vm.count("resignpct")) {
Expand Down Expand Up @@ -305,7 +317,9 @@ void init_global_objects() {
// improves reproducibility across platforms.
Random::get_Rng().seedrandom(cfg_rng_seed);

NNCache::get_NNCache().set_size_from_playouts(cfg_max_playouts);
// When visits are limited ensure cache size is still limited.
auto playouts = std::min(cfg_max_playouts, cfg_max_visits);
NNCache::get_NNCache().set_size_from_playouts(playouts);

// Initialize network
Network::initialize();
Expand Down
8 changes: 7 additions & 1 deletion src/NNCache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,13 @@ void NNCache::set_size_from_playouts(int max_playouts) {
// cache hits are generally from last several moves so setting cache
// size based on playouts increases the hit rate while balancing memory
// usage for low playout instances. 150'000 cache entries is ~225 MB
auto max_size = std::min(150'000, std::max(6'000, 3 * max_playouts));
constexpr auto num_cache_moves = 3;
auto max_playouts_per_move =
std::min(max_playouts,
std::numeric_limits<decltype(max_playouts)>::max() /
num_cache_moves);
auto max_size = num_cache_moves * max_playouts_per_move;
max_size = std::min(150'000, std::max(6'000, max_size));
NNCache::get_NNCache().resize(max_size);
}

Expand Down
22 changes: 8 additions & 14 deletions src/UCTSearch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -696,24 +696,18 @@ void UCTSearch::set_playout_limit(int playouts) {
static_assert(std::is_convertible<decltype(playouts),
decltype(m_maxplayouts)>::value,
"Inconsistent types for playout amount.");
if (playouts == 0) {
// Divide max by 2 to prevent overflow when multithreading.
m_maxplayouts = std::numeric_limits<decltype(m_maxplayouts)>::max()
/ 2;
} else {
m_maxplayouts = playouts;
}
// Limit to type max / 2 to prevent overflow when multithreading.
m_maxplayouts =
std::min(playouts,
std::numeric_limits<decltype(m_maxplayouts)>::max() / 2);
}

void UCTSearch::set_visit_limit(int visits) {
static_assert(std::is_convertible<decltype(visits),
decltype(m_maxvisits)>::value,
"Inconsistent types for visits amount.");
if (visits == 0) {
// Divide max by 2 to prevent overflow when multithreading.
m_maxvisits = std::numeric_limits<decltype(m_maxvisits)>::max()
/ 2;
} else {
m_maxvisits = visits;
}
// Limit to type max / 2 to prevent overflow when multithreading.
m_maxvisits =
std::min(visits,
std::numeric_limits<decltype(m_maxvisits)>::max() / 2);
}

0 comments on commit 4b070f8

Please sign in to comment.