Skip to content

Commit

Permalink
MAINT: stats: explicitly work around f32/f64 dtype selection
Browse files Browse the repository at this point in the history
  • Loading branch information
ev-br committed Jan 4, 2025
1 parent 8e12b27 commit c98f12d
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions scipy/stats/_continued_fraction.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import numpy as np

from scipy._lib._array_api import array_namespace, xp_ravel, xp_copy
from scipy._lib._array_api import (
array_namespace, xp_ravel, xp_copy, is_torch, xp_default_dtype
)
import scipy._lib._elementwise_iterative_method as eim
from scipy._lib._util import _RichResult
from scipy import special
Expand Down Expand Up @@ -291,8 +293,15 @@ def func(n, *args):
# on each callable to get the shape and dtype, then we broadcast these
# shapes, compute the result dtype, and broadcast/promote the zeroth terms
# and `*args` to this shape/dtype.

# `float32` here avoids influencing precision of resulting float type
zero = xp.asarray(0, dtype=xp.float32)
# patch up promotion: in numpy (int64, float32) -> float64, while in torch
# (int64, float32) -> float32 irrespective of the default_dtype.
dt = {'dtype': None
if is_torch(xp) and xp_default_dtype(xp) == xp.float64
else xp.float32}
zero = xp.asarray(0, **dt)

temp = eim._initialize(a, (zero,), args, complex_ok=True)
_, _, fs_a, _, shape_a, dtype_a, xp_a = temp
temp = eim._initialize(b, (zero,), args, complex_ok=True)
Expand Down

0 comments on commit c98f12d

Please sign in to comment.