From c98f12df5e03d51eeb1a9becda5cd8ac9f58e130 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 1 Jan 2025 13:06:29 +0200 Subject: [PATCH] MAINT: stats: explicitly work around f32/f64 dtype selection --- scipy/stats/_continued_fraction.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/scipy/stats/_continued_fraction.py b/scipy/stats/_continued_fraction.py index 4b08c389f1c8..7e02fa66a253 100644 --- a/scipy/stats/_continued_fraction.py +++ b/scipy/stats/_continued_fraction.py @@ -1,6 +1,8 @@ import numpy as np -from scipy._lib._array_api import array_namespace, xp_ravel, xp_copy +from scipy._lib._array_api import ( + array_namespace, xp_ravel, xp_copy, is_torch, xp_default_dtype +) import scipy._lib._elementwise_iterative_method as eim from scipy._lib._util import _RichResult from scipy import special @@ -291,8 +293,15 @@ def func(n, *args): # on each callable to get the shape and dtype, then we broadcast these # shapes, compute the result dtype, and broadcast/promote the zeroth terms # and `*args` to this shape/dtype. + # `float32` here avoids influencing precision of resulting float type - zero = xp.asarray(0, dtype=xp.float32) + # patch up promotion: in numpy (int64, float32) -> float64, while in torch + # (int64, float32) -> float32 irrespective of the default_dtype. + dt = {'dtype': None + if is_torch(xp) and xp_default_dtype(xp) == xp.float64 + else xp.float32} + zero = xp.asarray(0, **dt) + temp = eim._initialize(a, (zero,), args, complex_ok=True) _, _, fs_a, _, shape_a, dtype_a, xp_a = temp temp = eim._initialize(b, (zero,), args, complex_ok=True)