Skip to content

Commit

Permalink
step 0 of cuDNN v8 convolution API integration (pytorch#51390)
Browse files Browse the repository at this point in the history
Summary:
This PR is step 0 of adding PyTorch convolution bindings using the cuDNN frontend. The cuDNN frontend is the recommended way of using cuDNN v8 API. It is supposed to have faster release cycles, so that, for example, if people find a specific kernel has a bug, they can report it, and that kernel will be blocked in the cuDNN frontend and frameworks could just update that submodule without the need for waiting for a whole cuDNN release.

The work is not complete, and this PR is only step 0.

**What this PR does:**
- Add cudnn-frontend as a submodule.
- Modify cmake to build that submodule.
- Add bindings for convolution forward in `Conv_v8.cpp`, which is disabled by a macro by default.
- Tested manually by enabling the macro and run `test_nn.py`. All tests pass except those mentioned below.

**What this PR doesn't:**
- Only convolution forward, no backward. The backward will use v7 API.
- No 64bit-indexing support for some configuration. This is a known issue of cuDNN, and will be fixed in a later cuDNN version. PyTorch will not implement any workaround for issue, but instead, v8 API should be disabled on problematic cuDNN versions.
- No test beyond PyTorch's unit tests.
  - Not tested for correctness on real models.
  - Not benchmarked for performance.
- Benchmark cache is not thread-safe. (This is marked as `FIXME` in the code, and will be fixed in a follow-up PR)
- cuDNN benchmark is not supported.
- There are failing tests, which will be resolved later:
  ```
  FAILED test/test_nn.py::TestNNDeviceTypeCUDA::test_conv_cudnn_nhwc_cuda_float16 - AssertionError: False is not true : Tensors failed to compare as equal!With rtol=0.001 and atol=1e-05, found 32 element(s) (out of 32) whose difference(s) exceeded the margin of error (in...
  FAILED test/test_nn.py::TestNNDeviceTypeCUDA::test_conv_cudnn_nhwc_cuda_float32 - AssertionError: False is not true : Tensors failed to compare as equal!With rtol=1.3e-06 and atol=1e-05, found 32 element(s) (out of 32) whose difference(s) exceeded the margin of error (...
  FAILED test/test_nn.py::TestNNDeviceTypeCUDA::test_conv_large_cuda - RuntimeError: CUDNN_BACKEND_OPERATION: cudnnFinalize Failed cudnn_status: 9
  FAILED test/test_nn.py::TestNN::test_Conv2d_depthwise_naive_groups_cuda - AssertionError: False is not true : Tensors failed to compare as equal!With rtol=0 and atol=1e-05, found 64 element(s) (out of 64) whose difference(s) exceeded the margin of error (including 0 an...
  FAILED test/test_nn.py::TestNN::test_Conv2d_deterministic_cudnn - RuntimeError: not supported yet
  FAILED test/test_nn.py::TestNN::test_ConvTranspose2d_groups_cuda_fp32 - RuntimeError: cuDNN error: CUDNN_STATUS_BAD_PARAM
  FAILED test/test_nn.py::TestNN::test_ConvTranspose2d_groups_cuda_tf32 - RuntimeError: cuDNN error: CUDNN_STATUS_BAD_PARAM
  ```

Although this is not a complete implementation of cuDNN v8 API binding, I still want to merge this first. This would allow me to do small and incremental work, for the ease of development and review.

Pull Request resolved: pytorch#51390

Reviewed By: malfet

Differential Revision: D28513167

Pulled By: ngimel

fbshipit-source-id: 9cc20c9dec5bbbcb1f94ac9e0f59b10c34f62740
  • Loading branch information
zasdfgbnm authored and facebook-github-bot committed May 19, 2021
1 parent 954d39b commit 6c70cbe
Show file tree
Hide file tree
Showing 10 changed files with 217 additions and 4 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@
ignore = dirty
path = third_party/tensorpipe
url = https://github.com/pytorch/tensorpipe.git
[submodule "third_party/cudnn_frontend"]
path = third_party/cudnn_frontend
url = https://github.com/NVIDIA/cudnn-frontend.git
[submodule "third_party/kineto"]
path = third_party/kineto
url = https://github.com/pytorch/kineto
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ cmake_dependent_option(
cmake_dependent_option(
USE_STATIC_CUDNN "Use cuDNN static libraries" OFF
"USE_CUDNN" OFF)
cmake_dependent_option(
USE_EXPERIMENTAL_CUDNN_V8_API "Use experimental cuDNN v8 API" OFF
"USE_CUDNN" OFF)
option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON)
option(USE_KINETO "Use Kineto profiling library" ON)
option(USE_CUPTI_SO "Use CUPTI as a shared library" OFF)
Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/native/cudnn/Conv_v7.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#if AT_CUDNN_ENABLED()

#include <ATen/native/cudnn/Macros.h>

#include <limits>
#include <vector>
#include <sstream>
Expand Down Expand Up @@ -614,6 +616,8 @@ if (args.params.dataType == CUDNN_DATA_FLOAT) {
//
// ---------------------------------------------------------------------

#if !HAS_CUDNN_V8()

void raw_cudnn_convolution_forward_out_32bit(
const Tensor& output, const Tensor& input, const Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
Expand Down Expand Up @@ -665,6 +669,8 @@ void raw_cudnn_convolution_forward_out(
split_batch_dim_to_32bit_out(output, input, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, 1024 * 1024 * 256, raw_cudnn_convolution_forward_out_32bit);
}

#endif // !HAS_CUDNN_V8()

// ---------------------------------------------------------------------
//
// Convolution backward / Transposed convolution forward
Expand Down
178 changes: 175 additions & 3 deletions aten/src/ATen/native/cudnn/Conv_v8.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,177 @@
#include <ATen/cuda/CUDAConfig.h> // for the definition of AT_CUDNN_ENABLED

#if AT_CUDNN_ENABLED() && defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000
// Coming soon
#endif // AT_CUDNN_ENABLED and CUDNN_VERSION
#if AT_CUDNN_ENABLED()

#include <ATen/native/cudnn/Macros.h>

#if HAS_CUDNN_V8()

#include <ATen/cudnn/cudnn-wrapper.h>
#include <cudnn_frontend.h>
#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/native/ConvUtils.h>
#include <ATen/native/cudnn/ConvShared.h>
#include <ATen/native/utils/ParamsHash.h>
#include <ATen/cudnn/Handle.h>
#include <ATen/TensorUtils.h>

#include <unordered_map>

namespace at { namespace native{

namespace {

uint8_t getAlignment(const Tensor &t) {
// alignment are in bytes
uint8_t alignment = 1;
uint64_t address = reinterpret_cast<uint64_t>(t.data_ptr());
while (address % alignment == 0 && alignment < 16) alignment *= 2;
return alignment;
}

cudnn_frontend::Tensor getTensorDescriptor(const Tensor &t, int64_t id, uint8_t alignment) {
auto shape = t.sizes();
auto strides = t.strides();
return cudnn_frontend::TensorBuilder()
.setDim(shape.size(), shape.data())
.setStrides(strides.size(), strides.data())
.setId(id)
.setAlignment(alignment)
.setDataType(getCudnnDataType(t))
.build();
}

cudnn_frontend::ConvDesc_v8 getConvDescriptor(cudnnDataType_t dataType, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation) {
uint64_t convDim = stride.size();
return cudnn_frontend::ConvDescBuilder()
.setDataType(dataType)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, stride.data())
.setPrePadding(convDim, padding.data())
.setPostPadding(convDim, padding.data())
.setDilation(convDim, dilation.data())
.build();
}

void filterEngineConfigs(
cudnn_frontend::EngineConfigList &from,
cudnn_frontend::EngineConfigList &to,
bool deterministic, bool allow_tf32, c10::ScalarType scalar_type)
{
auto filter = [=](cudnnBackendDescriptor_t c) {
if (deterministic) {
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_NONDETERMINISTIC>(c)) return true;
}
if (scalar_type == kFloat || !allow_tf32) {
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) return true;
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c)) return true;
}
return false;
};
cudnn_frontend::filter(from, to, filter);
}

struct CacheKey {
ConvolutionParams params;
uint8_t input_alignment;
uint8_t weight_alignment;
uint8_t output_alignment;
};

// FIXME: make this thread-safe by reusing the benchmark cache in Conv_v7.cpp
std::unordered_map<CacheKey, cudnn_frontend::ManagedOpaqueDescriptor, ParamsHash<CacheKey>, ParamsEqual<CacheKey>> engine_cache;

}

void raw_cudnn_convolution_forward_out(
const Tensor& output, const Tensor& input, const Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32)
{
TORCH_CHECK(!benchmark, "not supported yet");
if (output.numel() == 0) {
return;
}

cudnnHandle_t handle = getCudnnHandle();

CacheKey key;
setConvolutionParams(&key.params, input, weight, padding, stride, dilation, groups, deterministic, allow_tf32);
key.input_alignment = getAlignment(input);
key.output_alignment = getAlignment(output);
key.weight_alignment = getAlignment(weight);

auto run = [&](cudnn_frontend::ManagedOpaqueDescriptor cfg) {
auto plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle)
.setEngineConfig(cfg)
.build();

auto workspace_size = plan.getWorkspaceSize();
auto workspace = at::empty({workspace_size}, input.options().dtype(kByte));
void *data_ptrs[] = {input.data_ptr(), output.data_ptr(), weight.data_ptr()};
// std::cout << plan.describe() << " requires workspace " << workspace_size << std::endl;
int64_t uids[] = {'x', 'y', 'w'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace.data_ptr())
.setDataPointers(3, data_ptrs)
.setUids(3, uids)
.build();
AT_CUDNN_CHECK(cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc()));
};

auto search = engine_cache.find(key);
if (search != engine_cache.end()) {
run(search->second);
return;
}

auto op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
.setxDesc(getTensorDescriptor(input, 'x', key.input_alignment))
.setyDesc(getTensorDescriptor(output, 'y', key.output_alignment))
.setwDesc(getTensorDescriptor(weight, 'w', key.weight_alignment))
.setcDesc(getConvDescriptor(key.params.dataType, padding, stride, dilation))
.build();
// std::cout << op.describe() << std::endl;

std::array<cudnn_frontend::Operation const *, 1> ops = {&op};

auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle)
.setOperationGraph(1, ops.data())
.build();
// std::cout << opGraph.describe() << std::endl;

auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
.setOperationGraph(opGraph)
.setHeurMode(CUDNN_HEUR_MODE_INSTANT)
.build();
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
.setOperationGraph(opGraph)
.setOperation(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
.build();

auto& engine_configs = heuristics.getEngineConfig(heuristics.getEngineConfigCount());
auto& fallback_list = fallback.getFallbackList();

cudnn_frontend::EngineConfigList filtered_configs;
filterEngineConfigs(engine_configs, filtered_configs, deterministic, allow_tf32, input.scalar_type());
filterEngineConfigs(fallback_list, filtered_configs, deterministic, allow_tf32, input.scalar_type());

for (auto &cfg : filtered_configs) {
try {
run(cfg);
engine_cache[key] = cfg;
return;
} catch (cudnn_frontend::cudnnException &e) {} catch(CuDNNError &e) {}
}
TORCH_CHECK(false, "Unable to find an engine to execute this computation");
}

}} // at::native

#endif // HAS_CUDNN_V8
#endif // AT_CUDNN_ENABLED
12 changes: 12 additions & 0 deletions aten/src/ATen/native/cudnn/Macros.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

#include <ATen/cudnn/cudnn-wrapper.h>

// Note: The version below should not actually be 8000. Instead, it should
// be whatever version of cuDNN that v8 API work with PyTorch correctly.
// The version is set to 8000 today for convenience of debugging.
#if defined(USE_EXPERIMENTAL_CUDNN_V8_API) && defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000
#define HAS_CUDNN_V8() true
#else
#define HAS_CUDNN_V8() false
#endif
9 changes: 9 additions & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1304,6 +1304,15 @@ elseif(USE_ROCM)
target_compile_definitions(torch_hip PRIVATE "-DTORCH_HIP_BUILD_MAIN_LIB")
endif()

if(USE_EXPERIMENTAL_CUDNN_V8_API)
if(BUILD_SPLIT_CUDA)
target_compile_definitions(torch_cuda_cu PRIVATE "-DUSE_EXPERIMENTAL_CUDNN_V8_API")
target_compile_definitions(torch_cuda_cpp PRIVATE "-DUSE_EXPERIMENTAL_CUDNN_V8_API")
elseif(USE_CUDA)
target_compile_definitions(torch_cuda PRIVATE "-DUSE_EXPERIMENTAL_CUDNN_V8_API")
endif()
endif()

set(EXPERIMENTAL_SINGLE_THREAD_POOL "0" CACHE STRING
"Experimental option to use a single thread pool for inter- and intra-op parallelism")
if("${EXPERIMENTAL_SINGLE_THREAD_POOL}")
Expand Down
6 changes: 6 additions & 0 deletions cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,12 @@ if(USE_CUDA)
endif()
endif()

# ---[ cuDNN
if(USE_CUDNN)
set(CUDNN_FRONTEND_INCLUDE_DIR ${CMAKE_CURRENT_LIST_DIR}/../third_party/cudnn_frontend/include)
include_directories(${CUDNN_FRONTEND_INCLUDE_DIR})
endif()

# ---[ HIP
if(USE_ROCM)
include(${CMAKE_CURRENT_LIST_DIR}/public/LoadHIP.cmake)
Expand Down
1 change: 1 addition & 0 deletions cmake/Summary.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ function(caffe2_print_configuration_summary)
message(STATUS " Split CUDA : ${BUILD_SPLIT_CUDA}")
message(STATUS " CUDA static link : ${CAFFE2_STATIC_LINK_CUDA}")
message(STATUS " USE_CUDNN : ${USE_CUDNN}")
message(STATUS " USE_EXPERIMENTAL_CUDNN_V8_API: ${USE_EXPERIMENTAL_CUDNN_V8_API}")
message(STATUS " CUDA version : ${CUDA_VERSION}")
if(${USE_CUDNN})
message(STATUS " cuDNN version : ${CUDNN_VERSION}")
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def not_exists_or_empty(folder):
print('Please run:\n\tgit submodule update --init --recursive')
sys.exit(1)
for folder in folders:
check_for_files(folder, ["CMakeLists.txt", "Makefile", "setup.py", "LICENSE"])
check_for_files(folder, ["CMakeLists.txt", "Makefile", "setup.py", "LICENSE", "LICENSE.txt"])
check_for_files(os.path.join(third_party_path, 'fbgemm', 'third_party',
'asmjit'), ['CMakeLists.txt'])
check_for_files(os.path.join(third_party_path, 'onnx', 'third_party',
Expand Down
1 change: 1 addition & 0 deletions third_party/cudnn_frontend
Submodule cudnn_frontend added at 51e60d

0 comments on commit 6c70cbe

Please sign in to comment.