forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcuda_gpu_kernel_helpers.cc
136 lines (119 loc) · 4.76 KB
/
cuda_gpu_kernel_helpers.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
/* Copyright 2019 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/cuda_gpu_kernel_helpers.h"
#include <stdexcept>
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
namespace jax {
namespace {
std::string ErrorToString(cudaError_t error) {
return cudaGetErrorString(error);
}
std::string ErrorToString(cusparseStatus_t status) {
return cusparseGetErrorString(status);
}
std::string ErrorToString(cusolverStatus_t status) {
switch (status) {
case CUSOLVER_STATUS_SUCCESS:
return "cuSolver success.";
case CUSOLVER_STATUS_NOT_INITIALIZED:
return "cuSolver has not been initialized";
case CUSOLVER_STATUS_ALLOC_FAILED:
return "cuSolver allocation failed";
case CUSOLVER_STATUS_INVALID_VALUE:
return "cuSolver invalid value error";
case CUSOLVER_STATUS_ARCH_MISMATCH:
return "cuSolver architecture mismatch error";
case CUSOLVER_STATUS_MAPPING_ERROR:
return "cuSolver mapping error";
case CUSOLVER_STATUS_EXECUTION_FAILED:
return "cuSolver execution failed";
case CUSOLVER_STATUS_INTERNAL_ERROR:
return "cuSolver internal error";
case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
return "cuSolver matrix type not supported error";
case CUSOLVER_STATUS_NOT_SUPPORTED:
return "cuSolver not supported error";
case CUSOLVER_STATUS_ZERO_PIVOT:
return "cuSolver zero pivot error";
case CUSOLVER_STATUS_INVALID_LICENSE:
return "cuSolver invalid license error";
default:
return absl::StrCat("Unknown cuSolver error: ", status);
}
}
std::string ErrorToString(cublasStatus_t status) {
switch (status) {
case CUBLAS_STATUS_SUCCESS:
return "cuBlas success";
case CUBLAS_STATUS_NOT_INITIALIZED:
return "cuBlas has not been initialized";
case CUBLAS_STATUS_ALLOC_FAILED:
return "cuBlas allocation failure";
case CUBLAS_STATUS_INVALID_VALUE:
return "cuBlas invalid value error";
case CUBLAS_STATUS_ARCH_MISMATCH:
return "cuBlas architecture mismatch";
case CUBLAS_STATUS_MAPPING_ERROR:
return "cuBlas mapping error";
case CUBLAS_STATUS_EXECUTION_FAILED:
return "cuBlas execution failed";
case CUBLAS_STATUS_INTERNAL_ERROR:
return "cuBlas internal error";
case CUBLAS_STATUS_NOT_SUPPORTED:
return "cuBlas not supported error";
case CUBLAS_STATUS_LICENSE_ERROR:
return "cuBlas license error";
default:
return "Unknown cuBlas error";
}
}
template <typename T>
void ThrowError(T status, const char* file, std::int64_t line,
const char* expr) {
throw std::runtime_error(absl::StrFormat("%s:%d: operation %s failed: %s",
file, line, expr,
ErrorToString(status)));
}
} // namespace
void ThrowIfError(cudaError_t error, const char* file, std::int64_t line,
const char* expr) {
if (error != cudaSuccess) ThrowError(error, file, line, expr);
}
void ThrowIfError(cusolverStatus_t status, const char* file, std::int64_t line,
const char* expr) {
if (status != CUSOLVER_STATUS_SUCCESS) ThrowError(status, file, line, expr);
}
void ThrowIfError(cusparseStatus_t status, const char* file, std::int64_t line,
const char* expr) {
if (status != CUSPARSE_STATUS_SUCCESS) ThrowError(status, file, line, expr);
}
void ThrowIfError(cublasStatus_t status, const char* file, std::int64_t line,
const char* expr) {
if (status != CUBLAS_STATUS_SUCCESS) ThrowError(status, file, line, expr);
}
std::unique_ptr<void* []> MakeBatchPointers(cudaStream_t stream, void* buffer,
void* dev_ptrs, int batch,
int batch_elem_size) {
char* ptr = static_cast<char*>(buffer);
auto host_ptrs = absl::make_unique<void*[]>(batch);
for (int i = 0; i < batch; ++i) {
host_ptrs[i] = ptr;
ptr += batch_elem_size;
}
JAX_THROW_IF_ERROR(cudaMemcpyAsync(dev_ptrs, host_ptrs.get(),
sizeof(void*) * batch,
cudaMemcpyHostToDevice, stream));
return host_ptrs;
}
} // namespace jax