forked from LeelaChessZero/lc0
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnetwork_rr.cc
95 lines (75 loc) · 3.1 KB
/
network_rr.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
/*
This file is part of Leela Chess Zero.
Copyright (C) 2018-2020 The LCZero Authors
Leela Chess is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
Leela Chess is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with Leela Chess. If not, see <http://www.gnu.org/licenses/>.
Additional permission under GNU GPL version 3 section 7
If you modify this Program, or any covered work, by linking or
combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA
Toolkit and the NVIDIA CUDA Deep Neural Network library (or a
modified version of those libraries), containing parts covered by the
terms of the respective license agreement, the licensors of this
Program grant you additional permission to convey the resulting work.
*/
#include <condition_variable>
#include <queue>
#include <thread>
#include "neural/factory.h"
#include "utils/exception.h"
namespace lczero {
namespace {
class RoundRobinNetwork : public Network {
public:
RoundRobinNetwork(const std::optional<WeightsFile>& weights,
const OptionsDict& options) {
const auto parents = options.ListSubdicts();
if (parents.empty()) {
// If options are empty, or multiplexer configured in root object,
// initialize on root object and default backend.
auto backends = NetworkFactory::Get()->GetBackendsList();
AddBackend(backends[0], weights, options);
}
for (const auto& name : parents) {
AddBackend(name, weights, options.GetSubdict(name));
}
}
void AddBackend(const std::string& name,
const std::optional<WeightsFile>& weights,
const OptionsDict& opts) {
const std::string backend = opts.GetOrDefault<std::string>("backend", name);
networks_.emplace_back(
NetworkFactory::Get()->Create(backend, weights, opts));
if (networks_.size() == 1) {
capabilities_ = networks_.back()->GetCapabilities();
} else {
capabilities_.Merge(networks_.back()->GetCapabilities());
}
}
std::unique_ptr<NetworkComputation> NewComputation() override {
const long long val = ++counter_;
return networks_[val % networks_.size()]->NewComputation();
}
const NetworkCapabilities& GetCapabilities() const override {
return capabilities_;
}
~RoundRobinNetwork() {}
private:
std::vector<std::unique_ptr<Network>> networks_;
std::atomic<long long> counter_;
NetworkCapabilities capabilities_;
};
std::unique_ptr<Network> MakeRoundRobinNetwork(
const std::optional<WeightsFile>& weights, const OptionsDict& options) {
return std::make_unique<RoundRobinNetwork>(weights, options);
}
REGISTER_NETWORK("roundrobin", MakeRoundRobinNetwork, -999)
} // namespace
} // namespace lczero