Skip to content

Commit

Permalink
Revert "Use CUTLASS GEMM for NT bmm [OSS-only] (pytorch#85894)"
Browse files Browse the repository at this point in the history
This reverts commit ef58a13.

Reverted pytorch#85894 on behalf of https://github.com/DanilBaibak due to Break internal build
  • Loading branch information
pytorchmergebot committed Oct 13, 2022
1 parent b97ae59 commit d169f95
Show file tree
Hide file tree
Showing 10 changed files with 40 additions and 390 deletions.
2 changes: 0 additions & 2 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,6 @@ cu_library(
"@cuda//:cublas",
"@cuda//:cufft",
"@cuda//:cusparse",
"@cutlass",
],
alwayslink = True,
)
Expand Down Expand Up @@ -1674,7 +1673,6 @@ cc_library(
] + if_cuda([
":torch_distributed_cuda",
"@cuda//:nvToolsExt",
"@cutlass",
]),
alwayslink = True,
)
Expand Down
6 changes: 0 additions & 6 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,6 @@ new_local_repository(
path = "third_party/eigen",
)

new_local_repository(
name = "cutlass",
build_file = "//third_party:cutlass.BUILD",
path = "third_party/cutlass",
)

new_local_repository(
name = "fbgemm",
build_file = "//third_party:fbgemm/BUILD.bazel",
Expand Down
4 changes: 3 additions & 1 deletion aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,9 @@ if(NOT MSVC AND NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE)
endif()

if(USE_CUDA AND NOT USE_ROCM)
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
if(USE_FLASH_ATTENTION)
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
endif()
if($ENV{ATEN_STATIC_CUDA})
list(APPEND ATen_CUDA_DEPENDENCY_LIBS
${CUDA_LIBRARIES}
Expand Down
3 changes: 1 addition & 2 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1174,8 +1174,7 @@
dispatch:
SparseCPU: bmm_sparse_cpu
SparseCUDA: bmm_sparse_cuda
NestedTensorCPU: bmm_nested
NestedTensorCUDA: bmm_nested_cuda
NestedTensorCPU, NestedTensorCUDA: bmm_nested
tags: canonical

- func: bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
Expand Down
22 changes: 22 additions & 0 deletions aten/src/ATen/native/nested/NestedTensorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,27 @@ std::vector<Tensor> chunk_nested_tensor(const Tensor& self, int64_t chunks, int6
return splits;
}

std::vector<IntArrayRef> NestedTensor_get_sizes(
const NestedTensorImpl* self_ptr) {
int64_t ntensors = self_ptr->size(0);
std::vector<IntArrayRef> sizes(ntensors);
if (ntensors == 0) {
return sizes;
}
const Tensor& sizemat = self_ptr->get_nested_size_tensor();
int64_t orig_dim = sizemat.size(1);
// nesting scalars has empty sizes
if (orig_dim == 0) {
return sizes;
}
const int64_t* sizemat_ptr = sizemat.data_ptr<int64_t>();

for (const auto i : c10::irange(ntensors)) {
sizes[i] = IntArrayRef(sizemat_ptr, sizemat_ptr + orig_dim);
sizemat_ptr += orig_dim;
}
return sizes;
}

} // namespace native
} // namespace at
24 changes: 2 additions & 22 deletions aten/src/ATen/native/nested/NestedTensorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,28 +86,8 @@ inline at::Tensor create_nested_view_tensor(
int64_t get_consistent_last_dim_of_nested_tensor(const NestedTensorImpl& nt);

// The sizes of the underlying tensors
inline std::vector<IntArrayRef> NestedTensor_get_sizes(
const NestedTensorImpl* self_ptr) {
int64_t ntensors = self_ptr->size(0);
std::vector<IntArrayRef> sizes(ntensors);
if (ntensors == 0) {
return sizes;
}
const Tensor& sizemat = self_ptr->get_nested_size_tensor();
int64_t orig_dim = sizemat.size(1);
// nesting scalars has empty sizes
if (orig_dim == 0) {
return sizes;
}
const int64_t* sizemat_ptr = sizemat.data_ptr<int64_t>();

for (const auto i : c10::irange(ntensors)) {
sizes[i] = IntArrayRef(sizemat_ptr, sizemat_ptr + orig_dim);
sizemat_ptr += orig_dim;
}
return sizes;
}

std::vector<IntArrayRef> NestedTensor_get_sizes(
const NestedTensorImpl* self_ptr);

TORCH_API std::vector<int64_t> NestedTensor_get_max_size(
const NestedTensorImpl& nt);
Expand Down
Loading

0 comments on commit d169f95

Please sign in to comment.