Skip to content

Commit

Permalink
Merge pull request pytorch#43 from jma127/master
Browse files Browse the repository at this point in the history
Formatting for go game_context
  • Loading branch information
jma127 authored May 10, 2018
2 parents e3e8abb + 54f8f83 commit bdda539
Showing 1 changed file with 107 additions and 107 deletions.
214 changes: 107 additions & 107 deletions src_cpp/elfgames/go/game_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,100 +28,100 @@

class GameContext {
public:
GameContext(const ContextOptions& context_options, const GameOptions& options)
: _context_options(context_options), _go_feature(options) {
_context.reset(new elf::Context);
GameContext(const ContextOptions& contextOptions, const GameOptions& options)
: contextOptions_(contextOptions), goFeature_(options) {
context_.reset(new elf::Context);

auto net_options = get_net_options(context_options, options);
auto curr_timestamp = time(NULL);
auto netOptions = getNetOptions(contextOptions, options);
auto currTimestamp = time(NULL);

bool perform_training = false;
bool performTraining = false;

int num_games = context_options.num_games;
int numGames = contextOptions.num_games;

if (options.mode == "selfplay") {
_writer.reset(new elf::shared::Writer(net_options));
std::cout << _writer->info() << std::endl;
_eval_ctrl.reset(new EvalCtrl(
_context->getClient(), _writer.get(), options, num_games));
writer_.reset(new elf::shared::Writer(netOptions));
std::cout << writer_->info() << std::endl;
evalCtrl_.reset(new EvalCtrl(
context_->getClient(), writer_.get(), options, numGames));

std::cout << "Send ctrl " << curr_timestamp << std::endl;
_writer->Ctrl(std::to_string(curr_timestamp));
std::cout << "Send ctrl " << currTimestamp << std::endl;
writer_->Ctrl(std::to_string(currTimestamp));
} else if (options.mode == "online") {
_eval_ctrl.reset(new EvalCtrl(
_context->getClient(), _writer.get(), options, num_games));
evalCtrl_.reset(new EvalCtrl(
context_->getClient(), writer_.get(), options, numGames));
} else if (options.mode == "train") {
init_reader(num_games, options, context_options.mcts_options);
_online_loader.reset(new DataOnlineLoader(*_reader, net_options));
initReader(numGames, options, contextOptions.mcts_options);
onlineLoader_.reset(new DataOnlineLoader(*reader_, netOptions));

auto start_func = [&]() { _train_ctrl->RegRecordSender(); };
auto start_func = [&]() { trainCtrl_->RegRecordSender(); };

auto replier = [&](elf::shared::Reader* reader,
const std::string& identity,
std::string* msg) -> bool {
(void)reader;
// cout << "Replier: before sending msg to " << identity << endl;
_train_ctrl->onReply(identity, msg);
trainCtrl_->onReply(identity, msg);
// cout << "Replier: about to send to " << identity << ", msg = " <<
// *msg << endl; cout << _reader->info() << endl;
// *msg << endl; cout << reader_->info() << endl;
return true;
};

_online_loader->start(start_func, replier);
perform_training = true;
onlineLoader_->start(start_func, replier);
performTraining = true;

} else if (options.mode == "offline_train") {
init_reader(num_games, options, context_options.mcts_options);
_offline_loader.reset(
new DataOfflineLoaderJSON(*_reader, options.list_files));
_offline_loader->start();
std::cout << _reader->info() << std::endl;
_train_ctrl->RegRecordSender();
perform_training = true;
initReader(numGames, options, contextOptions.mcts_options);
offlineLoader_.reset(
new DataOfflineLoaderJSON(*reader_, options.list_files));
offlineLoader_->start();
std::cout << reader_->info() << std::endl;
trainCtrl_->RegRecordSender();
performTraining = true;

} else {
std::cout << "Option.mode not recognized!" << options.mode << std::endl;
throw std::range_error("Option.mode not recognized! " + options.mode);
}

const int batchsize = context_options.batchsize;
const int batchsize = contextOptions.batchsize;

// Register all functions.
_go_feature.registerExtractor(batchsize, _context->getExtractor());
goFeature_.registerExtractor(batchsize, context_->getExtractor());

if (perform_training) {
for (int i = 0; i < num_games; ++i) {
_games.emplace_back(new GoGameTrain(
if (performTraining) {
for (int i = 0; i < numGames; ++i) {
games_.emplace_back(new GoGameTrain(
i,
_context->getClient(),
context_options,
context_->getClient(),
contextOptions,
options,
_train_ctrl.get(),
_reader.get()));
trainCtrl_.get(),
reader_.get()));
}
} else {
for (int i = 0; i < num_games; ++i) {
_games.emplace_back(new GoGameSelfPlay(
for (int i = 0; i < numGames; ++i) {
games_.emplace_back(new GoGameSelfPlay(
i,
_context->getClient(),
context_options,
context_->getClient(),
contextOptions,
options,
_eval_ctrl.get()));
evalCtrl_.get()));
}
}

_context->setStartCallback(num_games, [this](int i, elf::GameClient*) {
if (_eval_ctrl != nullptr)
_eval_ctrl->RegGame(i);
_games[i]->mainLoop();
context_->setStartCallback(numGames, [this](int i, elf::GameClient*) {
if (evalCtrl_ != nullptr)
evalCtrl_->RegGame(i);
games_[i]->mainLoop();
});

_context->setCBAfterGameStart(
[this, options]() { load_offline_selfplay_data(options); });
context_->setCBAfterGameStart(
[this, options]() { loadOfflineSelfplayData(options); });
}

std::map<std::string, int> getParams() const {
return _go_feature.getParams();
return goFeature_.getParams();
}

const GoGameBase* getGame(int game_idx) const {
Expand All @@ -131,120 +131,120 @@ class GameContext {
return nullptr;
}

return _games[game_idx].get();
return games_[game_idx].get();
}

GameStats* getGameStats() {
return &_eval_ctrl->getGameStats();
return &evalCtrl_->getGameStats();
}

void waitForSufficientSelfplay(int64_t selfplay_ver) {
_train_ctrl->waitForSufficientSelfplay(selfplay_ver);
trainCtrl_->waitForSufficientSelfplay(selfplay_ver);
}

// Used in training side.
void notifyNewVersion(int64_t selfplay_ver, int64_t new_version) {
_train_ctrl->addNewModelForEvaluation(selfplay_ver, new_version);
trainCtrl_->addNewModelForEvaluation(selfplay_ver, new_version);
}

void setInitialVersion(int64_t init_version) {
_train_ctrl->setInitialVersion(init_version);
trainCtrl_->setInitialVersion(init_version);
}

void setEvalMode(int64_t new_ver, int64_t old_ver) {
_train_ctrl->setEvalMode(new_ver, old_ver);
trainCtrl_->setEvalMode(new_ver, old_ver);
}

// Used in client side.
void setRequest(
int64_t black_ver,
int64_t white_ver,
float thres,
int num_thread = -1) {
int numThreads = -1) {
MsgRequest request;
request.vers.black_ver = black_ver;
request.vers.white_ver = white_ver;
request.vers.mcts_opt = _context_options.mcts_options;
request.vers.mcts_opt = contextOptions_.mcts_options;
request.client_ctrl.black_resign_thres = thres;
request.client_ctrl.white_resign_thres = thres;
request.client_ctrl.num_game_thread_used = num_thread;
_eval_ctrl->sendRequest(request);
request.client_ctrl.num_game_thread_used = numThreads;
evalCtrl_->sendRequest(request);
}

elf::Context* ctx() {
return _context.get();
return context_.get();
}

~GameContext() {
// cout << "Ending train ctrl" << endl;
_train_ctrl.reset(nullptr);
trainCtrl_.reset(nullptr);

// cout << "Ending eval ctrl" << endl;
_eval_ctrl.reset(nullptr);
evalCtrl_.reset(nullptr);

// cout << "Ending offline loader" << endl;
_offline_loader.reset(nullptr);
offlineLoader_.reset(nullptr);

// cout << "Ending online loader" << endl;
_online_loader.reset(nullptr);
onlineLoader_.reset(nullptr);

// cout << "Ending reader" << endl;
_reader.reset(nullptr);
reader_.reset(nullptr);

// cout << "Ending writer" << endl;
_writer.reset(nullptr);
writer_.reset(nullptr);

// cout << "Ending games" << endl;
_games.clear();
games_.clear();

// cout << "Ending context" << endl;
_context.reset(nullptr);
context_.reset(nullptr);

// cout << "Finish all ..." << endl;
}

private:
std::unique_ptr<elf::Context> _context;
std::vector<std::unique_ptr<GoGameBase>> _games;
std::unique_ptr<elf::Context> context_;
std::vector<std::unique_ptr<GoGameBase>> games_;

ContextOptions _context_options;
ContextOptions contextOptions_;

std::unique_ptr<TrainCtrl> _train_ctrl;
std::unique_ptr<EvalCtrl> _eval_ctrl;
std::unique_ptr<TrainCtrl> trainCtrl_;
std::unique_ptr<EvalCtrl> evalCtrl_;

std::unique_ptr<elf::shared::Writer> _writer;
std::unique_ptr<elf::shared::ReaderQueuesT<Record>> _reader;
std::unique_ptr<elf::shared::Writer> writer_;
std::unique_ptr<elf::shared::ReaderQueuesT<Record>> reader_;

std::unique_ptr<DataOfflineLoaderJSON> _offline_loader;
std::unique_ptr<DataOnlineLoader> _online_loader;
std::unique_ptr<DataOfflineLoaderJSON> offlineLoader_;
std::unique_ptr<DataOnlineLoader> onlineLoader_;

GoFeature _go_feature;
GoFeature goFeature_;

elf::shared::Options get_net_options(
const ContextOptions& context_options,
elf::shared::Options getNetOptions(
const ContextOptions& contextOptions,
const GameOptions& options) {
elf::shared::Options net_options;
net_options.addr =
elf::shared::Options netOptions;
netOptions.addr =
options.server_addr == "" ? "localhost" : options.server_addr;
net_options.port = options.port;
net_options.use_ipv6 = true;
net_options.verbose = options.verbose;
net_options.identity = context_options.job_id;
netOptions.port = options.port;
netOptions.use_ipv6 = true;
netOptions.verbose = options.verbose;
netOptions.identity = contextOptions.job_id;

return net_options;
return netOptions;
}

void load_offline_selfplay_data(const GameOptions& options) {
void loadOfflineSelfplayData(const GameOptions& options) {
if (options.list_files.empty())
return;

std::atomic<int> count(0);
const size_t num_thread = 16;
const size_t numThreads = 16;

auto thread_main = [&options, this, &count, num_thread](size_t idx) {
for (size_t k = 0; k * num_thread + idx < options.list_files.size();
auto thread_main = [&options, this, &count, numThreads](size_t idx) {
for (size_t k = 0; k * numThreads + idx < options.list_files.size();
++k) {
const std::string& f = options.list_files[k * num_thread + idx];
const std::string& f = options.list_files[k * numThreads + idx];
std::cout << "Load offline data: Reading: " << f << std::endl;

std::vector<Record> records;
Expand All @@ -257,25 +257,25 @@ class GameContext {
r.offline = true;
}

std::vector<FeedResult> res = _train_ctrl->onSelfplayGames(records);
std::vector<FeedResult> res = trainCtrl_->onSelfplayGames(records);

std::mt19937 rng(time(NULL));

// If the record does not fit in _train_ctrl,
// If the record does not fit in trainCtrl_,
// we should just send it directly to the replay buffer.
for (size_t i = 0; i < records.size(); ++i) {
if (res[i] == FeedResult::FEEDED ||
res[i] == FeedResult::VERSION_MISMATCH) {
bool black_win = records[i].result.reward > 0;
_reader->InsertWithParity(std::move(records[i]), &rng, black_win);
reader_->InsertWithParity(std::move(records[i]), &rng, black_win);
count++;
}
}
}
};

std::vector<std::thread> threads;
for (size_t i = 0; i < num_thread; ++i) {
for (size_t i = 0; i < numThreads; ++i) {
threads.emplace_back(std::bind(thread_main, i));
}

Expand All @@ -286,11 +286,11 @@ class GameContext {
std::cout << "All offline data are loaded. #record read: " << count
<< " from " << options.list_files.size() << " files."
<< std::endl;
std::cout << _reader->info() << std::endl;
std::cout << reader_->info() << std::endl;
}

void init_reader(
int num_games,
void initReader(
int numGames,
const GameOptions& options,
const elf::ai::tree_search::TSOptions& mcts_opt) {
elf::shared::RQCtrl ctrl;
Expand All @@ -303,7 +303,7 @@ class GameContext {
if (rs == nullptr)
return false;
try {
_train_ctrl->onReceive(s);
trainCtrl_->onReceive(s);
rs->clear();
return true;
} catch (...) {
Expand All @@ -313,14 +313,14 @@ class GameContext {
}
};

_reader.reset(new elf::shared::ReaderQueuesT<Record>(ctrl));
_train_ctrl.reset(new TrainCtrl(
num_games, _context->getClient(), _reader.get(), options, mcts_opt));
_reader->setConverter(converter);
std::cout << _reader->info() << std::endl;
reader_.reset(new elf::shared::ReaderQueuesT<Record>(ctrl));
trainCtrl_.reset(new TrainCtrl(
numGames, context_->getClient(), reader_.get(), options, mcts_opt));
reader_->setConverter(converter);
std::cout << reader_->info() << std::endl;
}

bool _check_game_idx(int game_idx) const {
return game_idx < 0 || game_idx >= (int)_games.size();
return game_idx < 0 || game_idx >= (int)games_.size();
}
};

0 comments on commit bdda539

Please sign in to comment.