Skip to content

Commit

Permalink
aarch64 fp16 math operations
Browse files Browse the repository at this point in the history
  • Loading branch information
starkat99 committed Jan 23, 2023
1 parent b8391de commit 95c9afa
Show file tree
Hide file tree
Showing 6 changed files with 300 additions and 96 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use-intrinsics = []
alloc = []

[dependencies]
cfg-if = "1.0.0"
bytemuck = { version = "1.4.1", default-features = false, features = [
"derive",
], optional = true }
Expand Down
36 changes: 18 additions & 18 deletions src/binary16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use serde::{Deserialize, Serialize};
#[cfg(feature = "zerocopy")]
use zerocopy::{AsBytes, FromBytes};

pub(crate) mod convert;
pub(crate) mod arch;

/// A 16-bit floating point type implementing the IEEE 754-2008 standard [`binary16`] a.k.a `half`
/// format.
Expand Down Expand Up @@ -57,7 +57,7 @@ impl f16 {
#[inline]
#[must_use]
pub fn from_f32(value: f32) -> f16 {
f16(convert::f32_to_f16(value))
f16(arch::f32_to_f16(value))
}

/// Constructs a 16-bit floating point value from a 32-bit floating point value.
Expand All @@ -74,7 +74,7 @@ impl f16 {
#[inline]
#[must_use]
pub const fn from_f32_const(value: f32) -> f16 {
f16(convert::f32_to_f16_fallback(value))
f16(arch::f32_to_f16_fallback(value))
}

/// Constructs a 16-bit floating point value from a 64-bit floating point value.
Expand All @@ -87,7 +87,7 @@ impl f16 {
#[inline]
#[must_use]
pub fn from_f64(value: f64) -> f16 {
f16(convert::f64_to_f16(value))
f16(arch::f64_to_f16(value))
}

/// Constructs a 16-bit floating point value from a 64-bit floating point value.
Expand All @@ -104,7 +104,7 @@ impl f16 {
#[inline]
#[must_use]
pub const fn from_f64_const(value: f64) -> f16 {
f16(convert::f64_to_f16_fallback(value))
f16(arch::f64_to_f16_fallback(value))
}

/// Converts a [`f16`] into the underlying bit representation.
Expand Down Expand Up @@ -230,7 +230,7 @@ impl f16 {
#[inline]
#[must_use]
pub fn to_f32(self) -> f32 {
convert::f16_to_f32(self.0)
arch::f16_to_f32(self.0)
}

/// Converts a [`f16`] value into a `f32` value.
Expand All @@ -244,7 +244,7 @@ impl f16 {
#[inline]
#[must_use]
pub const fn to_f32_const(self) -> f32 {
convert::f16_to_f32_fallback(self.0)
arch::f16_to_f32_fallback(self.0)
}

/// Converts a [`f16`] value into a `f64` value.
Expand All @@ -254,7 +254,7 @@ impl f16 {
#[inline]
#[must_use]
pub fn to_f64(self) -> f64 {
convert::f16_to_f64(self.0)
arch::f16_to_f64(self.0)
}

/// Converts a [`f16`] value into a `f64` value.
Expand All @@ -268,7 +268,7 @@ impl f16 {
#[inline]
#[must_use]
pub const fn to_f64_const(self) -> f64 {
convert::f16_to_f64_fallback(self.0)
arch::f16_to_f64_fallback(self.0)
}

/// Returns `true` if this value is `NaN` and `false` otherwise.
Expand Down Expand Up @@ -1006,7 +1006,7 @@ impl Add for f16 {

#[inline]
fn add(self, rhs: Self) -> Self::Output {
Self::from_f32(Self::to_f32(self) + Self::to_f32(rhs))
f16(arch::add_f16(self.0, rhs.0))
}
}

Expand Down Expand Up @@ -1056,7 +1056,7 @@ impl Sub for f16 {

#[inline]
fn sub(self, rhs: Self) -> Self::Output {
Self::from_f32(Self::to_f32(self) - Self::to_f32(rhs))
f16(arch::subtract_f16(self.0, rhs.0))
}
}

Expand Down Expand Up @@ -1106,7 +1106,7 @@ impl Mul for f16 {

#[inline]
fn mul(self, rhs: Self) -> Self::Output {
Self::from_f32(Self::to_f32(self) * Self::to_f32(rhs))
f16(arch::multiply_f16(self.0, rhs.0))
}
}

Expand Down Expand Up @@ -1156,7 +1156,7 @@ impl Div for f16 {

#[inline]
fn div(self, rhs: Self) -> Self::Output {
Self::from_f32(Self::to_f32(self) / Self::to_f32(rhs))
f16(arch::divide_f16(self.0, rhs.0))
}
}

Expand Down Expand Up @@ -1206,7 +1206,7 @@ impl Rem for f16 {

#[inline]
fn rem(self, rhs: Self) -> Self::Output {
Self::from_f32(Self::to_f32(self) % Self::to_f32(rhs))
f16(arch::remainder_f16(self.0, rhs.0))
}
}

Expand Down Expand Up @@ -1254,28 +1254,28 @@ impl RemAssign<&f16> for f16 {
impl Product for f16 {
#[inline]
fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
f16::from_f32(iter.map(|f| f.to_f32()).product())
f16(arch::product_f16(iter.map(|f| f.to_bits())))
}
}

impl<'a> Product<&'a f16> for f16 {
#[inline]
fn product<I: Iterator<Item = &'a f16>>(iter: I) -> Self {
f16::from_f32(iter.map(|f| f.to_f32()).product())
f16(arch::product_f16(iter.map(|f| f.to_bits())))
}
}

impl Sum for f16 {
#[inline]
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
f16::from_f32(iter.map(|f| f.to_f32()).sum())
f16(arch::sum_f16(iter.map(|f| f.to_bits())))
}
}

impl<'a> Sum<&'a f16> for f16 {
#[inline]
fn sum<I: Iterator<Item = &'a f16>>(iter: I) -> Self {
f16::from_f32(iter.map(|f| f.to_f32()).product())
f16(arch::sum_f16(iter.map(|f| f.to_bits())))
}
}

Expand Down
Loading

0 comments on commit 95c9afa

Please sign in to comment.