Skip to content

Commit

Permalink
add basic cuda support for float8 dtypes (pytorch#105807)
Browse files Browse the repository at this point in the history
Summary:

Ensures that creating tensors, copying, filling with zeroes, checking for nan works on cuda for the `float8` dtypes.  This should be enough for float8 emulation on cuda.

Note that I skipped the mul test - it's less trivial to add (need a new c++ macro), and there is no use case for it. We can follow up on that in the future.

Test Plan:

```
python test/test_quantization.py TestFloat8Dtype
```

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: pytorch#105807
Approved by: https://github.com/ezyang, https://github.com/jerryzh168, https://github.com/albanD
  • Loading branch information
vkuzo authored and pytorchmergebot committed Jul 25, 2023
1 parent 3a01c05 commit 8b34fa5
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 39 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/CompareEQKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ struct CompareEqFunctor{
}

C10_NOINLINE void compare_eq_ne_kernel(TensorIteratorBase &iter, EqOpType op) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBFloat16, kBool,
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6(kComplexHalf, kHalf, kBFloat16, kBool, kFloat8_e4m3fn, kFloat8_e5m2,
iter.common_dtype(), "compare_eq_ne_cuda", [&]() {
opmath_symmetric_gpu_kernel_with_scalars<scalar_t, bool>(
iter, CompareEqFunctor<scalar_t>(op));
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cuda/Copy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kHalf, kBool, kBFloat16, kComplexHalf, dtype, "copy_", [&] {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6(
kHalf, kBool, kBFloat16, kComplexHalf, kFloat8_e4m3fn, kFloat8_e5m2, dtype, "copy_", [&] {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
});
}
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/FillKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ struct FillFunctor {
};

void fill_kernel_cuda(TensorIterator& iter, const Scalar& value) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kBool, kHalf, kBFloat16, iter.dtype(), "fill_cuda", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6(kComplexHalf, kBool, kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2, iter.dtype(), "fill_cuda", [&]() {
gpu_kernel(iter, FillFunctor<scalar_t>(value.to<scalar_t>()));
});
}
Expand Down
5 changes: 4 additions & 1 deletion c10/util/Float8_e4m3fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <intrin.h>
#endif

#include <climits>
#include <cstdint>
#include <cstring>
#include <iosfwd>
Expand Down Expand Up @@ -102,7 +103,9 @@ inline C10_HOST_DEVICE float fp8e4m3fn_to_fp32_value(uint8_t input) {
_BitScanReverse(&nonsign_bsr, (unsigned long)nonsign);
uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31;
#else
uint32_t renorm_shift = __builtin_clz(nonsign);
// Note: zero is not a supported input into `__builtin_clz`
uint32_t renorm_shift =
nonsign != 0 ? __builtin_clz(nonsign) : sizeof(uint32_t) * CHAR_BIT;
#endif
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
/*
Expand Down
74 changes: 40 additions & 34 deletions test/quantization/core/experimental/test_float8.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
# Owner(s): ["oncall: quantization"]

import torch
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TestCase,
)
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import parametrize, run_tests, TestCase

# Masks for float8 simulation

Expand Down Expand Up @@ -96,63 +92,73 @@ class TestFloat8Dtype(TestCase):
"""

@parametrize("dtype", [torch.float8_e5m2, torch.float8_e4m3fn])
def test_creation_with_zeros(self, dtype):
x = torch.zeros(8, dtype=torch.float)
x8 = torch.zeros(8, dtype=dtype)
def test_creation_with_zeros(self, dtype, device):
x = torch.zeros(8, dtype=torch.float, device=device)
x8 = torch.zeros(8, dtype=dtype, device=device)
self.assertEqual(x, x8.float())

"""
Numerical test of float8 conversion
"""

@parametrize("dtype", [torch.float8_e5m2, torch.float8_e4m3fn])
def test_cast_to_float8(self, dtype):
x = torch.rand((100, 100)) * FP8_MAX[dtype]
def test_cast_to_float8(self, dtype, device):
x = torch.rand((100, 100), device=device) * FP8_MAX[dtype]
x = torch.cat((x, -x))
x8 = x.to(dtype)
x8_simulated = simulateFp8Precision(x, dtype)
self.assertEqual(x8_simulated, x8.float())

"""
Test of mul implementation
"""

@parametrize("dtype", [torch.float8_e5m2, torch.float8_e4m3fn])
def test_mul(self, dtype):
shape = (10, 10)
a = torch.randn(shape)
a8_simulated = simulateFp8Precision(a, dtype)
a8 = a.to(dtype)
b = torch.randn(shape)
b8_simulated = simulateFp8Precision(b, dtype)
b8 = b.to(dtype)
mul8 = a8 * b8
mul8_simulated = (a8_simulated * b8_simulated).to(dtype)
self.assertEqual(mul8, mul8_simulated)

"""
Test special numbers
"""

@parametrize("dtype", [torch.float8_e5m2, torch.float8_e4m3fn])
def test_special_numbers(self, dtype):
def compare_binary_with_decimal(binary, decimal, number_name, dtype):
def test_special_numbers(self, dtype, device):
def compare_binary_with_decimal(binary, decimal, number_name, dtype, device):
bits_int = int(binary, 2)
tensor_int = torch.tensor([bits_int], dtype=torch.uint8)
tensor_int = torch.tensor([bits_int], dtype=torch.uint8, device=device)
tensor_fp8 = tensor_int.view(dtype)
if number_name == "nan":
assert tensor_fp8.isnan()
else:
tensor_fp32 = tensor_fp8.float()
ref_tensor_fp32 = torch.tensor([decimal], dtype=torch.float)
ref_tensor_fp32 = torch.tensor(
[decimal], dtype=torch.float, device=device
)
self.assertEqual(tensor_fp32, ref_tensor_fp32)

for number in SPECIAL_NUMBERS[dtype]:
compare_binary_with_decimal(*number, dtype)
compare_binary_with_decimal(*number, dtype, device)


instantiate_device_type_tests(TestFloat8Dtype, globals())


instantiate_parametrized_tests(TestFloat8Dtype)
class TestFloat8DtypeCPUOnly(TestCase):

"""
Test of mul implementation
# Note: this is cpu-only for now because adding it to CUDA requires
adding yet c++ dtype macro, and there is no use case yet for unscaled
float8 multiplication - doesn't seem worth it.
"""

@parametrize("dtype", [torch.float8_e5m2, torch.float8_e4m3fn])
def test_mul(self, dtype):
shape = (10, 10)
a = torch.randn(shape)
a8_simulated = simulateFp8Precision(a, dtype)
a8 = a.to(dtype)
b = torch.randn(shape)
b8_simulated = simulateFp8Precision(b, dtype)
b8 = b.to(dtype)
mul8 = a8 * b8
mul8_simulated = (a8_simulated * b8_simulated).to(dtype)
self.assertEqual(mul8, mul8_simulated)


instantiate_device_type_tests(TestFloat8DtypeCPUOnly, globals(), only_for="cpu")

if __name__ == "__main__":
run_tests()
12 changes: 12 additions & 0 deletions test/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,18 @@

# Experimental functionality
from quantization.core.experimental.test_bits import TestBits # noqa: F401
try:
from quantization.core.experimental.test_float8 import TestFloat8DtypeCPU # noqa: F401
except ImportError as e:
logging.warning(e)
try:
from quantization.core.experimental.test_float8 import TestFloat8DtypeCUDA # noqa: F401
except ImportError as e:
logging.warning(e)
try:
from quantization.core.experimental.test_float8 import TestFloat8DtypeCPUOnlyCPU # noqa: F401
except ImportError as e:
logging.warning(e)

if __name__ == '__main__':
run_tests()

0 comments on commit 8b34fa5

Please sign in to comment.