Skip to content

Commit

Permalink
Add bfloat16 tensor loading support (guillaume-be#396)
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be authored Jun 25, 2023
1 parent 7b1ab24 commit a74d023
Showing 1 changed file with 34 additions and 1 deletion.
35 changes: 34 additions & 1 deletion utils/convert_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# Copyright 2019-2023 Guillaume Becquin
# Copyright 2023 https://github.com/starkat99/half-rs
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import argparse
import numpy as np
import subprocess
Expand All @@ -7,6 +20,23 @@
from pathlib import Path
from torch import Tensor


def get_bf16_repr(input_tensor: torch.Tensor) -> np.ndarray:
"""Convert a bfloat16 tensor to an equivalent byte representation in Numpy.
This is a vectorized implementation inspired from https://github.com/starkat99/half-rs/blob/main/src/bfloat/convert.rs
(shared under Apache 2.0 license at https://github.com/starkat99/half-rs/blob/main/LICENSES/Apache-2.0.txt)
"""
v_fp32 = input_tensor.cpu().float().numpy()
byte_array = np.frombuffer(v_fp32.tobytes(), dtype=np.uint32)
nan_value = np.logical_or(np.right_shift(byte_array, 16), 0x0040)
nan_mask = np.logical_and(byte_array, 0x7FFF_FFFF) > 0x7F80_0000
round_bit = 0x0000_8000
output_val = np.right_shift(byte_array, 16)
threshold_mask = (np.logical_and(byte_array, round_bit) != 0) & (np.logical_and(byte_array, (3*round_bit-1)) != 0)
output = np.where(nan_mask, nan_value, np.where(threshold_mask, output_val+1, output_val)).astype(np.uint16)
return output


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -64,7 +94,10 @@
if args.suffix:
k = k.split(".")[-1]
if isinstance(v, Tensor):
tensor = v.cpu().numpy()
if v.dtype == torch.bfloat16:
tensor = get_bf16_repr(v)
else:
tensor = v.cpu().numpy()
if args.dtype is not None:
nps[k] = np.ascontiguousarray(tensor.astype(np.dtype(args.dtype)))
else:
Expand Down

0 comments on commit a74d023

Please sign in to comment.