Skip to content

Commit

Permalink
Migrate thnn_conv_depthwise2d from THC to ATen (pytorch#62006)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#62006

Closes pytorchgh-24646, pytorchgh-24647

There is no `TensorIterator` equivalent to these kernels so this is just
migrating the existing kernels over to the ATen style.

I've benchmarked for contiguous tensors with this script:
```
import torch
shape = (10, 10, 100, 100)
x = torch.randn(*shape, device='cuda')
w = torch.randn((10, 1, 5, 5), device='cuda')

for _ in range(100):
    torch.nn.functional.conv2d(x, w, groups=10)
```

and similarly for backwards. I see these as the same to within measurement error.

|                   | Master Forward (us) | This PR Forward (us) |
|------------------:|:-------------------:|:--------------------:|
|           Forward |        133.5        |         133.6        |
|  Backward (input) |        1,102        |         1,119        |
| Backward (weight) |        2,220        |         2,217        |

Test Plan: Imported from OSS

Reviewed By: jbschlosser

Differential Revision: D29883676

Pulled By: ngimel

fbshipit-source-id: 9b2ac62cdd8a84e1a23ffcd66035b2b2fe2374d8
  • Loading branch information
peterbell10 authored and facebook-github-bot committed Jul 27, 2021
1 parent 9df6051 commit de3a4eb
Show file tree
Hide file tree
Showing 14 changed files with 608 additions and 892 deletions.
1 change: 0 additions & 1 deletion BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,6 @@ filegroup(
"aten/src/THCUNN/SoftShrink.cu.cc",
"aten/src/THCUNN/SpatialClassNLLCriterion.cu.cc",
"aten/src/THCUNN/SpatialConvolutionMM.cu.cc",
"aten/src/THCUNN/SpatialDepthwiseConvolution.cu.cc",
"aten/src/THCUNN/Tanh.cu.cc",
],
)
Expand Down
248 changes: 0 additions & 248 deletions aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions aten/src/ATen/native/ConvUtils.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#pragma once
#include <ATen/detail/CUDAHooksInterface.h>

namespace at { namespace native {

Expand Down Expand Up @@ -83,12 +84,12 @@ static inline Tensor reshape_bias(int64_t dim, const Tensor& bias) {

static inline bool cudnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
// disable NHWC for float64 input.
if (!detail::getCUDAHooks().compiledWithCuDNN() ||
if (!at::detail::getCUDAHooks().compiledWithCuDNN() ||
input.scalar_type() == at::kDouble ||
weight.scalar_type() == at::kDouble) {
return false;
}
long cudnn_version = detail::getCUDAHooks().versionCuDNN();
long cudnn_version = at::detail::getCUDAHooks().versionCuDNN();
auto input_memory_format = input.suggest_memory_format();
auto weight_memory_format = weight.suggest_memory_format();

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ at::Tensor _convolution(
padding, stride, dilation, params.groups, params.benchmark, params.deterministic);
} else {
if (input.ndimension() == 4) {
output = at::thnn_conv_depthwise2d(input.contiguous(), weight, kernel_size, bias, stride, padding, dilation);
output = at::_conv_depthwise2d(input.contiguous(), weight, kernel_size, bias, stride, padding, dilation);
}
else {
TORCH_CHECK(input.ndimension() == 5);
Expand Down
23 changes: 0 additions & 23 deletions aten/src/ATen/native/LegacyNNDefinitions.cpp

This file was deleted.

Loading

0 comments on commit de3a4eb

Please sign in to comment.