Skip to content

Commit

Permalink
CUDA support in the CSR layout: CUDA addmm/matvec (pytorch#59012)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#59012

Test Plan: Imported from OSS

Reviewed By: zou3519

Differential Revision: D28719631

Pulled By: bhosmer

fbshipit-source-id: 43e2004a61e114aeb0a7c6ad8a25fedda238c6da
  • Loading branch information
aocsa authored and facebook-github-bot committed Jun 2, 2021
1 parent 3efefc4 commit 2d8f0d9
Show file tree
Hide file tree
Showing 7 changed files with 253 additions and 113 deletions.
9 changes: 5 additions & 4 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2884,15 +2884,15 @@
structured_delegate: mm.out
variants: function, method
dispatch:
SparseCPU, SparseCUDA, SparseCsrCPU: _sparse_mm
SparseCPU, SparseCUDA, SparseCsrCPU, SparseCsrCUDA: _sparse_mm

- func: mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
structured: True
dispatch:
CPU: mm_out_cpu
CUDA: mm_out_cuda
SparseCPU, SparseCUDA: _sparse_mm_out
SparseCsrCPU: _sparse_csr_mm_out
SparseCsrCPU, SparseCsrCUDA: _sparse_csr_mm_out

- func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor

Expand Down Expand Up @@ -2978,7 +2978,7 @@
variants: function, method
dispatch:
CPU, CUDA: mv
SparseCPU, SparseCUDA, SparseCsrCPU: mv_sparse
SparseCPU, SparseCUDA, SparseCsrCPU, SparseCsrCUDA: mv_sparse

- func: mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
Expand Down Expand Up @@ -4688,14 +4688,15 @@
SparseCPU: addmm_out_sparse_dense_cpu
SparseCUDA: addmm_out_sparse_dense_cuda
SparseCsrCPU: addmm_out_sparse_csr_dense_cpu
SparseCsrCUDA: addmm_out_sparse_csr_dense_cuda

- func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
structured_delegate: addmm.out
variants: function, method
dispatch:
SparseCPU: addmm_sparse_dense_cpu
SparseCUDA: addmm_sparse_dense_cuda
SparseCsrCPU: addmm_sparse_csr_dense_cpu
SparseCsrCPU, SparseCsrCUDA: addmm_sparse_csr_dense

- func: addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
structured_delegate: addmm.out
Expand Down
28 changes: 13 additions & 15 deletions aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,16 +97,16 @@ Tensor& addmm_out_sparse_csr_dense_cpu(
TORCH_INTERNAL_ASSERT(sparse.is_sparse_csr());
Tensor t = *expand_size(self, {sparse.size(0), dense.size(1)}, "addmm_out_sparse_csr");

TORCH_INTERNAL_ASSERT(t.device().type() == kCPU);
TORCH_CHECK(!t.is_cuda(), "Expected all tensors to be on the same device. addmm expected 't' to be CPU tensor, but got CUDA tensor");
TORCH_CHECK(
r.device().type() == kCPU,
"addmm: expected 'out' to be CPU tensor, but got CUDA tensor");
!r.is_cuda(),
"Expected all tensors to be on the same device. addmm: expected 'out' to be CPU tensor, but got CUDA tensor");
TORCH_CHECK(
sparse.device().type() == kCPU,
"addmm: expected 'mat1' to be a CPU tensor, but got a CUDA tensor");
!sparse.is_cuda(),
"Expected all tensors to be on the same device. addmm: expected 'mat1' to be a CPU tensor, but got a CUDA tensor");
TORCH_CHECK(
dense.device().type() == kCPU,
"addmm: expected 'mat2' to be a CPU tensor, but got a CUDA tensor");
!dense.is_cuda(),
"Expected all tensors to be on the same device. addmm: expected 'mat2' to be a CPU tensor, but got a CUDA tensor");

TORCH_CHECK(
sparse.dim() == 2,
Expand Down Expand Up @@ -135,18 +135,16 @@ Tensor& addmm_out_sparse_csr_dense_cpu(
dim_j,
", got ",
dense.size(0));
TORCH_CHECK(
sparse.size(1) == dim_j,
"addmm: Expected sparse matrix (op1) size(1)=",
dim_j,
", got ",
sparse.size(1));

resize_output(r, {dim_i, dim_k});
auto col_indices = sparse.col_indices();
auto crow_indices = sparse.crow_indices();
auto values = sparse.values();
int64_t nnz = sparse._nnz();

if (nnz == 0) {
at::mul_out(r, t, at::scalar_tensor(beta, r.options()));
return r;
}
// Do not use MKL for Windows due to linking issues with sparse MKL routines.
if (at::hasMKL() && is_mkl_supported() && is_square_or_vec(dim_i, dim_j, dim_k)) {
AT_DISPATCH_FLOATING_TYPES(values.scalar_type(), "addmm_sparse_dense", [&] {
Expand All @@ -172,7 +170,7 @@ Tensor& addmm_out_sparse_csr_dense_cpu(
return r;
}

Tensor addmm_sparse_csr_dense_cpu(
Tensor addmm_sparse_csr_dense(
const Tensor& self,
const SparseCsrTensor& sparse,
const Tensor& dense,
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/sparse/SparseTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -890,10 +890,10 @@ Tensor& s_addmm_out_sparse_dense_cpu(
const Scalar& alpha
) {
// TODO: This error message seems awfully opaque
AT_ASSERT(!t.is_cuda());
TORCH_CHECK(!r.is_cuda(), "addmm: expected 'out' to be CPU tensor, but got CUDA tensor");
TORCH_CHECK(!sparse_.is_cuda(), "addmm: expected 'mat1' to be a CPU tensor, but got a CUDA tensor");
TORCH_CHECK(!dense.is_cuda(), "addmm: expected 'mat2' to be a CPU tensor, but got a CUDA tensor");
TORCH_CHECK(!t.is_cuda(), "Expected all tensors to be on the same device. addmm expected 't' to be CPU tensor, but got CUDA tensor");
TORCH_CHECK(!r.is_cuda(), "Expected all tensors to be on the same device. addmm: expected 'out' to be CPU tensor, but got CUDA tensor");
TORCH_CHECK(!sparse_.is_cuda(), "Expected all tensors to be on the same device. addmm: expected 'mat1' to be a CPU tensor, but got a CUDA tensor");
TORCH_CHECK(!dense.is_cuda(), "Expected all tensors to be on the same device. addmm: expected 'mat2' to be a CPU tensor, but got a CUDA tensor");

TORCH_CHECK(sparse_.sparse_dim() == 2, "addmm: matrices expected, got ", sparse_.sparse_dim(), "D tensor");
TORCH_CHECK(sparse_.dense_dim() == 0, "addmm: scalar values expected, got ", sparse_.dense_dim(), "D values");
Expand Down
142 changes: 73 additions & 69 deletions aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include <ATen/native/sparse/cuda/SparseCUDATensorMath.cuh>

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/NativeFunctions.h>
Expand Down Expand Up @@ -49,79 +51,85 @@ namespace {
}
}

void s_addmm_out_csr_sparse_dense_cuda_worker(int64_t nnz, int64_t m, int64_t n, int64_t k, Tensor& r_, const Scalar& beta, const Tensor& t, const Scalar& alpha, Tensor& crow_indices, Tensor& col_indices, Tensor& values, const Tensor& dense) {
TORCH_INTERNAL_ASSERT(nnz > 0);

// No half support, so we don't have to use CUDATypeConversion
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
values.scalar_type(), "addmm_sparse_cuda", [&] {
scalar_t cast_beta = beta.to<scalar_t>();
scalar_t cast_alpha = alpha.to<scalar_t>();
Tensor r__;
if (cast_beta == scalar_t(0)) {
r_.zero_();
} else if (!is_same_tensor(t, r_)) {
r_.copy_(t);
}
if(r_.stride(0) == 1 && r_.stride(1) == r_.size(0)) {
r__ = r_;
} else {
// Note: This storage arrangement is preferred due to most of the CUDA kernels handle only contiguous tensors
r__ = r_.transpose(0, 1).clone(at::MemoryFormat::Contiguous);
r__.transpose_(0, 1);
}
Tensor dense_;
char transpose_dense;
if(dense.stride(0) == 1 && dense.stride(1) == dense.size(0)) {
transpose_dense = 'n';
dense_ = dense;
} else if(dense.stride(1) == 1 && dense.stride(0) == dense.size(1)) {
transpose_dense = 't';
dense_ = dense;
} else {
transpose_dense = 't';
dense_ = dense.contiguous();
}

sparse::cuda::csrmm2(
'n',
transpose_dense,
m,
n,
k,
nnz,
cast_alpha,
values.data_ptr<scalar_t>(),
crow_indices.data_ptr<int32_t>(),
col_indices.data_ptr<int32_t>(),
dense_.data_ptr<scalar_t>(),
(transpose_dense == 'n' ? dense_.stride(1) : dense_.stride(0)),
cast_beta,
r__.data_ptr<scalar_t>(),
r__.stride(1));

if (!is_same_tensor(r__, r_)) {
r_.copy_(r__);
}
}
);
}

// NB: Deleted spaddcmul (aka addcmul_, but not actually wired up), spaddcdiv (not
// wired at all)

template <typename scalar_t>
void s_addmm_out_sparse_dense_cuda_worker(int64_t nnz, int64_t m, int64_t n, int64_t k, Tensor& r_, const Scalar& beta, const Tensor& t, const Scalar& alpha, Tensor& indices, Tensor& values, const Tensor& dense) {
scalar_t cast_beta = beta.to<scalar_t>();
scalar_t cast_alpha = alpha.to<scalar_t>();
Tensor rowIndices = indices.select(0, 0);
Tensor colIndices = indices.select(0, 1);
Tensor csr = _to_csr_int(rowIndices, m, nnz);
Tensor colIndicesInt = at::empty({colIndices.size(0)}, indices.options().dtype(kInt));
colIndicesInt.copy_(colIndices);

Tensor r__;
if (cast_beta == scalar_t(0)) {
r_.zero_();
} else if (!is_same_tensor(t, r_)) {
r_.copy_(t);
}

if(r_.stride(0) == 1 && r_.stride(1) == r_.size(0)) {
r__ = r_;
} else {
// TODO: how... strange
r__ = r_.transpose(0, 1).clone(at::MemoryFormat::Contiguous);
r__.transpose_(0, 1);
}

if (nnz > 0) {
Tensor dense_;
char transpose_dense;
if(dense.stride(0) == 1 && dense.stride(1) == dense.size(0)) {
transpose_dense = 'n';
dense_ = dense;
} else if(dense.stride(1) == 1 && dense.stride(0) != dense.size(1)) {
transpose_dense = 't';
dense_ = dense;
} else {
transpose_dense = 't';
dense_ = dense.contiguous();
}

sparse::cuda::csrmm2(
'n',
transpose_dense,
m,
n,
k,
nnz,
cast_alpha,
values.data_ptr<scalar_t>(),
csr.data_ptr<int32_t>(),
colIndicesInt.data_ptr<int32_t>(),
dense_.data_ptr<scalar_t>(),
(transpose_dense == 'n' ? dense_.stride(1) : dense_.stride(0)),
cast_beta,
r__.data_ptr<scalar_t>(),
r__.stride(1));
}
if (!is_same_tensor(r__, r_)) {
r_.copy_(r__);
}
Tensor crow_indices = _to_csr_int(rowIndices, m, nnz);
Tensor col_indices = at::empty({colIndices.size(0)}, indices.options().dtype(kInt));
col_indices.copy_(colIndices);
s_addmm_out_csr_sparse_dense_cuda_worker(nnz, m, n, k, r_, beta, t, alpha, crow_indices, col_indices, values, dense);
}

// --------------------------------------------------------------------
// addmm(Tensor, SparseTensor, Tensor, Scalar, Scalar) [broadcasts]
// --------------------------------------------------------------------

Tensor& s_addmm_out_sparse_dense_cuda(Tensor& r_, const Tensor& t, const SparseTensor& sparse_, const Tensor& dense, const Scalar& beta, const Scalar& alpha) {
TORCH_CHECK(t.is_cuda(), "addmm: expected 'self' to be CUDA, but got CPU");
TORCH_CHECK(r_.is_cuda(), "addmm: expected 'out' to be CUDA, but got CPU");
TORCH_CHECK(sparse_.is_cuda(), "addmm: expected 'mat1' to be CUDA, but got CPU");
TORCH_CHECK(dense.is_cuda(), "addmm: expected 'mat2' to be CUDA, but got CPU");
TORCH_CHECK(t.is_cuda(), "Expected all tensors to be on the same device. addmm: expected 'self' to be CUDA, but got CPU");
TORCH_CHECK(r_.is_cuda(), "Expected all tensors to be on the same device. addmm: expected 'out' to be CUDA, but got CPU");
TORCH_CHECK(sparse_.is_cuda(), "Expected all tensors to be on the same device. addmm: expected 'mat1' to be CUDA, but got CPU");
TORCH_CHECK(dense.is_cuda(), "Expected all tensors to be on the same device. addmm: expected 'mat2' to be CUDA, but got CPU");

TORCH_CHECK(cuda::check_device({sparse_, r_, t, dense}));

Expand All @@ -148,15 +156,11 @@ Tensor& s_addmm_out_sparse_dense_cuda(Tensor& r_, const Tensor& t, const SparseT
int64_t nnz = sparse._nnz();
Tensor indices = sparse._indices();
Tensor values = sparse._values();


// No half support, so we don't have to use CUDATypeConversion
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
values.scalar_type(), "addmm_sparse_cuda", [&] {
s_addmm_out_sparse_dense_cuda_worker<scalar_t>(nnz, m, n, k, r_, beta, t, alpha, indices, values, dense);
}
);

if (nnz == 0) {
at::mul_out(r_, t, at::scalar_tensor(beta, r_.options()));
return r_;
}
s_addmm_out_sparse_dense_cuda_worker(nnz, m, n, k, r_, beta, t, alpha, indices, values, dense);
return r_;
}

Expand Down
13 changes: 13 additions & 0 deletions aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#pragma once

#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <c10/macros/Macros.h>

namespace at { namespace native {

void s_addmm_out_csr_sparse_dense_cuda_worker(int64_t nnz, int64_t m, int64_t n, int64_t k, Tensor& r_, const Scalar& beta, const Tensor& t, const Scalar& alpha, Tensor& crow_indices, Tensor& col_indices, Tensor& values, const Tensor& dense);

void s_addmm_out_sparse_dense_cuda_worker(int64_t nnz, int64_t m, int64_t n, int64_t k, Tensor& r_, const Scalar& beta, const Tensor& t, const Scalar& alpha, Tensor& indices, Tensor& values, const Tensor& dense);

}} // namespace at::native
72 changes: 71 additions & 1 deletion aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAUtils.h>
#include <c10/cuda/CUDACachingAllocator.h>

#include <ATen/native/sparse/cuda/SparseCUDABlas.cuh>
#include <ATen/native/sparse/cuda/SparseCUDATensorMath.cuh>

#include <thrust/device_ptr.h>
#include <thrust/execution_policy.h>
Expand All @@ -33,6 +35,74 @@ using namespace at::sparse_csr;
// certain utiliy functions are usable from sparse COO.
using namespace at::sparse;

Tensor& addmm_out_sparse_csr_dense_cuda(
const Tensor& self,
const SparseCsrTensor& sparse,
const Tensor& dense,
const Scalar& beta,
const Scalar& alpha,
Tensor& r)
{

TORCH_INTERNAL_ASSERT(sparse.is_sparse_csr());
Tensor t = *expand_size(self, {sparse.size(0), dense.size(1)}, "addmm_out_sparse_csr");

TORCH_CHECK(t.is_cuda(), "Expected all tensors to be on the same device. addmm expected 't' to be CUDA tensor");
TORCH_CHECK(
r.is_cuda(),
"Expected all tensors to be on the same device. addmm: expected 'out' to be CUDA tensor, but got CPU tensor");
TORCH_CHECK(
sparse.is_cuda(),
"Expected all tensors to be on the same device. addmm: expected 'mat1' to be a CUDA tensor, but got a CPU tensor");
TORCH_CHECK(
dense.is_cuda(),
"Expected all tensors to be on the same device. addmm: expected 'mat2' to be a CUDA tensor, but got a CPU tensor");

TORCH_CHECK(
sparse.dim() == 2,
"addmm: 2-D matrices expected, got ",
sparse.dim(),
"D tensor");
TORCH_CHECK(
dense.dim() == 2,
"addmm: 2-D matrices expected, got ",
dense.dim(),
"D tensor");

TORCH_CHECK(
r.is_contiguous(),
"out argument must be contiguous, but got: ",
r.suggest_memory_format());

// mxk * kxn = mxn
int64_t m = sparse.size(0);
int64_t k = sparse.size(1);
int64_t n = dense.size(1);

TORCH_CHECK(
dense.size(0) == k,
"addmm: Expected dense matrix (dense) size(0)=",
k,
", got ",
dense.size(0));

resize_output(r, {m, n});
int64_t nnz = sparse._nnz();

if (nnz == 0) {
at::mul_out(r, t, at::scalar_tensor(beta, r.options()));
return r;
}
// TODO: Check if cusparseSpMM can use 64-bit indices
// https://docs.nvidia.com/cuda/cusparse/index.html
auto col_indices = sparse.col_indices().to(at::kInt);
auto crow_indices = sparse.crow_indices().to(at::kInt);
auto values = sparse.values();

s_addmm_out_csr_sparse_dense_cuda_worker(nnz, m, n, k, r, beta, t, alpha, crow_indices, col_indices, values, dense);
return r;
}

Tensor& add_out_dense_sparse_csr_cuda(
Tensor& output,
const Tensor& dense,
Expand Down Expand Up @@ -62,7 +132,7 @@ Tensor& add_out_dense_sparse_csr_cuda(
dense.sizes(),
" while other has size ",
src.sizes(),
" (FYI: op2-sparse addition does not currently support broadcasting)");
" (FYI: dense-sparse addition does not currently support broadcasting)");

auto commonDtype = promoteTypes(dense.scalar_type(), src.scalar_type());
TORCH_CHECK(
Expand Down
Loading

0 comments on commit 2d8f0d9

Please sign in to comment.