forked from leela-zero/leela-zero
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTraining.h
90 lines (75 loc) · 2.78 KB
/
Training.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
/*
This file is part of Leela Zero.
Copyright (C) 2017 Gian-Carlo Pascutto
Leela Zero is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
Leela Zero is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with Leela Zero. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef TRAINING_H_INCLUDED
#define TRAINING_H_INCLUDED
#include "config.h"
#include <cstddef>
#include <string>
#include <utility>
#include <vector>
#include "GameState.h"
#include "Network.h"
#include "UCTNode.h"
class TimeStep {
public:
Network::NNPlanes planes;
std::vector<float> probabilities;
int to_move;
float net_winrate;
float root_uct_winrate;
float child_uct_winrate;
int bestmove_visits;
};
std::ostream& operator<< (std::ostream& stream, const TimeStep& timestep);
std::istream& operator>> (std::istream& stream, TimeStep& timestep);
class OutputChunker {
public:
OutputChunker(const std::string& basename, bool compress = false);
~OutputChunker();
void append(const std::string& str);
// Group this many games in a batch.
static constexpr size_t CHUNK_SIZE = 32;
private:
std::string gen_chunk_name() const;
void flush_chunks();
size_t m_game_count{0};
size_t m_chunk_count{0};
std::string m_buffer;
std::string m_basename;
bool m_compress{false};
};
class Training {
public:
static void clear_training();
static void dump_training(int winner_color,
const std::string& out_filename);
static void dump_debug(const std::string& out_filename);
static void record(GameState& state, UCTNode& node);
static void dump_supervised(const std::string& sgf_file,
const std::string& out_filename);
static void save_training(const std::string& filename);
static void load_training(const std::string& filename);
private:
static void process_game(GameState& state, size_t& train_pos, int who_won,
const std::vector<int>& tree_moves,
OutputChunker& outchunker);
static void dump_training(int winner_color,
OutputChunker& outchunker);
static void dump_debug(OutputChunker& outchunker);
static void save_training(std::ofstream& out);
static void load_training(std::ifstream& in);
static std::vector<TimeStep> m_data;
};
#endif