Skip to content

Commit

Permalink
Small cleanup to visit counting.
Browse files Browse the repository at this point in the history
Pull request leela-zero#816.
  • Loading branch information
sethtroisi authored and gcp committed Feb 7, 2018
1 parent 0e8f1c6 commit b1e505b
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 13 deletions.
23 changes: 12 additions & 11 deletions src/UCTNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,16 +208,17 @@ void UCTNode::dirichlet_noise(float epsilon, float alpha) {
}

void UCTNode::randomize_first_proportionally() {
auto accum = std::uint32_t{0};
auto accum = 0;
auto accum_vector = std::vector<decltype(accum)>{};
for (const auto& child : m_children) {
accum += child->get_visits();
accum_vector.emplace_back(accum);
}
assert(accum + 1 == get_visits());

auto pick = Random::get_Rng().randuint32(accum);
auto pick = int(Random::get_Rng().randuint32(accum));
auto index = size_t{0};
for (size_t i = 0; i < accum_vector.size(); i++) {
for (auto i = size_t{0}; i < accum_vector.size(); i++) {
if (pick < accum_vector[i]) {
index = i;
break;
Expand Down Expand Up @@ -256,10 +257,6 @@ bool UCTNode::has_children() const {
return m_has_children;
}

void UCTNode::set_visits(int visits) {
m_visits = visits;
}

float UCTNode::get_score() const {
return m_score;
}
Expand Down Expand Up @@ -318,15 +315,19 @@ UCTNode* UCTNode::uct_select_child(int color) {

LOCK(get_mutex(), lock);

// Count parentvisits.
// We do this manually to avoid issues with transpositions.
auto parentvisits = size_t{0};
#ifndef NDEBUG
auto parentvisits = 0;
for (const auto& child : m_children) {
if (child->valid()) {
parentvisits += child->get_visits();
}
}
auto numerator = static_cast<float>(std::sqrt((double)parentvisits));
assert(get_visits() - 1 == parentvisits);
#endif

// Remove the initial visit to expand the node.
auto childvisits = get_visits() - 1;
auto numerator = static_cast<float>(std::sqrt((double)childvisits));

for (const auto& child : m_children) {
if (!child->valid()) {
Expand Down
3 changes: 1 addition & 2 deletions src/UCTNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,13 @@ class UCTNode {
void set_score(float score);
float get_eval(int tomove) const;
double get_blackevals() const;
void set_visits(int visits);
void set_blackevals(double blacevals);
void accumulate_eval(float eval);
void virtual_loss(void);
void virtual_loss_undo(void);
void dirichlet_noise(float epsilon, float alpha);
void randomize_first_proportionally();
void update(float eval = std::numeric_limits<float>::quiet_NaN());
void update(float eval);

UCTNode* uct_select_child(int color);
UCTNode* get_first_child() const;
Expand Down
1 change: 1 addition & 0 deletions src/UCTSearch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ int UCTSearch::think(int color, passflag_t passflag) {
float root_eval;
if (!m_root->has_children()) {
m_root->create_children(m_nodes, m_rootstate, root_eval);
m_root->update(root_eval);
} else {
root_eval = m_root->get_eval(color);
}
Expand Down

0 comments on commit b1e505b

Please sign in to comment.