Skip to content

Commit

Permalink
Merge pull request pytorch#70 from jma127/master
Browse files Browse the repository at this point in the history
Big refactor from FAIR team
  • Loading branch information
jma127 authored Jul 20, 2018
2 parents 1b6859f + d5f320d commit da2f20c
Show file tree
Hide file tree
Showing 69 changed files with 2,699 additions and 1,673 deletions.
19 changes: 10 additions & 9 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
[submodule "third_party/concurrentqueue"]
path = third_party/concurrentqueue
url = https://github.com/cameron314/concurrentqueue.git
[submodule "third_party/cppzmq"]
path = third_party/cppzmq
url = https://github.com/zeromq/cppzmq.git
[submodule "third_party/googletest"]
path = third_party/googletest
url = https://github.com/google/googletest.git
[submodule "third_party/json"]
path = third_party/json
url = https://github.com/nlohmann/json.git
[submodule "third_party/pybind11"]
path = third_party/pybind11
url = https://github.com/pybind/pybind11.git
[submodule "third_party/spdlog"]
path = third_party/spdlog
url = https://github.com/gabime/spdlog.git
[submodule "third_party/json"]
path = third_party/json
url = https://github.com/nlohmann/json.git
[submodule "third_party/googletest"]
path = third_party/googletest
url = https://github.com/google/googletest.git
[submodule "third_party/tbb"]
path = third_party/tbb
url = https://github.com/01org/tbb.git
ignore = untracked
[submodule "third_party/cppzmq"]
path = third_party/cppzmq
url = https://github.com/zeromq/cppzmq.git
22 changes: 0 additions & 22 deletions scripts/elfgames/go/analysis.sh

This file was deleted.

12 changes: 6 additions & 6 deletions src_cpp/elf/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ set(ELF_SOURCES
options/Pybind.cc
)

set(ELF_TEST_SOURCES
options/OptionMapTest.cc
options/OptionSpecTest.cc
)
# set(ELF_TEST_SOURCES
# options/OptionMapTest.cc
# options/OptionSpecTest.cc
# )

# Main ELF library

Expand All @@ -26,10 +26,10 @@ target_compile_definitions(elf PUBLIC
GIT_COMMIT_HASH=${GIT_COMMIT_HASH}
GIT_STAGED=${GIT_STAGED_STRING}
)

target_link_libraries(elf PUBLIC
#${Boost_LIBRARIES}
concurrentqueue
cppzmq
nlohmann_json
pybind11
$<BUILD_INTERFACE:${PYTHON_LIBRARIES}>
Expand All @@ -40,7 +40,7 @@ target_link_libraries(elf PUBLIC
# Tests

enable_testing()
add_cpp_tests(test_cpp_elf_ elf ${ELF_TEST_SOURCES})
# add_cpp_tests(test_cpp_elf_ elf ${ELF_TEST_SOURCES})

# Python bindings

Expand Down
4 changes: 2 additions & 2 deletions src_cpp/elf/ai/ai.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ class AIClientT : public AI_T<S, A> {
status == comm::ReplyStatus::UNKNOWN;
}

virtual bool act_batch(
bool act_batch(
const std::vector<const S*>& batch_s,
const std::vector<A*>& batch_a) {
const std::vector<A*>& batch_a) override {
std::vector<elf::FuncsWithState> funcs_s =
client_->BindStateToFunctions(targets_, batch_s);
std::vector<elf::FuncsWithState> funcs_a =
Expand Down
23 changes: 12 additions & 11 deletions src_cpp/elf/ai/tree_search/tree_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class TreeSearchSingleThreadT {
: threadId_(thread_id), options_(options) {
if (options_.verbose) {
std::string log_file =
"tree_search_" + std::to_string(thread_id) + ".txt";
options_.log_prefix + std::to_string(thread_id) + ".txt";
output_.reset(new std::ofstream(log_file));
}
}
Expand Down Expand Up @@ -374,10 +374,7 @@ class TreeSearchT {
return searchTree_.printTree();
}

MCTSResult runPolicyOnly(const State& /*root_state*/) {
// TODO Policy only doesn't work.
assert(false);
/*
MCTSResult runPolicyOnly(const State& root_state) {
if (actors_.empty() || treeSearches_.empty()) {
throw std::range_error(
"TreeSearch::runPolicyOnly works when there is at least one thread");
Expand All @@ -386,15 +383,17 @@ class TreeSearchT {

// Some hack here.
Node* root = searchTree_.getRootNode();
treeSearches_[0]->visit(*actors_[0], root);

// return StrongestPrior(root->getStateActions());
*/
if (!root->isVisited()) {
NodeResponseT<Action> resp;
actors_[0]->evaluate(*root->getStatePtr(), &resp);
root->setEvaluation(resp);
}

MCTSResult result;
// result.action_rank_method = MCTSResult::PRIOR;
// result.addActions(root->getStateActions());

result.action_rank_method = MCTSResult::PRIOR;
result.addActions(root->getStateActions());
result.root_value = root->getValue();
return result;
}

Expand Down Expand Up @@ -490,6 +489,8 @@ class TreeSearchT {

// Pick the best solution.
MCTSResult result;
result.root_value = root->getValue();

// MCTSResult result2;
if (options_.pick_method == "strongest_prior") {
result.action_rank_method = MCTSResult::PRIOR;
Expand Down
10 changes: 7 additions & 3 deletions src_cpp/elf/ai/tree_search/tree_search_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@ struct StateTrait {
return s1 == s2;
}

static bool
moves_since(const S& s, size_t* next_move_number, std::vector<A>* moves) {
static bool moves_since(
const S& /*s*/,
size_t* /*next_move_number*/,
std::vector<A>* /*moves*/) {
// By default it is not provided.
return false;
}
Expand All @@ -84,7 +86,7 @@ struct ActionTrait {
template <typename Actor>
struct ActorTrait {
public:
static std::string to_string(const Actor& a) {
static std::string to_string(const Actor&) {
return "";
}
};
Expand Down Expand Up @@ -213,6 +215,7 @@ struct MCTSResultT {
enum RankCriterion { MOST_VISITED = 0, PRIOR = 1, UNIFORM_RANDOM };

Action best_action;
float root_value;
float max_score;
EdgeInfo best_edge_info;
MCTSPolicy<Action> mcts_policy;
Expand All @@ -224,6 +227,7 @@ struct MCTSResultT {
// action_edges [email protected]
MCTSResultT()
: best_action(ActionTrait<Action>::default_value()),
root_value(0.0),
max_score(std::numeric_limits<float>::lowest()),
best_edge_info(0),
total_visits(0),
Expand Down
8 changes: 8 additions & 0 deletions src_cpp/elf/ai/tree_search/tree_search_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ struct TSOptions {
bool persistent_tree = false;
float root_epsilon = 0.0;
float root_alpha = 0.0;
std::string log_prefix = "";

// [TODO] Not a good design.
// string pick_method = "strongest_prior";
Expand All @@ -102,6 +103,7 @@ struct TSOptions {
ss << "Maximal #moves (0 = no constraint): " << max_num_moves
<< std::endl;
ss << "Seed: " << seed << std::endl;
ss << "Log Prefix: " << log_prefix << std::endl;
ss << "#Threads: " << num_threads << std::endl;
ss << "#Rollout per thread: " << num_rollouts_per_thread
<< ", #rollouts per batch: " << num_rollouts_per_batch << std::endl;
Expand Down Expand Up @@ -156,6 +158,9 @@ struct TSOptions {
if (t1.pick_method != t2.pick_method) {
return false;
}
if (t1.log_prefix != t2.log_prefix) {
return false;
}
if (t1.root_epsilon != t2.root_epsilon) {
return false;
}
Expand All @@ -181,6 +186,7 @@ struct TSOptions {
JSON_SAVE(j, seed);
JSON_SAVE(j, persistent_tree);
JSON_SAVE(j, pick_method);
JSON_SAVE(j, log_prefix);
JSON_SAVE(j, root_epsilon);
JSON_SAVE(j, root_alpha);
JSON_SAVE(j, virtual_loss);
Expand All @@ -198,6 +204,7 @@ struct TSOptions {
JSON_LOAD(opt, j, seed);
JSON_LOAD(opt, j, persistent_tree);
JSON_LOAD(opt, j, pick_method);
JSON_LOAD(opt, j, log_prefix);
JSON_LOAD(opt, j, root_epsilon);
JSON_LOAD(opt, j, root_alpha);
JSON_LOAD(opt, j, virtual_loss);
Expand All @@ -213,6 +220,7 @@ struct TSOptions {
verbose,
persistent_tree,
pick_method,
log_prefix,
virtual_loss,
verbose_time,
alg_opt,
Expand Down
16 changes: 11 additions & 5 deletions src_cpp/elf/base/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class Context {

void start() {
th_.reset(new std::thread([&]() {
assert(nice(10) == 10);
// assert(nice(10) == 10);
collectAndSendBatch();
}));
}
Expand Down Expand Up @@ -183,11 +183,17 @@ class Context {
}
}
smem_->waitBatchFillMem(server_);
// LOG(INFO) << "Receiver: Batch received. #batch = "
// << batch.size() << std::endl;
// std::cout << "Receiver[" << smem_opts.getLabel() << "] Batch
// received. #batch = "
// << smem_->getEffectiveBatchSize() << std::endl;

comm::ReplyStatus batch_status =
batchClient_->sendWait(smem_.get(), {""});

// std::cout << "Receiver[" << smem_opts.getLabel() << "] Batch
// releasing. #batch = "
// << smem_->getEffectiveBatchSize() << std::endl;

// LOG(INFO) << "Receiver: Release batch" << std::endl;
smem_->waitReplyReleaseBatch(server_, batch_status);
}
Expand Down Expand Up @@ -280,7 +286,7 @@ class Context {
auto* client = getClient();
for (int i = 0; i < num_games_; ++i) {
game_threads_.emplace_back([i, client, this]() {
assert(nice(19) == 19);
// assert(nice(19) == 19);
client->start();
game_cb_(i, client);
client->End();
Expand Down Expand Up @@ -329,7 +335,7 @@ class Context {
std::atomic<bool> tmp_thread_done(false);

std::thread tmp_thread([&]() {
assert(nice(10) == 10);
// assert(nice(10) == 10);

std::cout << "Prepare to stop ..." << std::endl;
client_->prepareToStop();
Expand Down
Loading

0 comments on commit da2f20c

Please sign in to comment.