Skip to content

Commit

Permalink
Add torch.float8_e5m2 and torch.float8_e4m3 data types (pytorch#104242)
Browse files Browse the repository at this point in the history
Proposal of two float8 variants - e5m2 and e4m3 - based on https://arxiv.org/pdf/2209.05433.pdf

Hide all Float8 operator implementations behind `#if !defined(C10_MOBILE)` guard to keep Android build size almost unchanged

TODO:
 - Refactor duplicated code
 - Cleanup unbalanced pragma pop in dtype utils
 - Add native implementation on the CUDA size

Co-authored-by: Nikita Shulga <[email protected]>
Pull Request resolved: pytorch#104242
Approved by: https://github.com/albanD
  • Loading branch information
australopitek authored and pytorchmergebot committed Jul 20, 2023
1 parent 1ea153a commit a980413
Show file tree
Hide file tree
Showing 36 changed files with 1,626 additions and 98 deletions.
18 changes: 18 additions & 0 deletions aten/src/ATen/AccumulateType.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#include <ATen/Config.h>
#include <c10/core/ScalarType.h>
#include <c10/util/BFloat16.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Half.h>

// Defines the accumulation type for a scalar type.
Expand Down Expand Up @@ -67,6 +69,14 @@ struct AccumulateType<Half, true> {
using type = float;
};
template <>
struct AccumulateType<Float8_e5m2, true> {
using type = float;
};
template <>
struct AccumulateType<Float8_e4m3fn, true> {
using type = float;
};
template <>
struct AccumulateType<float, true> {
using type = float;
};
Expand Down Expand Up @@ -111,6 +121,14 @@ struct AccumulateType<BFloat16, false> {
using type = float;
};
template <>
struct AccumulateType<Float8_e5m2, false> {
using type = float;
};
template <>
struct AccumulateType<Float8_e4m3fn, false> {
using type = float;
};
template <>
struct AccumulateType<c10::complex<Half>, false> {
using type = c10::complex<float>;
};
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/DLConvertor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ DLDataType getDLDataType(const Tensor& t) {
case ScalarType::BFloat16:
dtype.code = DLDataTypeCode::kDLBfloat;
break;
case ScalarType::Float8_e5m2:
case ScalarType::Float8_e4m3fn:
TORCH_CHECK(false, "float8 types are not supported by dlpack");
break;
case ScalarType::QInt8:
case ScalarType::QUInt8:
case ScalarType::QInt32:
Expand Down
83 changes: 83 additions & 0 deletions aten/src/ATen/Dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,22 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
AT_DISPATCH_CASE_FLOATING_TYPES_AND3( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))

#define AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)

#define AT_DISPATCH_FLOATING_TYPES_AND4( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))

#define AT_DISPATCH_CASE_COMPLEX_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::ComplexDouble, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::ComplexFloat, __VA_ARGS__)
Expand Down Expand Up @@ -515,6 +531,73 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))

#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)

#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
TYPE, \
NAME, \
...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
__VA_ARGS__))

#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
SCALARTYPE6, \
...) \
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__)

#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
SCALARTYPE6, \
TYPE, \
NAME, \
...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
SCALARTYPE6, \
__VA_ARGS__))

#define AT_DISPATCH_INDEX_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, \
Expand Down
26 changes: 26 additions & 0 deletions aten/src/ATen/NumericUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include <c10/macros/Macros.h>
#include <c10/util/BFloat16.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Half.h>
#include <c10/util/complex.h>

Expand Down Expand Up @@ -62,6 +64,22 @@ inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {
return at::_isnan(static_cast<float>(val));
}

template <
typename T,
typename std::enable_if<std::is_same<T, at::Float8_e5m2>::value, int>::
type = 0>
inline C10_HOST_DEVICE bool _isnan(T val) {
return val.isnan();
}

template <
typename T,
typename std::enable_if<std::is_same<T, at::Float8_e4m3fn>::value, int>::
type = 0>
inline C10_HOST_DEVICE bool _isnan(T val) {
return val.isnan();
}

// std::isinf isn't performant to use on integral types; it will
// (uselessly) convert to floating point and then do the test.
// This function is.
Expand Down Expand Up @@ -92,6 +110,14 @@ inline C10_HOST_DEVICE bool _isinf(at::BFloat16 val) {
return at::_isinf(static_cast<float>(val));
}

inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2 val) {
return val.isinf();
}

inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fn val) {
return false;
}

template <typename T>
C10_HOST_DEVICE inline T exp(T x) {
static_assert(
Expand Down
10 changes: 10 additions & 0 deletions aten/src/ATen/OpMathType.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include <c10/core/ScalarType.h>
#include <c10/util/BFloat16.h>
#include <c10/util/Exception.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Half.h>

namespace at {
Expand All @@ -21,6 +23,14 @@ struct OpMathType<at::BFloat16> {
using type = float;
};
template <>
struct OpMathType<at::Float8_e5m2> {
using type = float;
};
template <>
struct OpMathType<at::Float8_e4m3fn> {
using type = float;
};
template <>
struct OpMathType<c10::complex<Half>> {
using type = c10::complex<float>;
};
Expand Down
4 changes: 3 additions & 1 deletion aten/src/ATen/cpu/vec/vec_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ struct is_floating_point:
std::integral_constant<bool,
std::is_floating_point<T>::value ||
std::is_same<T, at::Half>::value ||
std::is_same<T, at::BFloat16>::value> {
std::is_same<T, at::BFloat16>::value ||
std::is_same<T, at::Float8_e5m2>::value ||
std::is_same<T, at::Float8_e4m3fn>::value> {
};

template<typename T>
Expand Down
14 changes: 13 additions & 1 deletion aten/src/ATen/native/Copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ bool copy_transpose_valid(const Tensor& self, const Tensor& src) {
self.numel() >= MIN_SZ;
}

#if !defined(C10_MOBILE)
#define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
kComplexHalf, kHalf, kBool, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
TYPE, NAME, __VA_ARGS__)
#else
#define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
kComplexHalf, kHalf, kBool, kBFloat16, \
TYPE, NAME, __VA_ARGS__)
#endif

// special case copy where tensor is contiguous and src is a transposed matrix
// This can be generalized to most copies, but it's trickier
void copy_same_type_transpose_(Tensor& self, const Tensor& src) {
Expand All @@ -65,7 +77,7 @@ void copy_same_type_transpose_(Tensor& self, const Tensor& src) {
// The code below is implemented with the assumption that sizes are equal
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.sizes().equals(src.sizes()));

AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kHalf, kBool, kBFloat16, kComplexHalf, self.scalar_type(), "copy_", [&] {
_AT_DISPATCH_CP_TYPES(self.scalar_type(), "copy_", [&] {
scalar_t* sp = src.data_ptr<scalar_t>();
scalar_t* rp = self.data_ptr<scalar_t>();
scalar_t* bp = buf.data_ptr<scalar_t>();
Expand Down
16 changes: 13 additions & 3 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1310,6 +1310,18 @@ Tensor outer(const Tensor& self, const Tensor& vec2) {
return self.reshape_symint({self.sym_size(0), 1}) * vec2;
}


#if !defined(C10_MOBILE)
#define _AT_DISPATCH_ADDMM_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \
kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
TYPE, NAME, __VA_ARGS__)
#else
#define _AT_DISPATCH_ADDMM_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, \
TYPE, NAME, __VA_ARGS__)
#endif

static void addmm_impl_cpu_(
Tensor &result, const Tensor &self, Tensor m1, Tensor m2, const Scalar& beta, const Scalar& alpha) {
TORCH_INTERNAL_ASSERT(self.dim() == 2 && m1.dim() == 2 && m2.dim() == 2);
Expand Down Expand Up @@ -1438,9 +1450,7 @@ static void addmm_impl_cpu_(

if(!dispatched) {
// Apply BLAS routine
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16,
result.scalar_type(), "addmm_impl_cpu_",
[&]{
_AT_DISPATCH_ADDMM_TYPES(result.scalar_type(), "addmm_impl_cpu_", [&]{
using opmath_t = at::opmath_type<scalar_t>;
at::native::cpublas::gemm(
transpose_a ? a.is_conj() ? TransposeType::ConjTranspose : TransposeType::Transpose : TransposeType::NoTranspose,
Expand Down
15 changes: 13 additions & 2 deletions aten/src/ATen/native/Scalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,21 @@ Scalar item(const Tensor& self) {
}
}

#if !defined(C10_MOBILE)
#define _AT_DISPATCH_SD_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
kComplexHalf, kHalf, kBool, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
TYPE, NAME, __VA_ARGS__)
#else
#define _AT_DISPATCH_SD_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
kComplexHalf, kHalf, kBool, kBFloat16, \
TYPE, NAME, __VA_ARGS__)
#endif

Scalar _local_scalar_dense_cpu(const Tensor& self) {
Scalar r;
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kComplexHalf, kHalf, kBool, kBFloat16, self.scalar_type(), "_local_scalar_dense_cpu", [&] {
_AT_DISPATCH_SD_TYPES(self.scalar_type(), "_local_scalar_dense_cpu", [&] {
scalar_t value = *self.data_ptr<scalar_t>();
r = Scalar(value);
});
Expand Down
16 changes: 14 additions & 2 deletions aten/src/ATen/native/TensorCompare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,18 @@ Tensor isreal(const Tensor& self) {
return at::imag(self) == 0;
}


#if !defined(C10_MOBILE)
#define _AT_DISPATCH_INF_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_FLOATING_TYPES_AND3( kHalf, kBFloat16, kFloat8_e5m2, \
TYPE, NAME, __VA_ARGS__)
#else
#define _AT_DISPATCH_INF_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, \
TYPE, NAME, __VA_ARGS__)
#endif


Tensor isinf(const Tensor &self) {
// Note: Integral tensor values are never infinite
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
Expand All @@ -381,7 +393,7 @@ Tensor isinf(const Tensor &self) {
(at::isinf(at::imag(self)));
}

return AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, self.scalar_type(), "isinf", [&]() {
return _AT_DISPATCH_INF_TYPES(self.scalar_type(), "isinf", [&]() {
return self.abs() == std::numeric_limits<scalar_t>::infinity();
});
}
Expand All @@ -397,7 +409,7 @@ Tensor isfinite(const Tensor& self) {
return at::isfinite(at::real(self)).__iand__(at::isfinite(at::imag(self)));
}

return AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, self.scalar_type(), "isfinite", [&]() {
return _AT_DISPATCH_INF_TYPES(self.scalar_type(), "isfinite", [&]() {
return (self == self) * (self.abs() != std::numeric_limits<scalar_t>::infinity());
});
}
Expand Down
Loading

0 comments on commit a980413

Please sign in to comment.