Skip to content

Commit

Permalink
fix wrong dist-kvstore push/pull/rsp_pull (apache#7762)
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin authored and piiswrong committed Sep 7, 2017
1 parent 5bd63f6 commit 1267c6a
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 91 deletions.
4 changes: 2 additions & 2 deletions include/mxnet/kvstore.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class KVStore {
*/
virtual void PullRowSparse(const std::vector<int>& str_keys,
const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
const int priority = 0) = 0;
int priority = 0) = 0;

/*!
* \brief pull a list of key-value pairs from the store, where each key is a string.
Expand All @@ -196,7 +196,7 @@ class KVStore {
*/
virtual void PullRowSparse(const std::vector<std::string>& str_keys,
const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
const int priority = 0) = 0;
int priority = 0) = 0;

/**
* \brief the prototype of user-defined updater
Expand Down
137 changes: 68 additions & 69 deletions src/kvstore/kvstore_dist.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,63 @@ class KVStoreDist : public KVStoreLocal {
}
}

void Init(const std::vector<int>& keys,
const std::vector<NDArray>& values) override {
void set_updater(const Updater& updater) override {
CHECK(updater) << "invalid updater";
if (IsServerNode()) {
CHECK_NOTNULL(server_)->set_updater(updater);
} else {
updater_ = updater;
}
}

void Barrier() override {
ps::Postoffice::Get()->Barrier(ps::kWorkerGroup);
}

void SendCommandToServers(int cmd_id,
const std::string& cmd_body) override {
CHECK_NOTNULL(ps_worker_);
ps_worker_->Wait(ps_worker_->Request(cmd_id, cmd_body, ps::kServerGroup));
}

int get_group_size() const override { return ps::NumWorkers(); }

int get_rank() const override { return ps::MyRank(); }

int get_num_dead_node(int node_id, int timeout) const override {
int number = 0;
auto dead_nodes = ps::Postoffice::Get()->GetDeadNodes(timeout);
const auto& watch_nodes = ps::Postoffice::Get()->GetNodeIDs(node_id);
std::unordered_set<int> watch_set(watch_nodes.begin(), watch_nodes.end());
for (int r : dead_nodes) {
if (watch_set.find(r) != watch_set.end()) number++;
}
return number;
}

void RunServer(const Controller& controller) override {
CHECK(!IsWorkerNode());
if (IsServerNode()) {
server_ = new KVStoreDistServer();
server_->set_controller(controller);
}

ps::StartAsync("mxnet_server\0");
if (!ps::Postoffice::Get()->is_recovery()) {
ps::Postoffice::Get()->Barrier(
ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler);
}
if (server_) server_->Run();
ps::Finalize();
if (server_) {
delete server_;
}
server_ = nullptr;
}

private:
void InitImpl(const std::vector<int>& keys,
const std::vector<NDArray>& values) override {
CheckUnique(keys);
for (size_t i = 0; i < keys.size(); ++i) {
comm_->Init(keys[i], values[i].storage_type(), values[i].shape(), values[i].dtype());
Expand All @@ -100,15 +155,15 @@ class KVStoreDist : public KVStoreLocal {
}
}

void Push(const std::vector<int>& keys,
const std::vector<NDArray>& values,
int priority) override {
void PushImpl(const std::vector<int>& keys,
const std::vector<NDArray>& values,
int priority) override {
Push_(keys, values, priority, true);
}

void Pull(const std::vector<int>& keys,
const std::vector<NDArray*>& values,
int priority) override {
void PullImpl(const std::vector<int>& keys,
const std::vector<NDArray*>& values,
int priority) override {
std::vector<int> uniq_keys;
std::vector<std::vector<NDArray*> > grouped_vals;
GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals);
Expand Down Expand Up @@ -155,9 +210,9 @@ class KVStoreDist : public KVStoreLocal {
}
}

void PullRowSparse(const std::vector<int>& keys,
const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
const int priority = 0) {
void PullRowSparseImpl(const std::vector<int>& keys,
const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
int priority = 0) override {
std::vector<int> uniq_keys;
std::vector<std::vector<std::pair<NDArray*, NDArray>>> grouped_val_rowids;
GroupKVPairsPullRsp(keys, val_rowids, &uniq_keys, &grouped_val_rowids);
Expand Down Expand Up @@ -198,66 +253,10 @@ class KVStoreDist : public KVStoreLocal {
}
}

void set_updater(const Updater& updater) override {
CHECK(updater) << "invalid updater";
if (IsServerNode()) {
CHECK_NOTNULL(server_)->set_updater(updater);
} else {
updater_ = updater;
}
}

void Barrier() override {
ps::Postoffice::Get()->Barrier(ps::kWorkerGroup);
}


void SendCommandToServers(int cmd_id,
const std::string& cmd_body) override {
CHECK_NOTNULL(ps_worker_);
ps_worker_->Wait(ps_worker_->Request(cmd_id, cmd_body, ps::kServerGroup));
}

int get_group_size() const override { return ps::NumWorkers(); }

int get_rank() const override { return ps::MyRank(); }

int get_num_dead_node(int node_id, int timeout) const override {
int number = 0;
auto dead_nodes = ps::Postoffice::Get()->GetDeadNodes(timeout);
const auto& watch_nodes = ps::Postoffice::Get()->GetNodeIDs(node_id);
std::unordered_set<int> watch_set(watch_nodes.begin(), watch_nodes.end());
for (int r : dead_nodes) {
if (watch_set.find(r) != watch_set.end()) number++;
}
return number;
}

void RunServer(const Controller& controller) override {
CHECK(!IsWorkerNode());
if (IsServerNode()) {
server_ = new KVStoreDistServer();
server_->set_controller(controller);
}

ps::StartAsync("mxnet_server\0");
if (!ps::Postoffice::Get()->is_recovery()) {
ps::Postoffice::Get()->Barrier(
ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler);
}
if (server_) server_->Run();
ps::Finalize();
if (server_) {
delete server_;
}
server_ = nullptr;
}

private:
void Push_(const std::vector<int>& keys,
const std::vector<NDArray>& values,
int priority,
bool do_merge) {
bool do_merge) {
// first aggregate the values over keys
std::vector<int> uniq_keys;
std::vector<std::vector<NDArray> > grouped_vals;
Expand Down Expand Up @@ -320,7 +319,7 @@ class KVStoreDist : public KVStoreLocal {
}

// pull row sparse weight into `recv_buf` based on indices given by `indices`
void PullRowSparse_(int key, NDArray *recv_buf, const NDArray& indices, int priority) {
void PullRowSparse_(const int key, NDArray *recv_buf, const NDArray& indices, int priority) {
using namespace rowsparse;
auto pull_from_servers = [this, key, recv_buf, indices]
(RunContext rctx, Engine::CallbackOnComplete cb) {
Expand Down
1 change: 1 addition & 0 deletions src/kvstore/kvstore_dist_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ class KVStoreDistServer {
auto len = unit_len * num_rows;
// concat values
response.vals.resize(len);
#pragma omp parallel for
for (size_t i = 1; i <= num_rows; i++) {
int key = DecodeKey(req_data.keys[i]);
int64_t row_id = key - master_key;
Expand Down
40 changes: 20 additions & 20 deletions src/kvstore/kvstore_local.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class KVStoreLocal : public KVStore {
void Init(const std::vector<int>& keys,
const std::vector<NDArray>& values) override {
SetKeyType(kIntKey);
Init_(keys, values);
InitImpl(keys, values);
}

void Init(const std::vector<std::string>& str_keys,
Expand All @@ -84,28 +84,28 @@ class KVStoreLocal : public KVStore {
reverse_str_key_dict_[key] = str_key;
keys[i] = key;
}
Init_(keys, values);
InitImpl(keys, values);
}

void Push(const std::vector<int>& keys,
const std::vector<NDArray>& values,
int priority) override {
SetKeyType(kIntKey);
Push_(keys, values, priority);
PushImpl(keys, values, priority);
}

void Pull(const std::vector<int>& keys,
const std::vector<NDArray*>& values,
int priority) override {
SetKeyType(kIntKey);
Pull_(keys, values, priority);
PullImpl(keys, values, priority);
}

void PullRowSparse(const std::vector<int>& keys,
const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
int priority = 0) override {
SetKeyType(kIntKey);
PullRowSparse_(keys, val_rowids, priority);
PullRowSparseImpl(keys, val_rowids, priority);
}

void Push(const std::vector<std::string>& str_keys,
Expand All @@ -114,7 +114,7 @@ class KVStoreLocal : public KVStore {
SetKeyType(kStringKey);
std::vector<int> keys(str_keys.size());
LookupKeys(str_keys, &keys);
Push_(keys, values, priority);
PushImpl(keys, values, priority);
}

void Pull(const std::vector<std::string>& str_keys,
Expand All @@ -123,21 +123,21 @@ class KVStoreLocal : public KVStore {
SetKeyType(kStringKey);
std::vector<int> keys(str_keys.size());
LookupKeys(str_keys, &keys);
Pull_(keys, values, priority);
PullImpl(keys, values, priority);
}

void PullRowSparse(const std::vector<std::string>& str_keys,
const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
const int priority = 0) override {
int priority = 0) override {
SetKeyType(kStringKey);
std::vector<int> keys(str_keys.size());
LookupKeys(str_keys, &keys);
PullRowSparse_(keys, val_rowids, priority);
PullRowSparseImpl(keys, val_rowids, priority);
}

private:
void Init_(const std::vector<int>& keys,
const std::vector<NDArray>& values) {
virtual void InitImpl(const std::vector<int>& keys,
const std::vector<NDArray>& values) {
for (size_t i = 0; i < keys.size(); ++i) {
CHECK(local_.find(keys[i]) == local_.end())
<< "duplicate init of key " << keys[i];
Expand All @@ -146,9 +146,9 @@ class KVStoreLocal : public KVStore {
}
}

void Push_(const std::vector<int>& keys,
const std::vector<NDArray>& values,
int priority) {
virtual void PushImpl(const std::vector<int>& keys,
const std::vector<NDArray>& values,
int priority) {
std::vector<int> uniq_keys;
std::vector<std::vector<NDArray> > grouped_vals;
GroupKVPairsPush(keys, values, &uniq_keys, &grouped_vals);
Expand Down Expand Up @@ -185,9 +185,9 @@ class KVStoreLocal : public KVStore {
}
}

void Pull_(const std::vector<int>& keys,
const std::vector<NDArray*>& values,
int priority) {
virtual void PullImpl(const std::vector<int>& keys,
const std::vector<NDArray*>& values,
int priority) {
std::vector<int> uniq_keys;
std::vector<std::vector<NDArray*> > grouped_vals;
GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals);
Expand All @@ -200,9 +200,9 @@ class KVStoreLocal : public KVStore {
}
}

void PullRowSparse_(const std::vector<int>& keys,
const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
int priority = 0) {
virtual void PullRowSparseImpl(const std::vector<int>& keys,
const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
int priority = 0) {
std::vector<int> uniq_keys;
std::vector<std::vector<std::pair<NDArray*, NDArray>>> grouped_val_rowids;
GroupKVPairsPullRsp(keys, val_rowids, &uniq_keys, &grouped_val_rowids);
Expand Down

0 comments on commit 1267c6a

Please sign in to comment.