Skip to content

Commit

Permalink
update by comment
Browse files Browse the repository at this point in the history
  • Loading branch information
Yancey1989 committed Sep 13, 2018
1 parent b084dfa commit 1e1b662
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 23 deletions.
6 changes: 1 addition & 5 deletions paddle/fluid/framework/details/all_reduce_op_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,7 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
#endif

void AllReduceOpHandle::RunImpl() {
if (dev_ctxes_.size() > 0UL) {
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
} else {
platform::RecordEvent record_event(Name(), nullptr);
}
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);

if (NoDummyInputSize() == 1) {
return; // No need to all reduce when GPU count = 1;
Expand Down
6 changes: 1 addition & 5 deletions paddle/fluid/framework/details/broadcast_op_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,7 @@ namespace framework {
namespace details {

void BroadcastOpHandle::RunImpl() {
if (dev_ctxes_.size() > 0UL) {
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
} else {
platform::RecordEvent record_event(Name(), nullptr);
}
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);

if (places_.size() == 1) return;

Expand Down
6 changes: 0 additions & 6 deletions paddle/fluid/framework/details/data_balance_op_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,6 @@ std::vector<std::array<int, 3>> DataBalanceOpHandle::GetBalancePlan(
}

void DataBalanceOpHandle::RunImpl() {
if (dev_ctxes_.size() > 0UL) {
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
} else {
platform::RecordEvent record_event(Name(), nullptr);
}

PADDLE_ENFORCE_GT(places_.size(), 1,
"Data balance can only be enabled when the number of "
"places to run larger than 1.");
Expand Down
10 changes: 3 additions & 7 deletions paddle/fluid/framework/details/multi_devices_graph_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -431,10 +431,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
CreateReduceOp(&result, g_name, cur_device_id);
graph->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(g_name, cur_device_id);
if (!is_dist_train) {
// will send gradients directly when distributed training
bcast_var_name_set[cur_device_id].emplace(p_name);
}
bcast_var_name_set[cur_device_id].emplace(p_name);
break;
case BuildStrategy::ReduceStrategy::kAllReduce:
if (IsSparseGradient(g_name)) {
Expand All @@ -461,9 +458,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
use_gpu = nccl_ctxs_ != nullptr;
#endif

if ((use_gpu &&
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) ||
is_dist_train) {
if (use_gpu && strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce &&
!is_dist_train) {
// Insert BCast Ops
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
auto &to_bcast_set = bcast_var_name_set[dev_id];
Expand Down

0 comments on commit 1e1b662

Please sign in to comment.