Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#107 from qingqing01/cudnn_conv
Browse files Browse the repository at this point in the history
fix cudnn conv bug which occurs in image classfication demo in GTX GPU
  • Loading branch information
hedaoyuan authored Sep 23, 2016
2 parents 7eb29f2 + c1c07bb commit 341486d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
10 changes: 10 additions & 0 deletions paddle/gserver/layers/CudnnConvLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ bool CudnnConvLayer::init(const LayerMap &layerMap,
biasOffset_ = numFilters_ / groups_[0];
}

batchNum_ = 0;
isSelectAlgo_ = false;
return true;
}
Expand Down Expand Up @@ -132,6 +133,11 @@ void CudnnConvLayer::reshape(int batchSize) {
getOutput().setFrameHeight(outputH_);
getOutput().setFrameWidth(outputW_);

// if the batchSize remains the same, set isSelectAlgo_ true.
// Otherwise, set isSelectAlgo_ false and select algo again.
isSelectAlgo_ = (batchSize == batchNum_);
batchNum_ = batchSize;

size_t maxWorkSpace = 0;
for (size_t i = 0; i < inputLayers_.size(); i++) {
CHECK_EQ(inputLayers_[i]->getOutput().value->getWidth(),
Expand Down Expand Up @@ -160,6 +166,10 @@ void CudnnConvLayer::reshape(int batchSize) {

maxWorkSpace = std::max(fwdLimitBytes_[i], bwdDataLimitBytes_[i]);
maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_[i]);

VLOG(3) << getName() << " Fwd / BwdData / BwdFilter algo: " << fwdAlgo_[i]
<< " / " << bwdDataAlgo_[i]
<< " / " << bwdFilterAlgo_[i];
}
}

Expand Down
4 changes: 4 additions & 0 deletions paddle/gserver/layers/CudnnConvLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ class CudnnConvLayer : public ConvBaseLayer {
/// Is or not select conv algorihtm.
bool isSelectAlgo_;

/// batchNum is used to record batch size. If the batch size is changed,
/// the selection algorithm will be called.
int batchNum_;

public:
explicit CudnnConvLayer(const LayerConfig& config) : ConvBaseLayer(config) {}

Expand Down

0 comments on commit 341486d

Please sign in to comment.