forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
CUDA support in the CSR layout: CUDA addmm/matvec (pytorch#59012)
Summary: Pull Request resolved: pytorch#59012 Test Plan: Imported from OSS Reviewed By: zou3519 Differential Revision: D28719631 Pulled By: bhosmer fbshipit-source-id: 43e2004a61e114aeb0a7c6ad8a25fedda238c6da
- Loading branch information
1 parent
3efefc4
commit 2d8f0d9
Showing
7 changed files
with
253 additions
and
113 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
#pragma once | ||
|
||
#include <ATen/cuda/detail/TensorInfo.cuh> | ||
#include <ATen/cuda/CUDAApplyUtils.cuh> | ||
#include <c10/macros/Macros.h> | ||
|
||
namespace at { namespace native { | ||
|
||
void s_addmm_out_csr_sparse_dense_cuda_worker(int64_t nnz, int64_t m, int64_t n, int64_t k, Tensor& r_, const Scalar& beta, const Tensor& t, const Scalar& alpha, Tensor& crow_indices, Tensor& col_indices, Tensor& values, const Tensor& dense); | ||
|
||
void s_addmm_out_sparse_dense_cuda_worker(int64_t nnz, int64_t m, int64_t n, int64_t k, Tensor& r_, const Scalar& beta, const Tensor& t, const Scalar& alpha, Tensor& indices, Tensor& values, const Tensor& dense); | ||
|
||
}} // namespace at::native |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.