Skip to content

Commit

Permalink
Merge pull request pytorch#81 from GuixingLin/logger
Browse files Browse the repository at this point in the history
cleanup: replace cout with logger
  • Loading branch information
jma127 authored Aug 10, 2018
2 parents a0ee179 + bd1afe6 commit b92efc7
Show file tree
Hide file tree
Showing 39 changed files with 473 additions and 397 deletions.
16 changes: 11 additions & 5 deletions src_cpp/elf/ai/tree_search/mcts.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <vector>

#include "elf/ai/ai.h"
#include "elf/logging/IndexedLoggerFactory.h"
#include "elf/utils/member_check.h"
#include "elf/utils/utils.h"

Expand All @@ -40,7 +41,9 @@ class MCTSAI_T : public AI_T<typename Actor::State, typename Actor::Action> {
MCTSAI_T(
const elf::ai::tree_search::TSOptions& options,
std::function<Actor*(int)> gen)
: options_(options) {
: options_(options),
logger_(
elf::logging::getLogger("elf::ai::tree_search::MCTSAI_T-", "")) {
ts_.reset(new TreeSearch(options_, gen));
}

Expand All @@ -62,10 +65,12 @@ class MCTSAI_T : public AI_T<typename Actor::State, typename Actor::Action> {
lastResult_ = ts_->run(s);

clock.record("MCTS");
std::cout << "[" << this->getID()
<< "] MCTSAI Result: " << lastResult_.info()
<< " Action:" << lastResult_.best_action << std::endl;
std::cout << clock.summary() << std::endl;
logger_->info(
"[{}] MCTSAI Result: {} Action: {}\n{}",
this->getID(),
lastResult_.info(),
lastResult_.best_action,
clock.summary());
} else {
lastResult_ = ts_->run(s);
}
Expand Down Expand Up @@ -125,6 +130,7 @@ class MCTSAI_T : public AI_T<typename Actor::State, typename Actor::Action> {
std::unique_ptr<TreeSearch> ts_;
size_t nextMoveNumber_ = 0;
MCTSResult lastResult_;
std::shared_ptr<spdlog::logger> logger_;

void resetTree() {
ts_->clear();
Expand Down
20 changes: 15 additions & 5 deletions src_cpp/elf/ai/tree_search/tree_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "elf/comm/primitive.h"
#include "elf/concurrency/ConcurrentQueue.h"
#include "elf/concurrency/Counter.h"
#include "elf/logging/IndexedLoggerFactory.h"
#include "elf/utils/member_check.h"

#include "tree_search_node.h"
Expand Down Expand Up @@ -67,7 +68,11 @@ class TreeSearchSingleThreadT {
using SearchTree = SearchTreeT<State, Action>;

TreeSearchSingleThreadT(int thread_id, const TSOptions& options)
: threadId_(thread_id), options_(options) {
: threadId_(thread_id),
options_(options),
logger_(elf::logging::getLogger(
"elf::ai::tree_search::TreeSearchSingleThreadT-",
"")) {
if (options_.verbose) {
std::string log_file =
options_.log_prefix + std::to_string(thread_id) + ".txt";
Expand All @@ -91,7 +96,7 @@ class TreeSearchSingleThreadT {
Node* root = search_tree.getRootNode();
if (root == nullptr || root->getStatePtr() == nullptr) {
if (stop_search == nullptr || !stop_search->load()) {
std::cout << "[" << threadId_ << "] root node is nullptr!" << std::endl;
logger_->info("[{}] root node is nullptr!", threadId_);
}
return false;
}
Expand Down Expand Up @@ -132,6 +137,8 @@ class TreeSearchSingleThreadT {
elf::concurrency::ConcurrentQueue<int> runInfoWhenStateReady_;
std::unique_ptr<std::ostream> output_;

std::shared_ptr<spdlog::logger> logger_;

MEMBER_FUNC_CHECK(reward)
template <
typename Actor,
Expand Down Expand Up @@ -325,13 +332,15 @@ class TreeSearchT {
using MCTSResult = MCTSResultT<Action>;

TreeSearchT(const TSOptions& options, std::function<Actor*(int)> actor_gen)
: options_(options), stopSearch_(false) {
: options_(options),
stopSearch_(false),
logger_(
elf::logging::getLogger("elf::ai::tree_search::TreeSearchT-", "")) {
for (int i = 0; i < options.num_threads; ++i) {
treeSearches_.emplace_back(new TreeSearchSingleThread(i, options_));
actors_.emplace_back(actor_gen(i));
}

// cout << "#Thread: " << options.num_threads << endl;
for (int i = 0; i < options.num_threads; ++i) {
TreeSearchSingleThread* th = treeSearches_[i].get();
threadPool_.emplace_back(std::thread{[i, this, th]() {
Expand Down Expand Up @@ -457,6 +466,8 @@ class TreeSearchT {
elf::concurrency::Counter<size_t> treeReady_;
elf::concurrency::Counter<size_t> countStoppedThreads_;

std::shared_ptr<spdlog::logger> logger_;

void notifySearches(int num_rollout) {
for (size_t i = 0; i < treeSearches_.size(); ++i) {
treeSearches_[i]->notifyReady(num_rollout);
Expand All @@ -483,7 +494,6 @@ class TreeSearchT {
MCTSResult chooseAction() const {
const Node* root = searchTree_.getRootNode();
if (root == nullptr) {
std::cout << "TreeSearch::root cannot be null!" << std::endl;
throw std::range_error("TreeSearch::root cannot be null!");
}

Expand Down
12 changes: 8 additions & 4 deletions src_cpp/elf/ai/tree_search/tree_search_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <nlohmann/json.hpp>

#include "elf/logging/IndexedLoggerFactory.h"
#include "elf/utils/utils.h"

using json = nlohmann::json;
Expand Down Expand Up @@ -108,12 +109,16 @@ struct EdgeInfo {
int num_visits;
float virtual_loss;

std::shared_ptr<spdlog::logger> logger_;

EdgeInfo(float probability)
: prior_probability(probability),
child_node(InvalidNodeId),
reward(0),
num_visits(0),
virtual_loss(0) {}
virtual_loss(0),
logger_(
elf::logging::getLogger("elf::ai::tree_search::EdgeInfo-", "")) {}

float getQSA() const {
return reward / num_visits;
Expand All @@ -123,9 +128,8 @@ struct EdgeInfo {
void checkValid() const {
if (virtual_loss != 0) {
// TODO: This should be a Google log (ssengupta@fb)
std::cout << "Virtual loss is not zero[" << virtual_loss << "]"
<< std::endl;
std::cout << info(true) << std::endl;
logger_->info(
"Virtual loss is not zero[{}]\n{}", virtual_loss, info(true));
assert(virtual_loss == 0);
}
}
Expand Down
21 changes: 9 additions & 12 deletions src_cpp/elf/base/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "elf/comm/comm.h"
#include "elf/concurrency/ConcurrentQueue.h"
#include "elf/concurrency/Counter.h"
#include "elf/logging/IndexedLoggerFactory.h"
#include "extractor.h"
#include "sharedmem.h"

Expand Down Expand Up @@ -133,7 +134,6 @@ class Context {
msgQueue_.push(PREPARE_TO_STOP);
completedSwitch_.waitUntilTrue();
completedSwitch_.reset();
// std::cout << " prepare to stop delivered "
// << smem_->getSharedMemOptions().info() << std::endl;
}

Expand Down Expand Up @@ -171,7 +171,6 @@ class Context {
_Msg msg;
if (msgQueue_.pop(&msg, std::chrono::microseconds(0))) {
if (msg == PREPARE_TO_STOP) {
// std::cout << " get prepare to stop signal "
// << smem_opts.info() << std::endl;

smem_->setMinBatchSize(0);
Expand All @@ -183,14 +182,12 @@ class Context {
}
}
smem_->waitBatchFillMem(server_);
// std::cout << "Receiver[" << smem_opts.getLabel() << "] Batch
// received. #batch = "
// << smem_->getEffectiveBatchSize() << std::endl;

comm::ReplyStatus batch_status =
batchClient_->sendWait(smem_.get(), {""});

// std::cout << "Receiver[" << smem_opts.getLabel() << "] Batch
// releasing. #batch = "
// << smem_->getEffectiveBatchSize() << std::endl;

Expand All @@ -203,7 +200,7 @@ class Context {
public:
using GameCallback = std::function<void(int game_idx, GameClient*)>;

Context() {
Context() : logger_(elf::logging::getLogger("elf::base::Context-", "")) {
// Wait for the derived class to add entries to extractor_.
server_ = comm_.getServer();
client_.reset(new GameClient(&comm_, this));
Expand Down Expand Up @@ -249,7 +246,6 @@ class Context {
// for (const string &key : keys) {
// LOG(INFO) << key << " ";
// }
// std::cout << std::endl;

smem2keys_[options.getRecvOptions().label] = keys;
auto anyps = extractor_.getAnyP(keys);
Expand Down Expand Up @@ -337,7 +333,7 @@ class Context {
std::thread tmp_thread([&]() {
// assert(nice(10) == 10);

std::cout << "Prepare to stop ..." << std::endl;
logger_->info("Prepare to stop ...");
client_->prepareToStop();

// First set the timeout for all collectors to be finite number.
Expand All @@ -346,22 +342,21 @@ class Context {
}

// Then stop all the threads.
std::cout << "Stop all game threads ..." << std::endl;
logger_->info("Stop all game threads ...");
client_->stopGames();

std::cout << "All games sent notification, "
<< "Waiting until they join" << std::endl;
logger_->info("All games sent notification, Waiting until they join");

for (auto& p : game_threads_) {
p.join();
}

std::cout << "Stop all collectors ..." << std::endl;
logger_->info("Stop all collectors ...");
for (auto& r : collectors_) {
r->stop();
}

std::cout << "Stop tmp pool..." << std::endl;
logger_->info("Stop tmp pool...");
tmp_thread_done = true;
});

Expand Down Expand Up @@ -393,6 +388,8 @@ class Context {
GameCallback game_cb_ = nullptr;
std::function<void()> cb_after_game_start_ = nullptr;
std::vector<std::thread> game_threads_;

std::shared_ptr<spdlog::logger> logger_;
};

template <typename S>
Expand Down
4 changes: 0 additions & 4 deletions src_cpp/elf/base/ctrl.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,6 @@ class ThreadInfosT {

std::thread::id _th_label2id(const std::string& label) const {
typename ThreadStrMap::accessor elem;
// std::cout << "looking for label: " << label << std::endl;
bool found = threadStrMap_.find(elem, label);
assert(found);
return elem->second;
Expand Down Expand Up @@ -371,12 +370,9 @@ class ThreadedCtrlBaseT {

virtual ~ThreadedCtrlBaseT() {
done_ = true;
// std::cout << "ThreadedCtrlBase: Ending thread.." << std::endl;
if (thread_ != nullptr) {
thread_->join();
}

// std::cout << "ThreadedCtrlBase: thread ended.." << std::endl;
}

protected:
Expand Down
16 changes: 10 additions & 6 deletions src_cpp/elf/base/dispatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <iostream>
#include "../concurrency/ConcurrentQueue.h"
#include "ctrl.h"
#include "elf/logging/IndexedLoggerFactory.h"

namespace elf {

Expand All @@ -18,7 +19,10 @@ class ThreadedDispatcherT : public ThreadedCtrlBase {
using ThreadRecv = std::function<bool(const S&, R*)>;

ThreadedDispatcherT(Ctrl& ctrl, int num_games)
: ThreadedCtrlBase(ctrl, 500), num_games_(num_games) {}
: ThreadedCtrlBase(ctrl, 500),
num_games_(num_games),
logger_(
elf::logging::getLogger("elf::base::ThreadedDispatcherT-", "")) {}

void Start(ServerReply replier, ServerFirstSend first_send = nullptr) {
server_replier_ = replier;
Expand All @@ -30,7 +34,6 @@ class ThreadedDispatcherT : public ThreadedCtrlBase {
void RegGame(int game_idx) {
ctrl_.reg("game_" + std::to_string(game_idx));
ctrl_.addMailbox<S, R>();
// cout << "Register game " << game_idx << endl;
game_counter_.increment();
}

Expand Down Expand Up @@ -67,13 +70,15 @@ class ThreadedDispatcherT : public ThreadedCtrlBase {
ServerReply server_replier_ = nullptr;
ServerFirstSend server_first_send_ = nullptr;

private:
std::shared_ptr<spdlog::logger> logger_;

void before_loop() override {
// Wait for all games + this processing thread.
std::cout << "Wait all games[" << num_games_
<< "] to register their mailbox" << std::endl;
logger_->info("Wait all games[{}] to register their mailbox", num_games_);
game_counter_.waitUntilCount(num_games_);
game_counter_.reset();
std::cout << "All games [" << num_games_ << "] registered" << std::endl;
logger_->info("All games [{}] registered", num_games_);

addrs_ = ctrl_.filterPrefix(std::string("game"));
for (size_t i = 0; i < addrs_.size(); ++i) {
Expand All @@ -82,7 +87,6 @@ class ThreadedDispatcherT : public ThreadedCtrlBase {
}

void on_thread() override {
// cout << "Register Recv threads" << endl;
S msg;
if (ctrl_.peekMail(&msg, 0)) {
if (just_started_ || msg != last_msg_) {
Expand Down
Loading

0 comments on commit b92efc7

Please sign in to comment.