From a74d02358374570f5c0386d935866d3c7015a9a2 Mon Sep 17 00:00:00 2001 From: guillaume-be Date: Sun, 25 Jun 2023 09:21:52 +0100 Subject: [PATCH] Add bfloat16 tensor loading support (#396) --- utils/convert_model.py | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/utils/convert_model.py b/utils/convert_model.py index c5fedde6..f1b9c00e 100644 --- a/utils/convert_model.py +++ b/utils/convert_model.py @@ -1,3 +1,16 @@ +# Copyright 2019-2023 Guillaume Becquin +# Copyright 2023 https://github.com/starkat99/half-rs +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import argparse import numpy as np import subprocess @@ -7,6 +20,23 @@ from pathlib import Path from torch import Tensor + +def get_bf16_repr(input_tensor: torch.Tensor) -> np.ndarray: + """Convert a bfloat16 tensor to an equivalent byte representation in Numpy. + This is a vectorized implementation inspired from https://github.com/starkat99/half-rs/blob/main/src/bfloat/convert.rs + (shared under Apache 2.0 license at https://github.com/starkat99/half-rs/blob/main/LICENSES/Apache-2.0.txt) + """ + v_fp32 = input_tensor.cpu().float().numpy() + byte_array = np.frombuffer(v_fp32.tobytes(), dtype=np.uint32) + nan_value = np.logical_or(np.right_shift(byte_array, 16), 0x0040) + nan_mask = np.logical_and(byte_array, 0x7FFF_FFFF) > 0x7F80_0000 + round_bit = 0x0000_8000 + output_val = np.right_shift(byte_array, 16) + threshold_mask = (np.logical_and(byte_array, round_bit) != 0) & (np.logical_and(byte_array, (3*round_bit-1)) != 0) + output = np.where(nan_mask, nan_value, np.where(threshold_mask, output_val+1, output_val)).astype(np.uint16) + return output + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( @@ -64,7 +94,10 @@ if args.suffix: k = k.split(".")[-1] if isinstance(v, Tensor): - tensor = v.cpu().numpy() + if v.dtype == torch.bfloat16: + tensor = get_bf16_repr(v) + else: + tensor = v.cpu().numpy() if args.dtype is not None: nps[k] = np.ascontiguousarray(tensor.astype(np.dtype(args.dtype))) else: