Skip to content

Commit 0069e5f

Browse files
committed
Bump PyTensor dependency
1 parent ddd1d4b commit 0069e5f

11 files changed

+20
-21
lines changed

conda-envs/environment-dev.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ dependencies:
1414
- numpy>=1.15.0
1515
- pandas>=0.24.0
1616
- pip
17-
- pytensor>=2.14.1,<2.15
17+
- pytensor>=2.15.0,<2.16
1818
- python-graphviz
1919
- networkx
2020
- scipy>=1.4.1

conda-envs/environment-docs.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ dependencies:
1212
- numpy>=1.15.0
1313
- pandas>=0.24.0
1414
- pip
15-
- pytensor>=2.14.1,<2.15
15+
- pytensor>=2.15.0,<2.16
1616
- python-graphviz
1717
- scipy>=1.4.1
1818
- typing-extensions>=3.7.4

conda-envs/environment-test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ dependencies:
1717
- numpy>=1.15.0
1818
- pandas>=0.24.0
1919
- pip
20-
- pytensor>=2.14.1,<2.15
20+
- pytensor>=2.15.0,<2.16
2121
- python-graphviz
2222
- networkx
2323
- scipy>=1.4.1

conda-envs/windows-environment-dev.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ dependencies:
1414
- numpy>=1.15.0
1515
- pandas>=0.24.0
1616
- pip
17-
- pytensor>=2.14.1,<2.15
17+
- pytensor>=2.15.0,<2.16
1818
- python-graphviz
1919
- networkx
2020
- scipy>=1.4.1

conda-envs/windows-environment-test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ dependencies:
1717
- numpy>=1.15.0
1818
- pandas>=0.24.0
1919
- pip
20-
- pytensor>=2.14.1,<2.15
20+
- pytensor>=2.15.0,<2.16
2121
- python-graphviz
2222
- networkx
2323
- scipy>=1.4.1

pymc/logprob/rewriting.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@
6666
SequenceDB,
6767
TopoDB,
6868
)
69+
from pytensor.tensor.basic import Alloc
6970
from pytensor.tensor.elemwise import DimShuffle, Elemwise
70-
from pytensor.tensor.extra_ops import BroadcastTo
7171
from pytensor.tensor.random.rewriting import local_subtensor_rv_lift
7272
from pytensor.tensor.rewriting.basic import register_canonicalize
7373
from pytensor.tensor.rewriting.shape import ShapeFeature
@@ -283,7 +283,7 @@ def request_measurable(self, vars: Sequence[Variable]) -> List[Variable]:
283283

284284

285285
@register_canonicalize
286-
@node_rewriter((Elemwise, BroadcastTo, DimShuffle) + subtensor_ops)
286+
@node_rewriter((Elemwise, Alloc, DimShuffle) + subtensor_ops)
287287
def local_lift_DiracDelta(fgraph, node):
288288
r"""Lift basic `Op`\s through `DiracDelta`\s."""
289289

pymc/logprob/tensor.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,8 @@
4141
from pytensor import tensor as pt
4242
from pytensor.graph.op import compute_test_value
4343
from pytensor.graph.rewriting.basic import node_rewriter
44-
from pytensor.tensor.basic import Join, MakeVector
44+
from pytensor.tensor.basic import Alloc, Join, MakeVector
4545
from pytensor.tensor.elemwise import DimShuffle
46-
from pytensor.tensor.extra_ops import BroadcastTo
4746
from pytensor.tensor.random.op import RandomVariable
4847
from pytensor.tensor.random.rewriting import (
4948
local_dimshuffle_rv_lift,
@@ -59,9 +58,9 @@
5958
from pymc.logprob.utils import check_potential_measurability
6059

6160

62-
@node_rewriter([BroadcastTo])
61+
@node_rewriter([Alloc])
6362
def naive_bcast_rv_lift(fgraph, node):
64-
"""Lift a ``BroadcastTo`` through a ``RandomVariable`` ``Op``.
63+
"""Lift an ``Alloc`` through a ``RandomVariable`` ``Op``.
6564
6665
XXX: This implementation simply broadcasts the ``RandomVariable``'s
6766
parameters, which won't always work (e.g. multivariate distributions).
@@ -73,7 +72,7 @@ def naive_bcast_rv_lift(fgraph, node):
7372
"""
7473

7574
if not (
76-
isinstance(node.op, BroadcastTo)
75+
isinstance(node.op, Alloc)
7776
and node.inputs[0].owner
7877
and isinstance(node.inputs[0].owner.op, RandomVariable)
7978
):
@@ -93,7 +92,7 @@ def naive_bcast_rv_lift(fgraph, node):
9392
return None
9493

9594
if not bcast_shape:
96-
# The `BroadcastTo` is broadcasting a scalar to a scalar (i.e. doing nothing)
95+
# The `Alloc` is broadcasting a scalar to a scalar (i.e. doing nothing)
9796
assert rv_var.ndim == 0
9897
return [rv_var]
9998

requirements-dev.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ numpydoc
1818
pandas>=0.24.0
1919
polyagamma
2020
pre-commit>=2.8.0
21-
pytensor>=2.14.1,<2.15
21+
pytensor>=2.15.0,<2.16
2222
pytest-cov>=2.5
2323
pytest>=3.0
2424
scipy>=1.4.1

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@ cloudpickle
44
fastprogress>=0.2.0
55
numpy>=1.15.0
66
pandas>=0.24.0
7-
pytensor>=2.14.1,<2.15
7+
pytensor>=2.15.0,<2.16
88
scipy>=1.4.1
99
typing-extensions>=3.7.4

tests/logprob/test_tensor.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from pytensor.graph import RewriteDatabaseQuery
4343
from pytensor.graph.rewriting.basic import in2out
4444
from pytensor.graph.rewriting.utils import rewrite_graph
45-
from pytensor.tensor.extra_ops import BroadcastTo
45+
from pytensor.tensor.basic import Alloc
4646
from scipy import stats as st
4747

4848
from pymc.logprob.basic import conditional_logp, logp
@@ -52,12 +52,12 @@
5252

5353

5454
def test_naive_bcast_rv_lift():
55-
r"""Make sure `naive_bcast_rv_lift` can handle useless scalar `BroadcastTo`\s."""
55+
r"""Make sure `naive_bcast_rv_lift` can handle useless scalar `Alloc`\s."""
5656
X_rv = pt.random.normal()
57-
Z_at = BroadcastTo()(X_rv, ())
57+
Z_at = Alloc()(X_rv, *())
5858

5959
# Make sure we're testing what we intend to test
60-
assert isinstance(Z_at.owner.op, BroadcastTo)
60+
assert isinstance(Z_at.owner.op, Alloc)
6161

6262
res = rewrite_graph(Z_at, custom_rewrite=in2out(naive_bcast_rv_lift), clone=False)
6363
assert res is X_rv

tests/logprob/test_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,8 @@ def test_dirac_delta():
180180
@pytest.mark.parametrize(
181181
"dist_params, obs",
182182
[
183-
((np.array(0, dtype=np.float64),), np.array([0, 0.5, 1, -1], dtype=np.float64)),
184-
((np.array([0, 0], dtype=np.int64),), np.array(0, dtype=np.int64)),
183+
((np.array([0, 0, 0, 0], dtype=np.float64),), np.array([0, 0.5, 1, -1], dtype=np.float64)),
184+
((np.array(0, dtype=np.int64),), np.array(0, dtype=np.int64)),
185185
],
186186
)
187187
def test_dirac_delta_logprob(dist_params, obs):

0 commit comments

Comments
 (0)