Skip to content

Commit

Permalink
Add: Double precision on Sapphire Rapids
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Nov 22, 2023
1 parent 653f91a commit e5175b4
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 23 deletions.
22 changes: 16 additions & 6 deletions cpp/bench.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,18 @@ static void measure(bm::State& state, metric_at metric, metric_at baseline) {
template <typename scalar_at, typename metric_at = void>
void register_(std::string name, metric_at* distance_func, metric_at* baseline_func) {

using pair_dims_t = vectors_pair_gt<scalar_at, 1536>;
using pair_bytes_t = vectors_pair_gt<scalar_at, 1536 / sizeof(scalar_at)>;

std::size_t seconds = 10;
std::size_t threads = 1; // std::thread::hardware_concurrency(); // 1;
std::string name_dims = name + "_" + std::to_string(pair_dims_t{}.dimensions()) + "d";
std::string name_bytes = name + "_" + std::to_string(pair_bytes_t{}.size_bytes()) + "b";
std::size_t threads = std::thread::hardware_concurrency(); // 1;

using pair_dims_t = vectors_pair_gt<scalar_at, 1536>;
std::string name_dims = name + "_" + std::to_string(pair_dims_t{}.dimensions()) + "d";
bm::RegisterBenchmark(name_dims.c_str(), measure<pair_dims_t, metric_at*>, distance_func, baseline_func)
->MinTime(seconds)
->Threads(threads);

return;
using pair_bytes_t = vectors_pair_gt<scalar_at, 1536 / sizeof(scalar_at)>;
std::string name_bytes = name + "_" + std::to_string(pair_bytes_t{}.size_bytes()) + "b";
bm::RegisterBenchmark(name_bytes.c_str(), measure<pair_bytes_t, metric_at*>, distance_func, baseline_func)
->MinTime(seconds)
->Threads(threads);
Expand Down Expand Up @@ -183,6 +184,11 @@ int main(int argc, char** argv) {
register_<simsimd_f32_t>("avx512_f32_l2sq", simsimd_avx512_f32_l2sq, simsimd_accurate_f32_l2sq);
register_<simsimd_f32_t>("avx512_f32_kl", simsimd_avx512_f32_kl, simsimd_accurate_f32_kl);
register_<simsimd_f32_t>("avx512_f32_js", simsimd_avx512_f32_js, simsimd_accurate_f32_js);

register_<simsimd_f64_t>("avx512_f64_ip", simsimd_avx512_f64_ip, simsimd_serial_f64_ip);
register_<simsimd_f64_t>("avx512_f64_cos", simsimd_avx512_f64_cos, simsimd_serial_f64_cos);
register_<simsimd_f64_t>("avx512_f64_l2sq", simsimd_avx512_f64_l2sq, simsimd_serial_f64_l2sq);

#endif

register_<simsimd_f16_t>("serial_f16_ip", simsimd_serial_f16_ip, simsimd_accurate_f16_ip);
Expand All @@ -197,6 +203,10 @@ int main(int argc, char** argv) {
register_<simsimd_f32_t>("serial_f32_kl", simsimd_serial_f32_kl, simsimd_accurate_f32_kl);
register_<simsimd_f32_t>("serial_f32_js", simsimd_serial_f32_js, simsimd_accurate_f32_js);

register_<simsimd_f64_t>("serial_f64_ip", simsimd_serial_f64_ip, simsimd_serial_f64_ip);
register_<simsimd_f64_t>("serial_f64_cos", simsimd_serial_f64_cos, simsimd_serial_f64_cos);
register_<simsimd_f64_t>("serial_f64_l2sq", simsimd_serial_f64_l2sq, simsimd_serial_f64_l2sq);

register_<simsimd_i8_t>("serial_i8_cos", simsimd_serial_i8_cos, simsimd_accurate_i8_cos);
register_<simsimd_i8_t>("serial_i8_l2sq", simsimd_serial_i8_l2sq, simsimd_accurate_i8_l2sq);

Expand Down
7 changes: 5 additions & 2 deletions include/simsimd/probability.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
simsimd_##accumulator_type##_t bi = converter(b[i]); \
d += ai * SIMSIMD_LOG((ai + epsilon) / (bi + epsilon)); \
} \
return d; \
return (simsimd_f32_t)d; \
}

#define SIMSIMD_MAKE_JS(name, input_type, accumulator_type, converter, epsilon) \
Expand All @@ -46,13 +46,16 @@
d += ai * SIMSIMD_LOG((ai + epsilon) / (mi + epsilon)); \
d += bi * SIMSIMD_LOG((bi + epsilon) / (mi + epsilon)); \
} \
return d / 2; \
return (simsimd_f32_t)d / 2; \
}

#ifdef __cplusplus
extern "C" {
#endif

SIMSIMD_MAKE_KL(serial, f64, f64, SIMSIMD_IDENTIFY, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_serial_f64_kl
SIMSIMD_MAKE_JS(serial, f64, f64, SIMSIMD_IDENTIFY, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_serial_f64_js

SIMSIMD_MAKE_KL(serial, f32, f32, SIMSIMD_IDENTIFY, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_serial_f32_kl
SIMSIMD_MAKE_JS(serial, f32, f32, SIMSIMD_IDENTIFY, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_serial_f32_js

Expand Down
25 changes: 24 additions & 1 deletion include/simsimd/simsimd.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,30 @@ inline static void simsimd_find_metric_punned( //
switch (datatype) {

case simsimd_datatype_unknown_k: break;
case simsimd_datatype_f64_k: break;

// Double-precision floating-point vectors
case simsimd_datatype_f64_k:

#if SIMSIMD_TARGET_X86_AVX512
if (viable & simsimd_cap_x86_avx512_k)
switch (kind) {
case simsimd_metric_ip_k: *m = (simsimd_metric_punned_t)&simsimd_avx512_f64_ip, *c = simsimd_cap_x86_avx512_k; return;
case simsimd_metric_cos_k: *m = (simsimd_metric_punned_t)&simsimd_avx512_f64_cos, *c = simsimd_cap_x86_avx512_k; return;
case simsimd_metric_l2sq_k: *m = (simsimd_metric_punned_t)&simsimd_avx512_f64_l2sq, *c = simsimd_cap_x86_avx512_k; return;
default: break;
}
#endif
if (viable & simsimd_cap_serial_k)
switch (kind) {
case simsimd_metric_ip_k: *m = (simsimd_metric_punned_t)&simsimd_serial_f64_ip, *c = simsimd_cap_serial_k; return;
case simsimd_metric_cos_k: *m = (simsimd_metric_punned_t)&simsimd_serial_f64_cos, *c = simsimd_cap_serial_k; return;
case simsimd_metric_l2sq_k: *m = (simsimd_metric_punned_t)&simsimd_serial_f64_l2sq, *c = simsimd_cap_serial_k; return;
case simsimd_metric_js_k: *m = (simsimd_metric_punned_t)&simsimd_serial_f64_js, *c = simsimd_cap_serial_k; return;
case simsimd_metric_kl_k: *m = (simsimd_metric_punned_t)&simsimd_serial_f64_kl, *c = simsimd_cap_serial_k; return;
default: break;
}

break;

// Single-precision floating-point vectors
case simsimd_datatype_f32_k:
Expand Down
99 changes: 99 additions & 0 deletions include/simsimd/spatial.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@
extern "C" {
#endif

SIMSIMD_MAKE_L2SQ(serial, f64, f64, SIMSIMD_IDENTIFY) // simsimd_serial_f64_l2sq
SIMSIMD_MAKE_IP(serial, f64, f64, SIMSIMD_IDENTIFY) // simsimd_serial_f64_ip
SIMSIMD_MAKE_COS(serial, f64, f64, SIMSIMD_IDENTIFY) // simsimd_serial_f64_cos

SIMSIMD_MAKE_L2SQ(serial, f32, f32, SIMSIMD_IDENTIFY) // simsimd_serial_f32_l2sq
SIMSIMD_MAKE_IP(serial, f32, f32, SIMSIMD_IDENTIFY) // simsimd_serial_f32_ip
SIMSIMD_MAKE_COS(serial, f32, f32, SIMSIMD_IDENTIFY) // simsimd_serial_f32_cos
Expand Down Expand Up @@ -996,6 +1000,101 @@ simsimd_avx512_i8_ip(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_
return simsimd_avx512_i8_cos(a, b, n);
}

/*
* @file x86_avx512_f32.h
* @brief x86 AVX-512 implementation of the most common similarity metrics for 32-bit floating point numbers.
* @author Ash Vardanian
*
* - Implements: L2 squared, inner product, cosine similarity.
* - Uses `f32` for storage and `f32` for accumulation.
* - Requires compiler capabilities: avx512f, avx512vl, bmi2.
*/

__attribute__((target("avx512f,avx512vl,bmi2"))) //
inline static simsimd_f32_t
simsimd_avx512_f64_l2sq(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n) {
__m512d d2_vec = _mm512_set1_pd(0);
__m512d a_vec, b_vec;

simsimd_avx512_f64_l2sq_cycle:
if (n < 8) {
__mmask8 mask = _bzhi_u32(0xFFFFFFFF, n);
a_vec = _mm512_maskz_loadu_pd(mask, a);
b_vec = _mm512_maskz_loadu_pd(mask, b);
n = 0;
} else {
a_vec = _mm512_loadu_pd(a);
b_vec = _mm512_loadu_pd(b);
a += 8, b += 8, n -= 8;
}
__m512d d_vec = _mm512_sub_pd(a_vec, b_vec);
d2_vec = _mm512_fmadd_pd(d_vec, d_vec, d2_vec);
if (n)
goto simsimd_avx512_f64_l2sq_cycle;

return (simsimd_f32_t)_mm512_reduce_add_pd(d2_vec);
}

__attribute__((target("avx512f,avx512vl,bmi2"))) //
inline static simsimd_f32_t
simsimd_avx512_f64_ip(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n) {
__m512d ab_vec = _mm512_set1_pd(0);
__m512d a_vec, b_vec;

simsimd_avx512_f64_ip_cycle:
if (n < 8) {
__mmask8 mask = _bzhi_u32(0xFFFFFFFF, n);
a_vec = _mm512_maskz_loadu_pd(mask, a);
b_vec = _mm512_maskz_loadu_pd(mask, b);
n = 0;
} else {
a_vec = _mm512_loadu_pd(a);
b_vec = _mm512_loadu_pd(b);
a += 8, b += 8, n -= 8;
}
ab_vec = _mm512_fmadd_pd(a_vec, b_vec, ab_vec);
if (n)
goto simsimd_avx512_f64_ip_cycle;

return 1 - (simsimd_f32_t)_mm512_reduce_add_pd(ab_vec);
}

__attribute__((target("avx512f,avx512vl,bmi2"))) //
inline static simsimd_f32_t
simsimd_avx512_f64_cos(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n) {
__m512d ab_vec = _mm512_set1_pd(0);
__m512d a2_vec = _mm512_set1_pd(0);
__m512d b2_vec = _mm512_set1_pd(0);
__m512d a_vec, b_vec;

simsimd_avx512_f64_cos_cycle:
if (n < 8) {
__mmask8 mask = _bzhi_u32(0xFFFFFFFF, n);
a_vec = _mm512_maskz_loadu_pd(mask, a);
b_vec = _mm512_maskz_loadu_pd(mask, b);
n = 0;
} else {
a_vec = _mm512_loadu_pd(a);
b_vec = _mm512_loadu_pd(b);
a += 8, b += 8, n -= 8;
}
ab_vec = _mm512_fmadd_pd(a_vec, b_vec, ab_vec);
a2_vec = _mm512_fmadd_pd(a_vec, a_vec, a2_vec);
b2_vec = _mm512_fmadd_pd(b_vec, b_vec, b2_vec);
if (n)
goto simsimd_avx512_f64_cos_cycle;

simsimd_f32_t ab = (simsimd_f32_t)_mm512_reduce_add_pd(ab_vec);
simsimd_f32_t a2 = (simsimd_f32_t)_mm512_reduce_add_pd(a2_vec);
simsimd_f32_t b2 = (simsimd_f32_t)_mm512_reduce_add_pd(b2_vec);

// Compute the reciprocal square roots of a2 and b2
__m128 rsqrts = _mm_rsqrt14_ps(_mm_set_ps(0.f, 0.f, a2 + 1.e-9f, b2 + 1.e-9f));
simsimd_f32_t rsqrt_a2 = _mm_cvtss_f32(rsqrts);
simsimd_f32_t rsqrt_b2 = _mm_cvtss_f32(_mm_shuffle_ps(rsqrts, rsqrts, _MM_SHUFFLE(0, 0, 0, 1)));
return 1 - ab * rsqrt_a2 * rsqrt_b2;
}

#endif // SIMSIMD_TARGET_X86_AVX512
#endif // SIMSIMD_TARGET_X86

Expand Down
42 changes: 32 additions & 10 deletions python/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def wrapped(A, B):
ndim = 1536

generators = {
np.float64: lambda: np.random.randn(count, ndim).astype(np.float64),
np.float32: lambda: np.random.randn(count, ndim).astype(np.float32),
np.float16: lambda: np.random.randn(count, ndim).astype(np.float16),
np.int8: lambda: np.random.randint(-100, 100, (count, ndim), np.int8),
Expand All @@ -67,6 +68,7 @@ def wrapped(A, B):
}

dtype_names = {
np.float64: "f64",
np.float32: "f32",
np.float16: "f16",
np.int8: "i8",
Expand All @@ -88,25 +90,35 @@ def wrapped(A, B):

# Benchmark functions
funcs = [
("scipy.cosine", spd.cosine, simd.cosine, [np.float32, np.float16, np.int8]),
(
"scipy.cosine",
spd.cosine,
simd.cosine,
[np.float64, np.float32, np.float16, np.int8],
),
(
"scipy.sqeuclidean",
spd.sqeuclidean,
simd.sqeuclidean,
[np.float32, np.float16, np.int8],
[np.float64, np.float32, np.float16, np.int8],
),
(
"numpy.inner",
np.inner,
simd.inner,
[np.float64, np.float32, np.float16, np.int8],
),
("numpy.inner", np.inner, simd.inner, [np.float32, np.float16, np.int8]),
(
"scipy.jensenshannon",
spd.jensenshannon,
simd.jensenshannon,
[np.float32, np.float16],
[np.float64, np.float32, np.float16],
),
(
"scipy.kl_div",
scs.kl_div,
simd.kullbackleibler,
[np.float32, np.float16],
[np.float64, np.float32, np.float16],
),
("scipy.hamming", spd.hamming, simd.hamming, [np.uint8]),
("scipy.jaccard", spd.jaccard, simd.jaccard, [np.uint8]),
Expand Down Expand Up @@ -147,25 +159,35 @@ def wrapped(A, B):

# Benchmark functions
funcs = [
("scipy.cosine", spd.cosine, simd.cosine, [np.float32, np.float16, np.int8]),
(
"scipy.cosine",
spd.cosine,
simd.cosine,
[np.float64, np.float32, np.float16, np.int8],
),
(
"scipy.sqeuclidean",
spd.sqeuclidean,
simd.sqeuclidean,
[np.float32, np.float16, np.int8],
[np.float64, np.float32, np.float16, np.int8],
),
(
"numpy.inner",
np.inner,
simd.inner,
[np.float64, np.float32, np.float16, np.int8],
),
("numpy.inner", np.inner, simd.inner, [np.float32, np.float16, np.int8]),
(
"scipy.jensenshannon",
spd.jensenshannon,
simd.jensenshannon,
[np.float32, np.float16],
[np.float64, np.float32, np.float16],
),
(
"scipy.kl_div",
scs.kl_div,
simd.kullbackleibler,
[np.float32, np.float16],
[np.float64, np.float32, np.float16],
),
("scipy.hamming", spd.hamming, simd.hamming, [np.uint8]),
("scipy.jaccard", spd.jaccard, simd.jaccard, [np.uint8]),
Expand Down
12 changes: 8 additions & 4 deletions python/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@

def test_pointers_availability():
"""Tests the availability of pre-compiled functions for compatibility with USearch."""
assert simd.pointer_to_sqeuclidean("f64") != 0
assert simd.pointer_to_cosine("f64") != 0
assert simd.pointer_to_inner("f64") != 0

assert simd.pointer_to_sqeuclidean("f32") != 0
assert simd.pointer_to_cosine("f32") != 0
assert simd.pointer_to_inner("f32") != 0
Expand All @@ -26,7 +30,7 @@ def test_pointers_availability():

@pytest.mark.repeat(50)
@pytest.mark.parametrize("ndim", [3, 97, 1536])
@pytest.mark.parametrize("dtype", [np.float32, np.float16])
@pytest.mark.parametrize("dtype", [np.float64, np.float32, np.float16])
def test_dot(ndim, dtype):
"""Compares the simd.dot() function with numpy.dot(), measuring the accuracy error for f16, and f32 types."""
np.random.seed()
Expand All @@ -43,7 +47,7 @@ def test_dot(ndim, dtype):

@pytest.mark.repeat(50)
@pytest.mark.parametrize("ndim", [3, 97, 1536])
@pytest.mark.parametrize("dtype", [np.float32, np.float16])
@pytest.mark.parametrize("dtype", [np.float64, np.float32, np.float16])
def test_sqeuclidean(ndim, dtype):
"""Compares the simd.sqeuclidean() function with scipy.spatial.distance.sqeuclidean(), measuring the accuracy error for f16, and f32 types."""
np.random.seed()
Expand All @@ -58,7 +62,7 @@ def test_sqeuclidean(ndim, dtype):

@pytest.mark.repeat(50)
@pytest.mark.parametrize("ndim", [3, 97, 1536])
@pytest.mark.parametrize("dtype", [np.float32, np.float16])
@pytest.mark.parametrize("dtype", [np.float64, np.float32, np.float16])
def test_cosine(ndim, dtype):
"""Compares the simd.cosine() function with scipy.spatial.distance.cosine(), measuring the accuracy error for f16, and f32 types."""
np.random.seed()
Expand Down Expand Up @@ -164,7 +168,7 @@ def test_jaccard(ndim):


@pytest.mark.parametrize("ndim", [3, 97, 1536])
@pytest.mark.parametrize("dtype", [np.float32, np.float16])
@pytest.mark.parametrize("dtype", [np.float64, np.float32, np.float16])
def test_batch(ndim, dtype):
"""Compares the simd.simd.sqeuclidean() function with scipy.spatial.distance.sqeuclidean() for a batch of vectors, measuring the accuracy error for f16, and f32 types."""
np.random.seed()
Expand Down

0 comments on commit e5175b4

Please sign in to comment.