Skip to content

Commit

Permalink
[Bugfix] Fix that half-precision SpMM produce incorrect results (dmlc…
Browse files Browse the repository at this point in the history
…#4842)

* update accumulator

* rename half to __half

* add bfloat16

* simplify code

* fix another case

* add unit test

* disable half-precision SpMMCoo

* fix lint
  • Loading branch information
yaox12 authored Nov 10, 2022
1 parent 9699b93 commit a8f9d5e
Show file tree
Hide file tree
Showing 5 changed files with 240 additions and 20 deletions.
144 changes: 138 additions & 6 deletions src/array/cuda/functor.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,32 @@ template <typename Idx, typename DType, bool atomic = false>
struct Sum : _Sum<Idx, DType, atomic> {};

template <typename Idx, bool atomic>
struct Sum<Idx, half, atomic> : _Sum<Idx, half, atomic> {
static constexpr __host__ __device__ __forceinline__ half zero() {
struct Sum<Idx, __half, atomic> : _Sum<Idx, __half, atomic> {
static constexpr __host__ __device__ __forceinline__ __half zero() {
return __float2half_rn(0.);
}
static __device__ __forceinline__ void Call(
__half *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
__half val, Idx uid, Idx eid) {
_Sum<Idx, __half, atomic>::Call(
out_buf, arg_u_buf, arg_e_buf, val, uid, eid);
}
static __device__ __forceinline__ void Call(
__half *out_buf, Idx *arg_buf, __half val, Idx id) {
_Sum<Idx, __half, atomic>::Call(out_buf, arg_buf, val, id);
}
// sometimes we have to use float in reduction for better precision
static __device__ __forceinline__ void Call(
float *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
__half val, Idx uid, Idx eid) {
_Sum<Idx, float, atomic>::Call(out_buf, arg_u_buf, arg_e_buf,
static_cast<float>(val), uid, eid);
}
static __device__ __forceinline__ void Call(
float *out_buf, Idx *arg_buf, __half val, Idx id) {
_Sum<Idx, float, atomic>::Call(out_buf, arg_buf,
static_cast<float>(val), id);
}
};

#if BF16_ENABLED
Expand All @@ -190,6 +212,28 @@ struct Sum<Idx, __nv_bfloat16, atomic> : _Sum<Idx, __nv_bfloat16, atomic> {
static constexpr __host__ __device__ __forceinline__ __nv_bfloat16 zero() {
return __float2bfloat16_rn(0.);
}
static __device__ __forceinline__ void Call(
__nv_bfloat16 *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
__nv_bfloat16 val, Idx uid, Idx eid) {
_Sum<Idx, __nv_bfloat16, atomic>::Call(
out_buf, arg_u_buf, arg_e_buf, val, uid, eid);
}
static __device__ __forceinline__ void Call(
__nv_bfloat16 *out_buf, Idx *arg_buf, __nv_bfloat16 val, Idx id) {
_Sum<Idx, __nv_bfloat16, atomic>::Call(out_buf, arg_buf, val, id);
}
// sometimes we have to use float in reduction for better precision
static __device__ __forceinline__ void Call(
float *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
__nv_bfloat16 val, Idx uid, Idx eid) {
_Sum<Idx, float, atomic>::Call(out_buf, arg_u_buf, arg_e_buf,
static_cast<float>(val), uid, eid);
}
static __device__ __forceinline__ void Call(
float *out_buf, Idx *arg_buf, __nv_bfloat16 val, Idx id) {
_Sum<Idx, float, atomic>::Call(out_buf, arg_buf,
static_cast<float>(val), id);
}
};
#endif // BF16_ENABLED

Expand Down Expand Up @@ -239,10 +283,32 @@ template <typename Idx, typename DType, bool atomic = false>
struct Max : _Max<Idx, DType, atomic> {};

template <typename Idx, bool atomic>
struct Max<Idx, half, atomic> : _Max<Idx, half, atomic> {
static constexpr __host__ __device__ __forceinline__ half zero() {
struct Max<Idx, __half, atomic> : _Max<Idx, __half, atomic> {
static constexpr __host__ __device__ __forceinline__ __half zero() {
return __float2half_rn(-6.550400e+04f);
}
static __device__ __forceinline__ void Call(
__half *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
__half val, Idx uid, Idx eid) {
_Max<Idx, __half, atomic>::Call(
out_buf, arg_u_buf, arg_e_buf, val, uid, eid);
}
static __device__ __forceinline__ void Call(
__half *out_buf, Idx *arg_buf, __half val, Idx id) {
_Max<Idx, __half, atomic>::Call(out_buf, arg_buf, val, id);
}
// sometimes we have to use float in reduction for better precision
static __device__ __forceinline__ void Call(
float *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
__half val, Idx uid, Idx eid) {
_Max<Idx, float, atomic>::Call(out_buf, arg_u_buf, arg_e_buf,
static_cast<float>(val), uid, eid);
}
static __device__ __forceinline__ void Call(
float *out_buf, Idx *arg_buf, __half val, Idx id) {
_Max<Idx, float, atomic>::Call(out_buf, arg_buf,
static_cast<float>(val), id);
}
};

#if BF16_ENABLED
Expand All @@ -251,6 +317,28 @@ struct Max<Idx, __nv_bfloat16, atomic> : _Max<Idx, __nv_bfloat16, atomic> {
static constexpr __host__ __device__ __forceinline__ __nv_bfloat16 zero() {
return __float2bfloat16_rn(-std::numeric_limits<float>::infinity());
}
static __device__ __forceinline__ void Call(
__nv_bfloat16 *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
__nv_bfloat16 val, Idx uid, Idx eid) {
_Max<Idx, __nv_bfloat16, atomic>::Call(
out_buf, arg_u_buf, arg_e_buf, val, uid, eid);
}
static __device__ __forceinline__ void Call(
__nv_bfloat16 *out_buf, Idx *arg_buf, __nv_bfloat16 val, Idx id) {
_Max<Idx, __nv_bfloat16, atomic>::Call(out_buf, arg_buf, val, id);
}
// sometimes we have to use float in reduction for better precision
static __device__ __forceinline__ void Call(
float *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
__nv_bfloat16 val, Idx uid, Idx eid) {
_Max<Idx, float, atomic>::Call(out_buf, arg_u_buf, arg_e_buf,
static_cast<float>(val), uid, eid);
}
static __device__ __forceinline__ void Call(
float *out_buf, Idx *arg_buf, __nv_bfloat16 val, Idx id) {
_Max<Idx, float, atomic>::Call(out_buf, arg_buf,
static_cast<float>(val), id);
}
};
#endif // BF16_ENABLED

Expand Down Expand Up @@ -300,10 +388,32 @@ template <typename Idx, typename DType, bool atomic = false>
struct Min : _Min<Idx, DType, atomic> {};

template <typename Idx, bool atomic>
struct Min<Idx, half, atomic> : _Min<Idx, half, atomic> {
static constexpr __host__ __device__ __forceinline__ half zero() {
struct Min<Idx, __half, atomic> : _Min<Idx, __half, atomic> {
static constexpr __host__ __device__ __forceinline__ __half zero() {
return __float2half_rn(6.550400e+04f);
}
static __device__ __forceinline__ void Call(
__half *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
__half val, Idx uid, Idx eid) {
_Min<Idx, __half, atomic>::Call(
out_buf, arg_u_buf, arg_e_buf, val, uid, eid);
}
static __device__ __forceinline__ void Call(
__half *out_buf, Idx *arg_buf, __half val, Idx id) {
_Min<Idx, __half, atomic>::Call(out_buf, arg_buf, val, id);
}
// sometimes we have to use float in reduction for better precision
static __device__ __forceinline__ void Call(
float *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
__half val, Idx uid, Idx eid) {
_Min<Idx, float, atomic>::Call(out_buf, arg_u_buf, arg_e_buf,
static_cast<float>(val), uid, eid);
}
static __device__ __forceinline__ void Call(
float *out_buf, Idx *arg_buf, __half val, Idx id) {
_Min<Idx, float, atomic>::Call(out_buf, arg_buf,
static_cast<float>(val), id);
}
};

#if BF16_ENABLED
Expand All @@ -312,6 +422,28 @@ struct Min<Idx, __nv_bfloat16, atomic> : _Min<Idx, __nv_bfloat16, atomic> {
static constexpr __host__ __device__ __forceinline__ __nv_bfloat16 zero() {
return __float2bfloat16_rn(std::numeric_limits<float>::infinity());
}
static __device__ __forceinline__ void Call(
__nv_bfloat16 *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
__nv_bfloat16 val, Idx uid, Idx eid) {
_Min<Idx, __nv_bfloat16, atomic>::Call(
out_buf, arg_u_buf, arg_e_buf, val, uid, eid);
}
static __device__ __forceinline__ void Call(
__nv_bfloat16 *out_buf, Idx *arg_buf, __nv_bfloat16 val, Idx id) {
_Min<Idx, __nv_bfloat16, atomic>::Call(out_buf, arg_buf, val, id);
}
// sometimes we have to use float in reduction for better precision
static __device__ __forceinline__ void Call(
float *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
__nv_bfloat16 val, Idx uid, Idx eid) {
_Min<Idx, float, atomic>::Call(out_buf, arg_u_buf, arg_e_buf,
static_cast<float>(val), uid, eid);
}
static __device__ __forceinline__ void Call(
float *out_buf, Idx *arg_buf, __nv_bfloat16 val, Idx id) {
_Min<Idx, float, atomic>::Call(out_buf, arg_buf,
static_cast<float>(val), id);
}
};
#endif // BF16_ENABLED

Expand Down
5 changes: 3 additions & 2 deletions src/array/cuda/segment_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
namespace dgl {

using namespace cuda;
using namespace runtime;

namespace aten {
namespace cuda {
Expand All @@ -32,12 +33,12 @@ __global__ void SegmentReduceKernel(
for (int row = blockIdx.x; row < n; row += gridDim.x) {
int col = blockIdx.y * blockDim.x + threadIdx.x;
while (col < dim) {
DType local_accum = ReduceOp::zero();
typename accum_dtype<DType>::type local_accum = ReduceOp::zero();
IdType local_arg = -1;
for (IdType i = offsets[row]; i < offsets[row + 1]; ++i) {
ReduceOp::Call(&local_accum, &local_arg, feat[i * dim + col], i);
}
out[row * dim + col] = local_accum;
out[row * dim + col] = static_cast<DType>(local_accum);
if (ReduceOp::require_arg) arg[row * dim + col] = local_arg;
col += gridDim.y * blockDim.x;
}
Expand Down
34 changes: 23 additions & 11 deletions src/array/cuda/spmm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ void _Transpose(const DType* in, DType* out, int row, int col) {
* @note cuBLAS has no geam API for half data type, fallback to our kernel.
*/
template <>
void _Transpose<half>(const half* in, half* out, int row, int col) {
void _Transpose<__half>(const __half* in, __half* out, int row, int col) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = FindNumThreads(row);
int nb = col;
Expand Down Expand Up @@ -553,7 +553,7 @@ __global__ void SpMMCsrKernel(
while (ty < num_rows) {
int tx = blockIdx.y * blockDim.x + threadIdx.x;
while (tx < out_len) {
DType local_accum = ReduceOp::zero();
typename accum_dtype<DType>::type local_accum = ReduceOp::zero();
Idx local_argu = 0, local_arge = 0;
const int lhs_add = UseBcast ? ubcast_off[tx] : tx;
const int rhs_add = UseBcast ? ebcast_off[tx] : tx;
Expand All @@ -573,7 +573,7 @@ __global__ void SpMMCsrKernel(
// Separate kernel `SpMMCmpCsrHeteroKernel` is used for max- and
// min-reducer. It does not affect the output on homogeneous graph as
// `out` is initialized to zero.
out[ty * out_len + tx] += local_accum;
out[ty * out_len + tx] += static_cast<DType>(local_accum);
if (ReduceOp::require_arg && BinaryOp::use_lhs)
arg_u[ty * out_len + tx] = local_argu;
if (ReduceOp::require_arg && BinaryOp::use_rhs)
Expand Down Expand Up @@ -610,7 +610,9 @@ __global__ void SpMMCmpCsrHeteroKernel(
while (ty < num_rows) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
while (tx < out_len) {
DType new_out = out[ty * out_len + tx]; // ReduceOp::zero();
using accum_type = typename accum_dtype<DType>::type;
accum_type local_accum = static_cast<accum_type>(
out[ty * out_len + tx]); // ReduceOp::zero();
Idx local_argu = 0, local_arge = 0;
const int lhs_add = UseBcast ? ubcast_off[tx] : tx;
const int rhs_add = UseBcast ? ebcast_off[tx] : tx;
Expand All @@ -622,10 +624,12 @@ __global__ void SpMMCmpCsrHeteroKernel(
const DType* eoff =
BinaryOp::use_rhs ? (efeat + eid * efeat_len) : nullptr;
DType tmp_out = BinaryOp::Call(uoff + lhs_add, eoff + rhs_add);
ReduceOp::Call(&new_out, &local_argu, &local_arge, tmp_out, cid, eid);
ReduceOp::Call(
&local_accum, &local_argu, &local_arge, tmp_out, cid, eid);
}
// Update output only when max/min values are different that original
// output
DType new_out = static_cast<DType>(local_accum);
if (out[ty * out_len + tx] != new_out) {
out[ty * out_len + tx] = new_out;
if (ReduceOp::require_arg && BinaryOp::use_lhs) {
Expand Down Expand Up @@ -663,12 +667,20 @@ template <typename Idx, typename DType, typename BinaryOp, typename ReduceOp>
void SpMMCoo(
const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat,
NDArray out, NDArray argu, NDArray arge) {
#if defined(CUDART_VERSION) && CUDART_VERSION <= 10000
if (std::is_same<DType, half>::value)
LOG(FATAL) << "SpMMCoo requires atomicCAS, which is not supported "
<< "for float16 in CUDA 10.0. Please upgrade your CUDA "
<< "to later versions.";
#endif
/**
* TODO(Xin): Disable half precision for SpMMCoo due to the round-off error.
* We should use fp32 for the accumulation but it's hard to modify the
* current implementation.
*/
#if BF16_ENABLED
if (std::is_same<DType, __half>::value ||
std::is_same<DType, __nv_bfloat16>::value)
#else
if (std::is_same<DType, __half>::value)
#endif // BF16_ENABLED
LOG(FATAL) << "SpMMCoo doesn't support half precision fow now. "
<< "Please use SpMMCsr instead by allowing the graph "
<< "materialize CSR/CSC formats.";
const Idx *row = coo.row.Ptr<Idx>(), *col = coo.col.Ptr<Idx>(),
*edge_map = coo.data.Ptr<Idx>();
const DType *ufeat_data = ufeat.Ptr<DType>(),
Expand Down
39 changes: 38 additions & 1 deletion src/runtime/cuda/cuda_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,17 @@ struct cuda_dtype {
};

template <>
struct cuda_dtype<half> {
struct cuda_dtype<__half> {
static constexpr cudaDataType_t value = CUDA_R_16F;
};

#if BF16_ENABLED
template <>
struct cuda_dtype<__nv_bfloat16> {
static constexpr cudaDataType_t value = CUDA_R_16BF;
};
#endif // BF16_ENABLED

template <>
struct cuda_dtype<float> {
static constexpr cudaDataType_t value = CUDA_R_32F;
Expand All @@ -132,6 +139,36 @@ struct cuda_dtype<double> {
static constexpr cudaDataType_t value = CUDA_R_64F;
};

/*
* \brief Accumulator type for SpMM.
*/
template <typename T>
struct accum_dtype {
typedef float type;
};

template <>
struct accum_dtype<__half> {
typedef float type;
};

#if BF16_ENABLED
template <>
struct accum_dtype<__nv_bfloat16> {
typedef float type;
};
#endif // BF16_ENABLED

template <>
struct accum_dtype<float> {
typedef float type;
};

template <>
struct accum_dtype<double> {
typedef double type;
};

#if CUDART_VERSION >= 11000
/**
* @brief Cast index data type to cusparseIndexType_t.
Expand Down
38 changes: 38 additions & 0 deletions tests/compute/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,44 @@ def test_spmm(idtype, g, shp, msg, reducer):
g.dstdata.pop("v")


@unittest.skipIf(
dgl.backend.backend_name != "pytorch",
reason="Only support PyTorch for now."
)
@unittest.skipIf(
F._default_context_str == "cpu",
reason="Don't support half precision on CPU."
)
@parametrize_idtype
@pytest.mark.parametrize(
"dtype, rtol, atol",
[(torch.float16, 1e-3, 0.5), (torch.bfloat16, 4e-3, 2.)]
)
def test_half_spmm(idtype, dtype, rtol, atol):
if LooseVersion(torch.version.cuda) < LooseVersion("11.0") \
and dtype == torch.bfloat16:
pytest.skip("BF16 requires CUDA >= 11.0.")

# make sure the spmm result is < 512 to match the rtol/atol we set.
g = dgl.graph((torch.arange(900), torch.tensor([0] * 900)),
idtype=idtype, device=F.ctx())
feat_fp32 = torch.rand((g.num_src_nodes(), 32)).to(0)
feat_half = feat_fp32.to(dtype)

# test SpMMCSR
g = g.formats(['csc'])
res_fp32 = dgl.ops.copy_u_sum(g, feat_fp32)[0]
res_half = dgl.ops.copy_u_sum(g, feat_half)[0].float()
assert torch.allclose(res_fp32, res_half, rtol=rtol, atol=atol)

# test SpMMCOO
# TODO(Xin): half-precision SpMMCoo is temporally disabled.
# g = g.formats(['coo'])
# res_fp32 = dgl.ops.copy_u_sum(g, feat_fp32)[0]
# res_half = dgl.ops.copy_u_sum(g, feat_half)[0].float()
# assert torch.allclose(res_fp32, res_half, rtol=rtol, atol=atol)


@pytest.mark.parametrize("g", graphs)
@pytest.mark.parametrize("shp", sddmm_shapes)
@pytest.mark.parametrize("lhs_target", ["u", "v", "e"])
Expand Down

0 comments on commit a8f9d5e

Please sign in to comment.