forked from LeelaChessZero/lc0
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfactory.h
140 lines (114 loc) · 5.15 KB
/
factory.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
/*
This file is part of Leela Chess Zero.
Copyright (C) 2018-2020 The LCZero Authors
Leela Chess is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
Leela Chess is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with Leela Chess. If not, see <http://www.gnu.org/licenses/>.
Additional permission under GNU GPL version 3 section 7
If you modify this Program, or any covered work, by linking or
combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA
Toolkit and the NVIDIA CUDA Deep Neural Network library (or a
modified version of those libraries), containing parts covered by the
terms of the respective license agreement, the licensors of this
Program grant you additional permission to convey the resulting work.
*/
#pragma once
#include <functional>
#include <optional>
#include <string>
#include "neural/loader.h"
#include "neural/network.h"
#include "utils/optionsdict.h"
#include "utils/optionsparser.h"
namespace lczero {
class NetworkFactory {
public:
using FactoryFunc = std::function<std::unique_ptr<Network>(
const std::optional<WeightsFile>&, const OptionsDict&)>;
static NetworkFactory* Get();
// Registers network so it can be created by name.
// @name -- name
// @options -- options to pass to the network
// @priority -- how high should be the network in the list. The network with
// the highest priority is the default.
class Register {
public:
Register(const std::string& name, FactoryFunc factory, int priority = 0);
};
// Add the network/backend parameters to the options dictionary.
static void PopulateOptions(OptionsParser* options);
// Returns list of backend names, sorted by priority (higher priority first).
std::vector<std::string> GetBackendsList() const;
// Creates a backend given name and config.
std::unique_ptr<Network> Create(const std::string& network,
const std::optional<WeightsFile>&,
const OptionsDict& options);
// Helper function to load the network from the options. Returns nullptr
// if no network options changed since the previous call.
static std::unique_ptr<Network> LoadNetwork(const OptionsDict& options);
// Parameter IDs.
static const OptionId kWeightsId;
static const OptionId kBackendId;
static const OptionId kBackendOptionsId;
struct BackendConfiguration {
BackendConfiguration() = default;
BackendConfiguration(const OptionsDict& options);
std::string weights_path;
std::string backend;
std::string backend_options;
bool operator==(const BackendConfiguration& other) const;
bool operator!=(const BackendConfiguration& other) const {
return !operator==(other);
}
bool operator<(const BackendConfiguration& other) const {
return std::tie(weights_path, backend, backend_options) <
std::tie(other.weights_path, other.backend, other.backend_options);
}
};
private:
void RegisterNetwork(const std::string& name, FactoryFunc factory,
int priority);
NetworkFactory() {}
struct Factory {
Factory(const std::string& name, FactoryFunc factory, int priority)
: name(name), factory(factory), priority(priority) {}
bool operator<(const Factory& other) const {
if (priority != other.priority) return priority > other.priority;
return name < other.name;
}
std::string name;
FactoryFunc factory;
int priority;
};
std::vector<Factory> factories_;
friend class Register;
};
#define REGISTER_NETWORK_WITH_COUNTER2(name, func, priority, counter) \
namespace { \
static NetworkFactory::Register regH38fhs##counter( \
name, \
[](const std::optional<WeightsFile>& w, const OptionsDict& o) { \
return func(w, o); \
}, \
priority); \
}
#define REGISTER_NETWORK_WITH_COUNTER(name, func, priority, counter) \
REGISTER_NETWORK_WITH_COUNTER2(name, func, priority, counter)
// Registers a Network.
// Constructor of a network class must have parameters:
// (const Weights& w, const OptionsDict& o)
// @name -- name under which the backend will be known in configs.
// @func -- Factory function for a backend.
// std::unique_ptr<Network>(const WeightsFile&, const OptionsDict&)
// @priority -- numeric priority of a backend. Higher is higher, highest number
// is the default backend.
#define REGISTER_NETWORK(name, func, priority) \
REGISTER_NETWORK_WITH_COUNTER(name, func, priority, __LINE__)
} // namespace lczero