Skip to content

Commit

Permalink
[cuDNN V8 API] (reopen) Allow the number of kernels profiled under to…
Browse files Browse the repository at this point in the history
…rch.backends.cudnn.benchmark = True to be limitedCudnnv8 benchmark limit (pytorch#77002)

(reopening due to botched merge)
The cuDNN V8 API (main support merged in pytorch#60755) potentially exposes many more kernels with benchmark=True. While these additional kernels can improve performance, it is often unnecessary to run every kernel returned by the heuristic and doing so may degrade the user experience by causing the first model iteration to be very slow. To alleviate this issue, this PR introduces torch.backends.cudnn.benchmark_limit. benchmark_limit specifies the maximum number of working cuDNN kernels to try for a given workload, with the default being 10 (similar to what TensorFlow does). benchmark_limit = 0 yields the current behavior of trying every kernel returned by the heuristic.

CC @ptrblck @ngimel @xwang233
Pull Request resolved: pytorch#77002
Approved by: https://github.com/ngimel
  • Loading branch information
eqy authored and pytorchmergebot committed May 24, 2022
1 parent c500897 commit c274f2a
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 5 deletions.
8 changes: 8 additions & 0 deletions aten/src/ATen/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,14 @@ void Context::setBenchmarkCuDNN(bool b) {
benchmark_cudnn = b;
}

int Context::benchmarkLimitCuDNN() const {
return benchmark_limit_cudnn;
}

void Context::setBenchmarkLimitCuDNN(int b) {
benchmark_limit_cudnn = b;
}

bool Context::allowTF32CuBLAS() const {
static bool allow_tf32_cublas_override = c10::utils::check_env("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE") == true;
return allow_tf32_cublas_override || float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST;
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ class TORCH_API Context {
void setUserEnabledMkldnn(bool e);
bool benchmarkCuDNN() const;
void setBenchmarkCuDNN(bool);
int benchmarkLimitCuDNN() const;
void setBenchmarkLimitCuDNN(int);
bool deterministicCuDNN() const;
void setDeterministicCuDNN(bool);

Expand Down Expand Up @@ -251,6 +253,7 @@ class TORCH_API Context {
bool _deterministic_algorithms_warn_only = false;
bool benchmark_cudnn = false;
Float32MatmulPrecision float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST;
int benchmark_limit_cudnn = 10;
bool allow_tf32_cudnn = true;
bool allow_fp16_reduction_cublas = true;
bool enabled_mkldnn = true;
Expand Down
6 changes: 4 additions & 2 deletions aten/src/ATen/native/cudnn/Conv_v8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,8 @@ auto get_plans_from_find(const cudnnHandle_t handle, const cudnnBackendDescripto
cudnn_frontend::executionPlans_t valid_plans;
c10::DeviceGuard g(x.options().device());
at::DataPtr workspace_ptr;
generate_and_filter_plans(handle, opGraph, generator, x, valid_plans, workspace_ptr);
auto benchmark_limit = at::globalContext().benchmarkLimitCuDNN();
generate_and_filter_plans(handle, opGraph, generator, x, valid_plans, workspace_ptr, benchmark_limit);
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setDataPointers(3, data_ptrs)
.setUids(3, uids)
Expand Down Expand Up @@ -389,7 +390,8 @@ auto get_plans_from_find_fused(const cudnnHandle_t handle,
cudnn_frontend::executionPlans_t valid_plans;
c10::DeviceGuard g(x.options().device());
at::DataPtr workspace_ptr;
generate_and_filter_plans(handle, opGraph, generator, x, valid_plans, workspace_ptr);
auto benchmark_limit = at::globalContext().benchmarkLimitCuDNN();
generate_and_filter_plans(handle, opGraph, generator, x, valid_plans, workspace_ptr, benchmark_limit);
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setDataPointers(5, data_ptrs)
.setUids(5, uids)
Expand Down
8 changes: 8 additions & 0 deletions docs/source/backends.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ torch.backends.cudnn
A :class:`bool` that, if True, causes cuDNN to benchmark multiple convolution algorithms
and select the fastest.

.. attribute:: torch.backends.cudnn.benchmark_limit

A :class:`int` that specifies the maximum number of cuDNN convolution algorithms to try when
`torch.backends.cudnn.benchmark` is True. Set `benchmark_limit` to zero to try every
available algorithm. Note that this setting only affects convolutions dispatched via the
cuDNN v8 API.


torch.backends.mps
^^^^^^^^^^^^^^^^^^
.. automodule:: torch.backends.mps
Expand Down
4 changes: 4 additions & 0 deletions torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ if(USE_ROCM)
list(APPEND TORCH_PYTHON_INCLUDE_DIRECTORIES ${roctracer_INCLUDE_DIRS})
endif()

if(USE_EXPERIMENTAL_CUDNN_V8_API)
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_EXPERIMENTAL_CUDNN_V8_API)
endif()

if(USE_CUDNN OR USE_ROCM)
list(APPEND TORCH_PYTHON_SRCS
${TORCH_SRC_DIR}/csrc/cuda/shared/cudnn.cpp
Expand Down
2 changes: 2 additions & 0 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,8 @@ def _get_mkldnn_enabled() -> _bool: ... # THPModule_userEnabledMkldnn
def _set_mkldnn_enabled(arg: _bool) -> None: ... # THPModule_setUserEnabledMkldnn
def _get_cudnn_benchmark() -> _bool: ... # THPModule_benchmarkCuDNN
def _set_cudnn_benchmark(arg: _bool) -> None: ... # THPModule_setBenchmarkCuDNN
def _get_cudnn_benchmark_limit() -> _int: ... # THPModule_benchmarkLimitCuDNN
def _set_cudnn_benchmark_limit(arg: _int) -> None: ... # THPModule_setBenchmarkLimitCuDNN
def _get_cudnn_deterministic() -> _bool: ... # THPModule_deterministicCuDNN
def _set_cudnn_deterministic(arg: _bool) -> None: ... # THPModule_setDeterministicCuDNN
def _get_deterministic_algorithms() -> _bool: ... # THPModule_deterministicAlgorithms
Expand Down
10 changes: 7 additions & 3 deletions torch/backends/cudnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,18 @@ def is_acceptable(tensor):
return True


def set_flags(_enabled=None, _benchmark=None, _deterministic=None, _allow_tf32=None):
def set_flags(_enabled=None, _benchmark=None, _benchmark_limit=None, _deterministic=None, _allow_tf32=None):
orig_flags = (torch._C._get_cudnn_enabled(),
torch._C._get_cudnn_benchmark(),
torch._C._get_cudnn_benchmark_limit(),
torch._C._get_cudnn_deterministic(),
torch._C._get_cudnn_allow_tf32())
if _enabled is not None:
torch._C._set_cudnn_enabled(_enabled)
if _benchmark is not None:
torch._C._set_cudnn_benchmark(_benchmark)
if _benchmark_limit is not None:
torch._C._set_cudnn_benchmark_limit(_benchmark_limit)
if _deterministic is not None:
torch._C._set_cudnn_deterministic(_deterministic)
if _allow_tf32 is not None:
Expand All @@ -101,9 +104,9 @@ def set_flags(_enabled=None, _benchmark=None, _deterministic=None, _allow_tf32=N


@contextmanager
def flags(enabled=False, benchmark=False, deterministic=False, allow_tf32=True):
def flags(enabled=False, benchmark=False, benchmark_limit=10, deterministic=False, allow_tf32=True):
with __allow_nonbracketed_mutation():
orig_flags = set_flags(enabled, benchmark, deterministic, allow_tf32)
orig_flags = set_flags(enabled, benchmark, benchmark_limit, deterministic, allow_tf32)
try:
yield
finally:
Expand All @@ -123,6 +126,7 @@ def __init__(self, m, name):
enabled = ContextProp(torch._C._get_cudnn_enabled, torch._C._set_cudnn_enabled)
deterministic = ContextProp(torch._C._get_cudnn_deterministic, torch._C._set_cudnn_deterministic)
benchmark = ContextProp(torch._C._get_cudnn_benchmark, torch._C._set_cudnn_benchmark)
benchmark_limit = ContextProp(torch._C._get_cudnn_benchmark_limit, torch._C._set_cudnn_benchmark_limit)
allow_tf32 = ContextProp(torch._C._get_cudnn_allow_tf32, torch._C._set_cudnn_allow_tf32)

# This is the sys.modules replacement trick, see
Expand Down
31 changes: 31 additions & 0 deletions torch/csrc/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@

#include <ATen/ATen.h>
#include <ATen/native/ConvUtils.h>
#include <ATen/cuda/CUDAConfig.h>
#if AT_CUDNN_ENABLED()

#include <ATen/native/cudnn/Macros.h>

#endif
#include <ATen/DLConvertor.h>
#include <ATen/ExpandUtils.h>
#include <ATen/LinalgBackend.h>
Expand Down Expand Up @@ -520,6 +526,29 @@ PyObject *THPModule_benchmarkCuDNN(PyObject *_unused, PyObject *noargs)
Py_RETURN_FALSE;
}

PyObject *THPModule_setBenchmarkLimitCuDNN(PyObject *_unused, PyObject *arg)
{
THPUtils_assert(THPUtils_checkLong(arg), "set_benchmark_limit_cudnn expects an int, "
"but got %s", THPUtils_typename(arg));
auto benchmark_limit = static_cast<int>(THPUtils_unpackLong(arg));
#if defined(USE_ROCM)
TORCH_WARN_ONCE("cuDNN Benchmark limit is not supported in MIOpen and will have no effect.");
#endif
#if AT_CUDNN_ENABLED()
#if HAS_CUDNN_V8()
at::globalContext().setBenchmarkLimitCuDNN(benchmark_limit);
#else
TORCH_WARN_ONCE("cuDNN Benchmark limit is not supported with cuDNN v7 API and will have no effect.");
#endif
#endif
Py_RETURN_NONE;
}

PyObject *THPModule_benchmarkLimitCuDNN(PyObject *_unused, PyObject *noargs)
{
return THPUtils_packInt32(at::globalContext().benchmarkLimitCuDNN());
}

PyObject *THPModule_setAllowTF32CuBLAS(PyObject *_unused, PyObject *arg)
{
THPUtils_assert(PyBool_Check(arg), "set_allow_tf32_cublas expects a bool, "
Expand Down Expand Up @@ -692,6 +721,8 @@ static PyMethodDef TorchMethods[] = {
{"_set_cudnn_allow_tf32", THPModule_setAllowTF32CuDNN, METH_O, nullptr},
{"_get_cudnn_benchmark", THPModule_benchmarkCuDNN, METH_NOARGS, nullptr},
{"_set_cudnn_benchmark", THPModule_setBenchmarkCuDNN, METH_O, nullptr},
{"_get_cudnn_benchmark_limit", THPModule_benchmarkLimitCuDNN, METH_NOARGS, nullptr},
{"_set_cudnn_benchmark_limit", THPModule_setBenchmarkLimitCuDNN, METH_O, nullptr},
{"_get_cudnn_deterministic", THPModule_deterministicCuDNN, METH_NOARGS, nullptr},
{"_set_cudnn_deterministic", THPModule_setDeterministicCuDNN, METH_O, nullptr},
{"_get_deterministic_algorithms", THPModule_deterministicAlgorithms, METH_NOARGS, nullptr},
Expand Down

0 comments on commit c274f2a

Please sign in to comment.