Skip to content

Commit

Permalink
Add bfloat16 support to JAX. (jax-ml#1720)
Browse files Browse the repository at this point in the history
bfloat16 support is still immature, but this PR adds some initial support.

Fixes jax-ml#76, at least enough that we can declare it fixed and open specific issues for specific bfloat16 problems.

The main awkwardness that this change deals with is that classic NumPy doesn't understand bfloat16 promotion rules, so we must:

implement our own type promotion operators that understand bfloat16 types
wrap a number of the reference implementations in tests to temporarily cast to float32 for computation.
  • Loading branch information
hawkinsp authored Nov 21, 2019
1 parent b7d11ab commit ee36818
Show file tree
Hide file tree
Showing 8 changed files with 418 additions and 251 deletions.
7 changes: 4 additions & 3 deletions jax/abstract_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,11 @@ def zeros_like_array(x):
dtype = dtypes.canonicalize_dtype(dtypes.result_type(x))
return onp.broadcast_to(onp.array(0, dtype), onp.shape(x))

array_types = {onp.ndarray, onp.float64, onp.float32, onp.float16,
array_types = {onp.ndarray, onp.bool_,
onp.int8, onp.int16, onp.int32, onp.int64,
onp.uint8, onp.uint16, onp.uint32, onp.uint64,
dtypes.bfloat16, onp.float16, onp.float32, onp.float64,
onp.complex64, onp.complex128,
onp.int64, onp.int32, onp.int16, onp.int8,
onp.bool_, onp.uint64, onp.uint32, onp.uint16, onp.uint8,
onp.longlong}

for t in array_types:
Expand Down
81 changes: 76 additions & 5 deletions jax/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,51 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Array type functions.
#
# JAX dtypes differ from NumPy in both:
# a) their type promotion rules, and
# b) the set of supported types (e.g., bfloat16),
# so we need our own implementation that deviates from NumPy in places.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from distutils.util import strtobool
import functools
import os

import numpy as onp
import six

from .config import flags
from . import util
from .config import flags
from .lib import xla_client

FLAGS = flags.FLAGS
flags.DEFINE_bool('jax_enable_x64',
strtobool(os.getenv('JAX_ENABLE_X64', 'False')),
'Enable 64-bit types to be used.')

# bfloat16 support
bfloat16 = xla_client.bfloat16
_bfloat16_dtype = onp.dtype(bfloat16)

class _bfloat16_finfo(object):
bits = 16
eps = bfloat16(float.fromhex("0x1p-7"))
epsneg = bfloat16(float.fromhex("0x1p-8"))
machep = -7
negep = -8
max = bfloat16(float.fromhex("0x1.FEp127"))
min = -max
nexp = 8
nmant = 7
iexp = nexp
precision = 2
resolution = 10 ** -2
tiny = bfloat16(float.fromhex("0x1p-126"))

# Default types.

Expand Down Expand Up @@ -96,12 +123,56 @@ def coerce_to_array(x):
return onp.array(x, dtype) if dtype else onp.array(x)

iinfo = onp.iinfo
finfo = onp.finfo

def finfo(dtype):
# Since NumPy doesn't consider bfloat16 a floating-point type, we have to
# provide an alternative implementation of finfo that does so.
if onp.result_type(dtype) == _bfloat16_dtype:
return _bfloat16_finfo
else:
return onp.finfo(dtype)


def issubdtype(a, b):
if a == bfloat16:
return b in [onp.floating, onp.inexact, onp.number]
return onp.issubdtype(a, b)

can_cast = onp.can_cast
issubdtype = onp.issubdtype
issubsctype = onp.issubsctype
promote_types = onp.promote_types

_bfloat16_type_promotions = {
onp.dtype('bool'): onp.dtype(bfloat16),
onp.dtype(bfloat16): onp.dtype(bfloat16),
onp.dtype('float16'): onp.dtype('float32'),
onp.dtype('float32'): onp.dtype('float32'),
onp.dtype('float64'): onp.dtype('float64'),
onp.dtype('complex64'): onp.dtype('complex64'),
onp.dtype('complex128'): onp.dtype('complex128'),
onp.dtype('int8'): onp.dtype(bfloat16),
onp.dtype('int16'): onp.dtype('float32'),
onp.dtype('int32'): onp.dtype('float64'),
onp.dtype('int64'): onp.dtype('float64'),
onp.dtype('uint8'): onp.dtype(bfloat16),
onp.dtype('uint16'): onp.dtype('float32'),
onp.dtype('uint32'): onp.dtype('float64'),
onp.dtype('uint64'): onp.dtype('float64'),
}

def promote_types(a, b):
a = onp.dtype(a)
b = onp.dtype(b)
if b == _bfloat16_dtype:
a, b = b, a

if a == _bfloat16_dtype:
try:
return _bfloat16_type_promotions[b]
except:
raise TypeError("invalid type promotion of bfloat16 type and {}"
.format(b))

return onp.promote_types(a, b)


def is_python_scalar(x):
Expand Down Expand Up @@ -138,4 +209,4 @@ def result_type(*args):
(scalars if is_python_scalar(x) else dtypes).append(dtype(x))
array_priority = max(map(_dtype_priority, dtypes)) if dtypes else -1
dtypes += [x for x in scalars if _dtype_priority(x) > array_priority]
return canonicalize_dtype(onp.result_type(*dtypes))
return canonicalize_dtype(functools.reduce(promote_types, dtypes))
51 changes: 29 additions & 22 deletions jax/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,17 +351,23 @@ def convert_element_type(operand, new_dtype):
"""
new_dtype = dtypes.canonicalize_dtype(new_dtype)
old_dtype = dtypes.canonicalize_dtype(_dtype(operand))
if old_dtype != new_dtype:
if (dtypes.issubdtype(old_dtype, onp.complexfloating) and
not dtypes.issubdtype(new_dtype, onp.complexfloating)):
msg = "Casting complex values to real discards the imaginary part"
warnings.warn(msg, onp.ComplexWarning, stacklevel=2)
operand = real(operand)
old_dtype = _dtype(operand)
return convert_element_type_p.bind(
operand, new_dtype=new_dtype, old_dtype=old_dtype)
else:
if old_dtype == new_dtype:
return operand
if (dtypes.issubdtype(old_dtype, onp.complexfloating) and
not dtypes.issubdtype(new_dtype, onp.complexfloating)):
msg = "Casting complex values to real discards the imaginary part"
warnings.warn(msg, onp.ComplexWarning, stacklevel=2)
operand = real(operand)
old_dtype = _dtype(operand)
# TODO(b/143311238, b/142974574): work around bfloat16 conversion bugs by
# introducing an intermediate cast via float32.
if ((old_dtype == dtypes.bfloat16 and new_dtype != onp.float32) or
(new_dtype == dtypes.bfloat16 and old_dtype != onp.float32)):
operand = convert_element_type_p.bind(
operand, new_dtype=onp.float32, old_dtype=old_dtype)
old_dtype = onp.float32
return convert_element_type_p.bind(
operand, new_dtype=new_dtype, old_dtype=old_dtype)

def bitcast_convert_type(operand, new_dtype):
"""Elementwise bitcast.
Expand Down Expand Up @@ -1377,7 +1383,19 @@ def reciprocal(x):
r"""Elementwise reciprocal: :math:`1 \over x`."""
return div(_const(x, 1), x)

def _upcast_fp16_for_computation(f):
@functools.wraps(f)
def f_wrapped(x):
dtype = _dtype(x)
if dtype == onp.float16 or dtype == dtypes.bfloat16:
return convert_element_type(
f(convert_element_type(x, onp.float32)), dtype)
return f(x)

return f_wrapped

@api.jit
@_upcast_fp16_for_computation
def tan(x):
r"""Elementwise tangent: :math:`\mathrm{tan}(x)`."""
return div(sin(x), cos(x))
Expand All @@ -1401,17 +1419,6 @@ def atan(x):
r"""Elementwise arc tangent: :math:`\mathrm{atan}(x)`."""
return atan2(x, _const(x, 1))

def _upcast_fp16_for_computation(f):
@functools.wraps(f)
def f_wrapped(x):
dtype = _dtype(x)
if dtype == onp.float16:
return convert_element_type(
f(convert_element_type(x, onp.float32)), dtype)
return f(x)

return f_wrapped

@api.jit
@_upcast_fp16_for_computation
def sinh(x):
Expand Down Expand Up @@ -1586,7 +1593,7 @@ def _brcast_to(x, shape):
return broadcast(x, shape)


_float = {onp.floating}
_float = {onp.floating, dtypes.bfloat16}
_complex = {onp.complexfloating}
_complex_elem_types = {onp.float32, onp.float64}
_int = {onp.integer}
Expand Down
16 changes: 11 additions & 5 deletions jax/lax_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

from six.moves import builtins

from . import dtypes

_slice = builtins.slice
_max = builtins.max
_min = builtins.min
Expand Down Expand Up @@ -88,7 +90,7 @@ def complex(x, y):
mul = onp.multiply

def div(lhs, rhs):
if onp.issubdtype(onp.result_type(lhs), onp.integer):
if dtypes.issubdtype(dtypes.result_type(lhs), onp.integer):
quotient = onp.floor_divide(lhs, rhs)
select = onp.logical_and(onp.sign(lhs) != onp.sign(rhs),
onp.remainder(lhs, rhs) != 0)
Expand Down Expand Up @@ -176,7 +178,11 @@ def dot_general(lhs, rhs, dimension_numbers):
not_none = lambda x: x is not None
out_axis_ids = filter(not_none,
batch_ids + lhs_out_axis_ids + rhs_out_axis_ids)
return onp.einsum(lhs, lhs_axis_ids, rhs, rhs_axis_ids, out_axis_ids)
assert lhs.dtype == rhs.dtype
dtype = onp.float32 if lhs.dtype == dtypes.bfloat16 else None
out = onp.einsum(lhs, lhs_axis_ids, rhs, rhs_axis_ids, out_axis_ids,
dtype=dtype)
return out.astype(dtypes.bfloat16) if lhs.dtype == dtypes.bfloat16 else out

def broadcast(operand, sizes):
return onp.broadcast_to(operand, sizes + onp.shape(operand))
Expand Down Expand Up @@ -352,15 +358,15 @@ def _make_reducer(py_binop, init_val):
monoid_record = _monoids.get(getattr(py_binop, '__name__'))
if monoid_record:
reducer, monoid_identity = monoid_record
if init_val == monoid_identity(onp.result_type(init_val)):
if init_val == monoid_identity(dtypes.result_type(init_val)):
return reducer
return _reducer_from_pyfunc(py_binop, init_val)

def _get_max_identity(dt):
return -onp.inf if onp.issubdtype(dt, onp.floating) else onp.iinfo(dt).min
return -onp.inf if dtypes.issubdtype(dt, onp.floating) else onp.iinfo(dt).min

def _get_min_identity(dt):
return onp.inf if onp.issubdtype(dt, onp.floating) else onp.iinfo(dt).max
return onp.inf if dtypes.issubdtype(dt, onp.floating) else onp.iinfo(dt).max

def _identity_getter(op):
return lambda dtype: onp.asarray(op.identity, dtype=dtype)
Expand Down
Loading

0 comments on commit ee36818

Please sign in to comment.