Skip to content

Commit

Permalink
use CudnnHolder in conv_transpose_cudnn_op
Browse files Browse the repository at this point in the history
  • Loading branch information
JiayiFeng committed Aug 30, 2018
1 parent 15cc912 commit d5f74b7
Showing 1 changed file with 4 additions and 13 deletions.
17 changes: 4 additions & 13 deletions paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,8 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc,
cudnn_output_desc, algo, &workspace_size_in_bytes));

// Allocate on GPU memory
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// Get cudnn workspace
cudnn_workspace = dev_ctx.cudnn_workspace(workspace_size_in_bytes);

// ------------------- cudnn conv transpose forward ---------------------
int input_offset = input->numel() / input->dims()[0] / groups;
Expand All @@ -116,9 +115,6 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
algo, cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_output_desc, output_data + output_offset * g));
}

// Release the cudnn workspace
paddle::memory::Free(gpu, cudnn_workspace);
}
};

Expand Down Expand Up @@ -207,10 +203,8 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
}

// ------------------- cudnn conv workspace ---------------------
// Already on GPU
void* cudnn_workspace = nullptr;
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// Get cudnn workspace
void* cudnn_workspace = dev_ctx.cudnn_workspace(workspace_size_in_bytes);
// ------------------- cudnn conv backward data ---------------------
// FIXME(typhoonzero): template type T may not be the same as cudnn call.
int input_offset = input->numel() / input->dims()[0] / groups;
Expand Down Expand Up @@ -245,9 +239,6 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
filter_grad_data + filter_offset * g));
}
}

// Release the cudnn workspace
paddle::memory::Free(gpu, cudnn_workspace);
}
};

Expand Down

0 comments on commit d5f74b7

Please sign in to comment.