Skip to content

Commit

Permalink
fix conv sum fusion bug found in VNet
Browse files Browse the repository at this point in the history
  • Loading branch information
ftian1 committed Dec 27, 2018
1 parent 14ac1e5 commit d0e6e15
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 1 deletion.
10 changes: 10 additions & 0 deletions include/caffe/net.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,16 @@ class Net {
const NetParameter& param,
int layer_id);

/**
* @brief If find "Conv--BN--Scale" in current network, merge BN and Scale layer into Convolution
* layers, this optimization only works in caffe TEST phase now.
*/

static void GetBlobProducers(std::vector<const LayerParameter*> &producer_blobs,
const string& blob_name_to_find,
const NetParameter& param,
int layer_id);

static void GetNeedToCancelInplaceLayers(
vector<vector<const LayerParameter*>>& layer_pairs,
std::map<string, int>& specified_layer_blob_name_to_index,
Expand Down
45 changes: 44 additions & 1 deletion src/caffe/net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,27 @@ void Net<Dtype>::CompilationRuleConvSumFusion(const NetParameter& param,
switch_flag = false;
}
param_compiled->add_layer()->CopyFrom(*layer_param);
need_to_convert_layer = layer_param;

std::vector<const LayerParameter*> another_eltwise_input_layers_params;
if (switch_flag) {
Net<Dtype>::GetBlobProducers(another_eltwise_input_layers_params,
child_layers_params[0]->bottom(1), param,
i + 1 < param.layer_size() ? i + 1 : i);
} else {
Net<Dtype>::GetBlobProducers(another_eltwise_input_layers_params,
child_layers_params[0]->bottom(0), param,
i + 1 < param.layer_size() ? i + 1 : i);
}

const LayerParameter& another_eltwise_layer_param =
another_eltwise_input_layers_params.size() > 0
? *(another_eltwise_input_layers_params[0])
: *layer_param;

if (another_eltwise_layer_param.type().compare("Convolution") == 0 ) {
need_to_convert_layer = layer_param;

}
continue;

} else {
Expand Down Expand Up @@ -1284,6 +1304,29 @@ void Net<Dtype>::GetBlobConsumers(
}
}

template <typename Dtype>
void Net<Dtype>::GetBlobProducers(
std::vector<const LayerParameter*>& producers_blobs,
const string& blob_name_to_find,
const NetParameter& param,
int layer_id_to_start_traversing_from) {
producers_blobs.clear();

// Validate values of ids of layers are <1..num_layers-1>
CHECK_GE(layer_id_to_start_traversing_from, 1);
CHECK_LT(layer_id_to_start_traversing_from, param.layer_size());

// Traverse through layers to search the layer that consumes blob_name_to_find
for (int i = layer_id_to_start_traversing_from; i < param.layer_size(); ++i) {
// check bottom blobs if any of them is consuming given blob
for (int j = 0; j < param.layer(i).top_size(); ++j) {
if (param.layer(i).top(j).compare(blob_name_to_find) == 0) {
producers_blobs.push_back(&param.layer(i));
}
}
}
}

template <typename Dtype>
void Net<Dtype>::ParseNetInplaceStatus(
std::map<string, int>& inplace_blob_name_to_index,
Expand Down

0 comments on commit d0e6e15

Please sign in to comment.