Skip to content

Commit

Permalink
Merge commit '3f25232aaba44aa4377c7e5ed670587a72f5886e'
Browse files Browse the repository at this point in the history
  • Loading branch information
soumith committed Aug 15, 2017
2 parents 469969e + 3f25232 commit fb5e40f
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 5 deletions.
49 changes: 49 additions & 0 deletions torch/lib/THC/THCBlas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,31 @@ void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, long m,
(int)batchCount));
}

#if CUDA_VERSION >= 8000
void THCudaBlas_SgemmStridedBatched(THCState *state, char transa, char transb, long m, long n, long k,
float alpha, const float *a, long lda, long strideA, const float *b, long ldb, long strideB,
float beta, float *c, long ldc, long strideC, long batchCount)
{
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )

{
THError("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount"
"with the bound [val] <= %d", INT_MAX);
}

adjustLd(transa, transb, m, n, k, &lda, &ldb, &ldc);
cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb);

cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasSgemmStridedBatched(handle,
opa, opb, (int)m, (int)n, (int)k,
&alpha, a, (int)lda, strideA, b, (int)ldb, strideB, &beta, c, (int)ldc, strideC,
(int)batchCount));
}
#endif

void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, long m, long n, long k,
double alpha, const double *a[], long lda, const double *b[], long ldb,
double beta, double *c[], long ldc, long batchCount)
Expand All @@ -366,6 +391,30 @@ void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, long m,
(int)batchCount));
}

#if CUDA_VERSION >= 8000
void THCudaBlas_DgemmStridedBatched(THCState *state, char transa, char transb, long m, long n, long k,
double alpha, const double *a, long lda, long strideA, const double *b, long ldb, long strideB,
double beta, double *c, long ldc, long strideC, long batchCount)
{
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
{
THError("Cublas_DgemmBatched only supports m, n, k, lda, ldb, ldc, batchCount"
"with the bound [val] <= %d", INT_MAX);
}

adjustLd(transa, transb, m, n, k, &lda, &ldb, &ldc);
cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb);

cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasDgemmStridedBatched(handle,
opa, opb, (int)m, (int)n, (int)k,
&alpha, a, (int)lda, strideA, b, (int)ldb, strideB, &beta, c, (int)ldc, strideC,
(int)batchCount));
}
#endif

/* Inverse */
void THCudaBlas_Sgetrf(THCState *state, int n, float **a, int lda, int *pivot, int *info, int batchSize) {
if( (n >= INT_MAX) || (lda >= INT_MAX) || (batchSize >= INT_MAX) )
Expand Down
9 changes: 8 additions & 1 deletion torch/lib/THC/THCBlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,14 @@ THC_API void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb,
THC_API void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, long m, long n, long k,
double alpha, const double *a[], long lda, const double *b[], long ldb,
double beta, double *c[], long ldc, long batchCount);

#if CUDA_VERSION >= 8000
THC_API void THCudaBlas_SgemmStridedBatched(THCState *state, char transa, char transb, long m, long n, long k,
float alpha, const float *a, long lda, long strideA, const float *b, long ldb, long strideB,
float beta, float *c, long ldc, long strideC, long batchCount);
THC_API void THCudaBlas_DgemmStridedBatched(THCState *state, char transa, char transb, long m, long n, long k,
double alpha, const double *a, long lda, long strideA, const double *b, long ldb, long strideB,
double beta, double *c, long ldc, long strideC, long batchCount);
#endif
/* Inverse */
THC_API void THCudaBlas_Sgetrf(THCState *state, int n, float **a, int lda, int *pivot, int *info, int batchSize);
THC_API void THCudaBlas_Dgetrf(THCState *state, int n, double **a, int lda, int *pivot, int *info, int batchSize);
Expand Down
4 changes: 4 additions & 0 deletions torch/lib/THC/THCNumerics.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ struct THCNumerics<char> {
static inline __host__ __device__ bool eq(char a, char b) { return a == b; }
static inline __host__ __device__ bool ne(char a, char b) { return a != b; }

static inline __host__ __device__ char neg(char a) { return -a; }
static inline __host__ __device__ char add(char a, char b) { return a + b; }
static inline __host__ __device__ char mul(char a, char b) { return a * b; }
static inline __host__ __device__ char sub(char a, char b) { return a - b; }
Expand All @@ -63,6 +64,7 @@ struct THCNumerics<short> {
static inline __host__ __device__ bool eq(short a, short b) { return a == b; }
static inline __host__ __device__ bool ne(short a, short b) { return a != b; }

static inline __host__ __device__ short neg(short a) { return -a; }
static inline __host__ __device__ short add(short a, short b) { return a + b; }
static inline __host__ __device__ short mul(short a, short b) { return a * b; }
static inline __host__ __device__ short sub(short a, short b) { return a - b; }
Expand All @@ -82,6 +84,7 @@ struct THCNumerics<int> {
static inline __host__ __device__ bool eq(int a, int b) { return a == b; }
static inline __host__ __device__ bool ne(int a, int b) { return a != b; }

static inline __host__ __device__ int neg(int a) { return -a; }
static inline __host__ __device__ int add(int a, int b) { return a + b; }
static inline __host__ __device__ int mul(int a, int b) { return a * b; }
static inline __host__ __device__ int sub(int a, int b) { return a - b; }
Expand All @@ -101,6 +104,7 @@ struct THCNumerics<long> {
static inline __host__ __device__ bool eq(long a, long b) { return a == b; }
static inline __host__ __device__ bool ne(long a, long b) { return a != b; }

static inline __host__ __device__ long neg(long a) { return -a; }
static inline __host__ __device__ long add(long a, long b) { return a + b; }
static inline __host__ __device__ long mul(long a, long b) { return a * b; }
static inline __host__ __device__ long sub(long a, long b) { return a - b; }
Expand Down
36 changes: 34 additions & 2 deletions torch/lib/THC/generic/THCTensorMathBlas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -537,9 +537,10 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,

#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
// Compute pointers to matrices in each batch.
#if CUDA_VERSION < 8000
size_t matrices_size = num_batches * sizeof(real*);

// Copy pointers to device.
// Copy pointers to device.
const real **d_matrices1, **d_matrices2;
real **d_result_matrices;
THCudaCheck(THCudaMalloc(state, (void**)&d_matrices1, matrices_size));
Expand All @@ -558,7 +559,6 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
createBatchGemmBuffer<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
(const real**)d_result_matrices, THCTensor_(data)(state,result_),
result_->stride[0], num_batches);

#ifdef THC_REAL_IS_FLOAT
THCudaBlas_SgemmBatched(
state,
Expand Down Expand Up @@ -592,6 +592,38 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
THCudaFree(state, d_matrices1);
THCudaFree(state, d_matrices2);
THCudaFree(state, d_result_matrices);

#else
#ifdef THC_REAL_IS_FLOAT
THCudaBlas_SgemmStridedBatched(
state,
transpose_batch1,
transpose_batch2,
result_->size[transpose_result ? 2 : 1],
result_->size[transpose_result ? 1 : 2],
batch1_->size[transpose_result ? 1 : 2],
alpha,
THCTensor_(data)(state, batch1_), lda, batch1_->stride[0],
THCTensor_(data)(state, batch2_), ldb, batch2_->stride[0],
beta,
THCTensor_(data)(state, result_), ldc, result_->stride[0],
num_batches);
#elif defined(THC_REAL_IS_DOUBLE)
THCudaBlas_DgemmStridedBatched(
state,
transpose_batch1,
transpose_batch2,
result_->size[transpose_result ? 2 : 1],
result_->size[transpose_result ? 1 : 2],
batch1_->size[transpose_result ? 1 : 2],
alpha,
THCTensor_(data)(state, batch1_), lda, batch1_->stride[0],
THCTensor_(data)(state, batch2_), ldb, batch2_->stride[0],
beta,
THCTensor_(data)(state, result_), ldc, result_->stride[0],
num_batches);
#endif
#endif

#elif defined(THC_REAL_IS_HALF)
// Currently no HgemmBatched in Cublas
Expand Down
8 changes: 7 additions & 1 deletion torch/lib/THC/generic/THCTensorMathPointwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(rsqrt, THCNumerics<real>::rsqrt, Real)
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( ceil, THCNumerics<real>::ceil, Real)
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(floor, THCNumerics<real>::floor, Real)
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(trunc, THCNumerics<real>::trunc, Real)
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( neg, THCNumerics<real>::neg, Real)

IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( acos, THCNumerics<real>::acos, Real)
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( cosh, THCNumerics<real>::cosh, Real)
Expand All @@ -61,6 +60,13 @@ IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( cinv, THCNumerics<real>::cinv, Real)

#endif

#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) || \
defined(THC_REAL_IS_SHORT) || defined(THC_REAL_IS_INT) || defined(THC_REAL_IS_LONG)

IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( neg, THCNumerics<real>::neg, Real)

#endif

IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( abs, THCNumerics<real>::abs, Real)

#undef IMPLEMENT_CUDA_TENSOR_BASIC_FUNC_
Expand Down
8 changes: 7 additions & 1 deletion torch/lib/THC/generic/THCTensorMathPointwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,17 @@ THC_API void THCTensor_(trunc)(THCState *state, THCTensor *self, THCTensor *src)
THC_API void THCTensor_(frac)(THCState *state, THCTensor *self, THCTensor *src);
THC_API void THCTensor_(lerp)(THCState *state, THCTensor *result, THCTensor *a, THCTensor *b, real w);

THC_API void THCTensor_(neg)(THCState *state, THCTensor *self, THCTensor *src);
THC_API void THCTensor_(cinv)(THCState *state, THCTensor *self, THCTensor *src);

#endif

#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) || \
defined(THC_REAL_IS_SHORT) || defined(THC_REAL_IS_INT) || defined(THC_REAL_IS_LONG)

THC_API void THCTensor_(neg)(THCState *state, THCTensor *self, THCTensor *src);

#endif

THC_API void THCTensor_(abs)(THCState *state, THCTensor *self, THCTensor *src);
THC_API void THCTensor_(sign)(THCState *state, THCTensor *self, THCTensor *src);
THC_API void THCTensor_(clamp)(THCState *state, THCTensor *self, THCTensor *src, real min_value, real max_value);
Expand Down

0 comments on commit fb5e40f

Please sign in to comment.