Skip to content

Commit

Permalink
Netresult memory size optimization.
Browse files Browse the repository at this point in the history
* Instead of keeping a std::vector<std::pair<float,int>>,
  just keep a flat array of floats.
  Reduces NNCache per-entry max memory footprint by roughly half.

* Increased NNCache max size from 50k to 150k,
  reflecting memory size optimizations

Pull request leela-zero#1203.
  • Loading branch information
ihavnoid authored and gcp committed Apr 20, 2018
1 parent 447a366 commit 7789926
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 48 deletions.
4 changes: 2 additions & 2 deletions src/NNCache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ void NNCache::resize(int size) {
void NNCache::set_size_from_playouts(int max_playouts) {
// cache hits are generally from last several moves so setting cache
// size based on playouts increases the hit rate while balancing memory
// usage for low playout instances. 50'000 cache entries is ~250 MB
auto max_size = std::min(50'000, std::max(6'000, 3 * max_playouts));
// usage for low playout instances. 150'000 cache entries is ~225 MB
auto max_size = std::min(150'000, std::max(6'000, 3 * max_playouts));
NNCache::get_NNCache().resize(max_size);
}

Expand Down
4 changes: 2 additions & 2 deletions src/NNCache.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class NNCache {
void dump_stats();

private:
NNCache(int size = 50000); // ~ 250MB
NNCache(int size = 150000); // ~ 225MB

std::mutex m_mutex;

Expand All @@ -67,7 +67,7 @@ class NNCache {
struct Entry {
Entry( const Network::Netresult& r)
: result(r) {}
Network::Netresult result; // ~ 3KB
Network::Netresult result; // ~ 1.5KB
};

// Map from hash to {features, result}
Expand Down
60 changes: 28 additions & 32 deletions src/Network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -861,12 +861,12 @@ Network::Netresult Network::get_scored_moves(

if (ensemble == DIRECT) {
assert(rotation >= 0 && rotation <= 7);
result = get_scored_moves_internal(state, planes, rotation);
result = get_scored_moves_internal(planes, rotation);
} else {
assert(ensemble == RANDOM_ROTATION);
assert(rotation == -1);
const auto rand_rot = Random::get_Rng().randfix<8>();
result = get_scored_moves_internal(state, planes, rand_rot);
result = get_scored_moves_internal(planes, rand_rot);
}

// Insert result into cache.
Expand All @@ -876,7 +876,7 @@ Network::Netresult Network::get_scored_moves(
}

Network::Netresult Network::get_scored_moves_internal(
const GameState* const state, const NNPlanes & planes, const int rotation) {
const NNPlanes & planes, const int rotation) {
assert(rotation >= 0 && rotation <= 7);
assert(INPUT_CHANNELS == planes.size());
constexpr auto width = BOARD_SIZE;
Expand Down Expand Up @@ -930,45 +930,31 @@ Network::Netresult Network::get_scored_moves_internal(
// Sigmoid
const auto winrate_sig = (1.0f + std::tanh(winrate_out[0])) / 2.0f;

std::vector<scored_node> result;
for (auto idx = size_t{0}; idx < outputs.size(); idx++) {
if (idx < BOARD_SQUARES) {
const auto rot_idx = rotate_nn_idx_table[rotation][idx];
const auto x = rot_idx % BOARD_SIZE;
const auto y = rot_idx / BOARD_SIZE;
const auto rot_vtx = state->board.get_vertex(x, y);
if (state->board.get_square(rot_vtx) == FastBoard::EMPTY) {
result.emplace_back(outputs[idx], rot_vtx);
}
} else {
result.emplace_back(outputs[idx], FastBoard::PASS);
}
Netresult result;

for (auto idx = size_t{0}; idx < BOARD_SQUARES; idx++) {
const auto rot_idx = rotate_nn_idx_table[rotation][idx];
result.policy[rot_idx] = outputs[idx];
}

return std::make_pair(result, winrate_sig);
result.policy_pass = outputs[BOARD_SQUARES];
result.winrate = winrate_sig;

return result;
}

void Network::show_heatmap(const FastState* const state,
const Netresult& result,
const bool topmoves) {
auto moves = result.first;
std::vector<std::string> display_map;
std::string line;

for (unsigned int y = 0; y < BOARD_SIZE; y++) {
for (unsigned int x = 0; x < BOARD_SIZE; x++) {
const auto vtx = state->board.get_vertex(x, y);

const auto item = std::find_if(moves.cbegin(), moves.cend(),
[&vtx](scored_node const& test_item) {
return test_item.second == vtx;
});

auto score = 0;
// Non-empty squares won't be scored
if (item != moves.cend()) {
score = int(item->first * 1000);
assert(vtx == item->second);
const auto vertex = state->board.get_vertex(x, y);
if (state->board.get_square(vertex) == FastBoard::EMPTY) {
score = result.policy[y * BOARD_SIZE + x] * 1000;
}

line += boost::str(boost::format("%3d ") % score);
Expand All @@ -981,12 +967,22 @@ void Network::show_heatmap(const FastState* const state,
for (int i = display_map.size() - 1; i >= 0; --i) {
myprintf("%s\n", display_map[i].c_str());
}
assert(result.first.back().second == FastBoard::PASS);
const auto pass_score = int(result.first.back().first * 1000);
const auto pass_score = int(result.policy_pass * 1000);
myprintf("pass: %d\n", pass_score);
myprintf("winrate: %f\n", result.second);
myprintf("winrate: %f\n", result.winrate);

if (topmoves) {
std::vector<Network::ScoreVertexPair> moves;
for (auto i=0; i < BOARD_SQUARES; i++) {
const auto x = i % BOARD_SIZE;
const auto y = i / BOARD_SIZE;
const auto vertex = state->board.get_vertex(x, y);
if (state->board.get_square(vertex) == FastBoard::EMPTY) {
moves.emplace_back(result.policy[i], vertex);
}
}
moves.emplace_back(result.policy_pass, FastBoard::PASS);

std::stable_sort(rbegin(moves), rend(moves));

auto cum = 0.0f;
Expand Down
18 changes: 15 additions & 3 deletions src/Network.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,20 @@ class Network {
};
using BoardPlane = std::bitset<BOARD_SQUARES>;
using NNPlanes = std::vector<BoardPlane>;
using scored_node = std::pair<float, int>;
using Netresult = std::pair<std::vector<scored_node>, float>;
using ScoreVertexPair = std::pair<float,int>;

struct Netresult {
// 19x19 board positions
std::vector<float> policy;

// pass
float policy_pass;

// winrate
float winrate;

Netresult() : policy(BOARD_SQUARES), policy_pass(0.0f), winrate(0.0f) {}
};

static Netresult get_scored_moves(const GameState* const state,
const Ensemble ensemble,
Expand Down Expand Up @@ -94,7 +106,7 @@ class Network {
static void fill_input_plane_pair(
const FullBoard& board, BoardPlane& black, BoardPlane& white);
static Netresult get_scored_moves_internal(
const GameState* const state, const NNPlanes & planes, const int rotation);
const NNPlanes & planes, const int rotation);
#if defined(USE_BLAS)
static void forward_cpu(const std::vector<float>& input,
std::vector<float>& output_pol,
Expand Down
2 changes: 1 addition & 1 deletion src/Training.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ void Training::record(GameState& state, UCTNode& root) {

auto result =
Network::get_scored_moves(&state, Network::Ensemble::DIRECT, 0);
step.net_winrate = result.second;
step.net_winrate = result.winrate;

const auto& best_node = root.get_best_root_child(step.to_move);
step.root_uct_winrate = root.get_eval(step.to_move);
Expand Down
18 changes: 11 additions & 7 deletions src/UCTNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,24 +81,28 @@ bool UCTNode::create_children(std::atomic<int>& nodecount,
&state, Network::Ensemble::RANDOM_ROTATION);

// DCNN returns winrate as side to move
m_net_eval = raw_netlist.second;
m_net_eval = raw_netlist.winrate;
const auto to_move = state.board.get_to_move();
// our search functions evaluate from black's point of view
if (state.board.white_to_move()) {
m_net_eval = 1.0f - m_net_eval;
}
eval = m_net_eval;

std::vector<Network::scored_node> nodelist;
std::vector<Network::ScoreVertexPair> nodelist;

auto legal_sum = 0.0f;
for (const auto& node : raw_netlist.first) {
auto vertex = node.second;
for (auto i = 0; i < BOARD_SQUARES; i++) {
const auto x = i % BOARD_SIZE;
const auto y = i / BOARD_SIZE;
const auto vertex = state.board.get_vertex(x, y);
if (state.is_move_legal(to_move, vertex)) {
nodelist.emplace_back(node);
legal_sum += node.first;
nodelist.emplace_back(raw_netlist.policy[i], vertex);
legal_sum += raw_netlist.policy[i];
}
}
nodelist.emplace_back(raw_netlist.policy_pass, FastBoard::PASS);
legal_sum += raw_netlist.policy_pass;

if (legal_sum > std::numeric_limits<float>::min()) {
// re-normalize after removing illegal moves.
Expand All @@ -118,7 +122,7 @@ bool UCTNode::create_children(std::atomic<int>& nodecount,
}

void UCTNode::link_nodelist(std::atomic<int>& nodecount,
std::vector<Network::scored_node>& nodelist,
std::vector<Network::ScoreVertexPair>& nodelist,
float min_psa_ratio) {
assert(min_psa_ratio < m_min_psa_ratio_children);

Expand Down
2 changes: 1 addition & 1 deletion src/UCTNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class UCTNode {
ACTIVE
};
void link_nodelist(std::atomic<int>& nodecount,
std::vector<Network::scored_node>& nodelist,
std::vector<Network::ScoreVertexPair>& nodelist,
float min_psa_ratio);
double get_blackevals() const;
void accumulate_eval(float eval);
Expand Down

0 comments on commit 7789926

Please sign in to comment.