Skip to content

Commit

Permalink
Merge pull request pytorch#44 from jma127/master
Browse files Browse the repository at this point in the history
Misc logging fixes and migrate game_context output
  • Loading branch information
jma127 authored May 11, 2018
2 parents bdda539 + 353a505 commit 9621c27
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 61 deletions.
5 changes: 3 additions & 2 deletions src_cpp/elf/logging/IndexedLoggerFactory.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
*
* private:
* static IndexedLoggerFactory* getLoggerFactory() {
* IndexedLoggerFactory factory([](const std::string& name) {
* return spdlog::stdout_color_mt(name);
* static IndexedLoggerFactory factory([](const std::string& name) {
* return spdlog::stderr_color_mt(name);
* });
* return &factory;
* }
*
* std::shared_ptr<spdlog::logger> logger_;
Expand Down
109 changes: 53 additions & 56 deletions src_cpp/elfgames/go/game_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@

#pragma once

// TODO: Figure out how to remove this (ssengupta@fb)
#include <time.h>

#include <iostream>
#include <memory>
#include <vector>

// TODO: Figure out how to remove this (ssengupta@fb)
#include <time.h>

#include "base/board_feature.h"
#include "data_loader.h"
#include "elf/base/context.h"
#include "elf/legacy/python_options_utils_cpp.h"
#include "elf/logging/IndexedLoggerFactory.h"
#include "game_selfplay.h"
#include "game_train.h"
#include "mcts/ai.h"
Expand All @@ -29,7 +30,10 @@
class GameContext {
public:
GameContext(const ContextOptions& contextOptions, const GameOptions& options)
: contextOptions_(contextOptions), goFeature_(options) {
: contextOptions_(contextOptions),
goFeature_(options),
logger_(
getLoggerFactory()->makeLogger("elfgames::go::GameContext-", "")) {
context_.reset(new elf::Context);

auto netOptions = getNetOptions(contextOptions, options);
Expand All @@ -41,11 +45,11 @@ class GameContext {

if (options.mode == "selfplay") {
writer_.reset(new elf::shared::Writer(netOptions));
std::cout << writer_->info() << std::endl;
logger_->info("Writer info: {}", writer_->info());
evalCtrl_.reset(new EvalCtrl(
context_->getClient(), writer_.get(), options, numGames));

std::cout << "Send ctrl " << currTimestamp << std::endl;
logger_->info("Send ctrl with timestamp {}", currTimestamp);
writer_->Ctrl(std::to_string(currTimestamp));
} else if (options.mode == "online") {
evalCtrl_.reset(new EvalCtrl(
Expand All @@ -60,10 +64,12 @@ class GameContext {
const std::string& identity,
std::string* msg) -> bool {
(void)reader;
// cout << "Replier: before sending msg to " << identity << endl;
trainCtrl_->onReply(identity, msg);
// cout << "Replier: about to send to " << identity << ", msg = " <<
// *msg << endl; cout << reader_->info() << endl;
logger_->info(
"Replier: about to send: recipient {}; msg {}; reader {}",
identity,
*msg,
reader_->info());
return true;
};

Expand All @@ -75,13 +81,12 @@ class GameContext {
offlineLoader_.reset(
new DataOfflineLoaderJSON(*reader_, options.list_files));
offlineLoader_->start();
std::cout << reader_->info() << std::endl;
logger_->info("Offline train; reader info {}", reader_->info());
trainCtrl_->RegRecordSender();
performTraining = true;

} else {
std::cout << "Option.mode not recognized!" << options.mode << std::endl;
throw std::range_error("Option.mode not recognized! " + options.mode);
throw std::range_error("options.mode not recognized! " + options.mode);
}

const int batchsize = contextOptions.batchsize;
Expand Down Expand Up @@ -126,8 +131,7 @@ class GameContext {

const GoGameBase* getGame(int game_idx) const {
if (_check_game_idx(game_idx)) {
std::cout << "Invalid game_idx [" + std::to_string(game_idx) + "]"
<< std::endl;
logger_->error("Invalid game_idx [{}]", game_idx);
return nullptr;
}

Expand Down Expand Up @@ -176,50 +180,17 @@ class GameContext {
}

~GameContext() {
// cout << "Ending train ctrl" << endl;
trainCtrl_.reset(nullptr);

// cout << "Ending eval ctrl" << endl;
evalCtrl_.reset(nullptr);

// cout << "Ending offline loader" << endl;
offlineLoader_.reset(nullptr);

// cout << "Ending online loader" << endl;
onlineLoader_.reset(nullptr);

// cout << "Ending reader" << endl;
reader_.reset(nullptr);

// cout << "Ending writer" << endl;
writer_.reset(nullptr);

// cout << "Ending games" << endl;
games_.clear();

// cout << "Ending context" << endl;
context_.reset(nullptr);

// cout << "Finish all ..." << endl;
}

private:
std::unique_ptr<elf::Context> context_;
std::vector<std::unique_ptr<GoGameBase>> games_;

ContextOptions contextOptions_;

std::unique_ptr<TrainCtrl> trainCtrl_;
std::unique_ptr<EvalCtrl> evalCtrl_;

std::unique_ptr<elf::shared::Writer> writer_;
std::unique_ptr<elf::shared::ReaderQueuesT<Record>> reader_;

std::unique_ptr<DataOfflineLoaderJSON> offlineLoader_;
std::unique_ptr<DataOnlineLoader> onlineLoader_;

GoFeature goFeature_;

elf::shared::Options getNetOptions(
const ContextOptions& contextOptions,
const GameOptions& options) {
Expand All @@ -245,11 +216,11 @@ class GameContext {
for (size_t k = 0; k * numThreads + idx < options.list_files.size();
++k) {
const std::string& f = options.list_files[k * numThreads + idx];
std::cout << "Load offline data: Reading: " << f << std::endl;
logger_->info("Loading offline data, reading file {}", f);

std::vector<Record> records;
if (!Record::loadBatchFromJsonFile(f, &records)) {
std::cout << "Offline data loading: Error reading " << f << std::endl;
logger_->error("Offline data loader: error reading {}", f);
return;
}

Expand Down Expand Up @@ -283,10 +254,12 @@ class GameContext {
t.join();
}

std::cout << "All offline data are loaded. #record read: " << count
<< " from " << options.list_files.size() << " files."
<< std::endl;
std::cout << reader_->info() << std::endl;
logger_->info(
"All offline data is loaded. Read {} records from {} files. Reader "
"info {}",
count,
options.list_files.size(),
reader_->info());
}

void initReader(
Expand All @@ -307,8 +280,7 @@ class GameContext {
rs->clear();
return true;
} catch (...) {
std::cout << "Data malformed! ..." << std::endl;
std::cout << s << std::endl;
logger_->error("Data malformed! String is {}", s);
return false;
}
};
Expand All @@ -317,10 +289,35 @@ class GameContext {
trainCtrl_.reset(new TrainCtrl(
numGames, context_->getClient(), reader_.get(), options, mcts_opt));
reader_->setConverter(converter);
std::cout << reader_->info() << std::endl;
logger_->info("Finished initializing reader {}", reader_->info());
}

bool _check_game_idx(int game_idx) const {
return game_idx < 0 || game_idx >= (int)games_.size();
}

static elf::logging::IndexedLoggerFactory* getLoggerFactory() {
static elf::logging::IndexedLoggerFactory factory(
[](const std::string& name) { return spdlog::stderr_color_mt(name); });
return &factory;
}

private:
std::unique_ptr<elf::Context> context_;
std::vector<std::unique_ptr<GoGameBase>> games_;

ContextOptions contextOptions_;

std::unique_ptr<TrainCtrl> trainCtrl_;
std::unique_ptr<EvalCtrl> evalCtrl_;

std::unique_ptr<elf::shared::Writer> writer_;
std::unique_ptr<elf::shared::ReaderQueuesT<Record>> reader_;

std::unique_ptr<DataOfflineLoaderJSON> offlineLoader_;
std::unique_ptr<DataOnlineLoader> onlineLoader_;

GoFeature goFeature_;

std::shared_ptr<spdlog::logger> logger_;
};
2 changes: 1 addition & 1 deletion src_py/elfgames/go/mcts_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


_logger_factory = logging.IndexedLoggerFactory(
lambda name: logging.stdout_color_mt(name))
lambda name: logging.stderr_color_mt(name))


class MCTSPrediction(object):
Expand Down
4 changes: 2 additions & 2 deletions src_py/rlpytorch/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


_logger_factory = logging.IndexedLoggerFactory(
lambda name: logging.stdout_color_mt(name))
lambda name: logging.stderr_color_mt(name))


def load_module(mod):
Expand Down Expand Up @@ -190,7 +190,7 @@ def _on_get_args(self, *args, **kwargs):
DeprecationWarning)


_load_env_logger = logging.stdout_color_mt('rlpytorch.model_loader.load_env')
_load_env_logger = logging.stderr_color_mt('rlpytorch.model_loader.load_env')


def load_env(
Expand Down

0 comments on commit 9621c27

Please sign in to comment.