forked from Ttl/leela-zero
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathUCTNode.h
103 lines (88 loc) · 3.11 KB
/
UCTNode.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
/*
This file is part of Leela Zero.
Copyright (C) 2017 Gian-Carlo Pascutto
Leela Zero is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
Leela Zero is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with Leela Zero. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef UCTNODE_H_INCLUDED
#define UCTNODE_H_INCLUDED
#include "config.h"
#include <tuple>
#include <atomic>
#include <limits>
#include "SMP.h"
#include "GameState.h"
#include "Network.h"
class UCTNode {
public:
using sortnode_t = std::tuple<float, int, float, UCTNode*>;
// When we visit a node, add this amount of virtual losses
// to it to encourage other CPUs to explore other parts of the
// search tree.
static constexpr auto VIRTUAL_LOSS_COUNT = 3;
explicit UCTNode(int vertex, float score);
~UCTNode();
bool first_visit() const;
bool has_children() const;
bool create_children(std::atomic<int> & nodecount,
GameState & state, float & eval);
void kill_superkos(KoState & state);
void delete_child(UCTNode * child);
void invalidate();
bool valid() const;
int get_move() const;
int get_visits() const;
float get_score() const;
void set_score(float score);
float get_eval(int tomove) const;
double get_blackevals() const;
void set_visits(int visits);
void set_blackevals(double blacevals);
void set_eval(float eval);
void accumulate_eval(float eval);
void virtual_loss(void);
void virtual_loss_undo(void);
void dirichlet_noise(float epsilon, float alpha);
void randomize_first_proportionally();
void update(float eval = std::numeric_limits<float>::quiet_NaN());
UCTNode* uct_select_child(int color);
UCTNode* get_first_child() const;
UCTNode* get_pass_child() const;
UCTNode* get_nopass_child(FastState& state) const;
UCTNode* get_sibling() const;
void sort_root_children(int color);
void sort_children();
SMP::Mutex & get_mutex();
private:
UCTNode();
void link_child(UCTNode * newchild);
void link_nodelist(std::atomic<int> & nodecount,
std::vector<Network::scored_node> & nodelist);
// Tree data
std::atomic<bool> m_has_children{false};
UCTNode* m_firstchild{nullptr};
UCTNode* m_nextsibling{nullptr};
// Move
int m_move;
// UCT
std::atomic<int> m_visits{0};
std::atomic<int> m_virtual_loss{0};
// UCT eval
float m_score;
std::atomic<double> m_blackevals{0};
// node alive (not superko)
std::atomic<bool> m_valid{true};
// Is someone adding scores to this node?
// We don't need to unset this.
bool m_is_expanding{false};
SMP::Mutex m_nodemutex;
};
#endif