Skip to content

Commit c3c76c5

Browse files
Run mypy in pre-commit for continuous improvement of type hints (pymc-devs#5549)
* Remove numpy Tester instance from public API It was making mypy follow into the tests module regardles of its config. See https://stackoverflow.com/a/70367929 * Move all mypy config to mypy.init Because typing_copilot can't deal with the other formats. * Fix typing problems and add pre-commit mypy step
1 parent b919153 commit c3c76c5

17 files changed

+283
-71
lines changed

.pre-commit-config.yaml

+18
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,24 @@ repos:
1313
- id: requirements-txt-fixer
1414
exclude: ^requirements-dev\.txt$
1515
- id: trailing-whitespace
16+
- repo: https://github.com/pre-commit/mirrors-mypy
17+
rev: v0.931
18+
hooks:
19+
- id: mypy
20+
name: Run static type checks
21+
language: python
22+
entry: python ./scripts/run_mypy.py --verbose
23+
additional_dependencies:
24+
- pandas
25+
- types-cachetools
26+
- types-filelock
27+
- types-setuptools
28+
- arviz
29+
- aesara==2.4.0
30+
- aeppl==0.0.26
31+
always_run: true
32+
require_serial: true
33+
pass_filenames: false
1634
- repo: https://github.com/PyCQA/isort
1735
rev: 5.10.1
1836
hooks:

mypy.ini

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Autogenerated by typing_copilot v0.6.0
2+
[mypy]
3+
no_implicit_optional = False
4+
strict_optional = True
5+
warn_redundant_casts = False
6+
check_untyped_defs = False
7+
disallow_untyped_calls = False
8+
disallow_incomplete_defs = False
9+
disallow_untyped_defs = False
10+
disallow_untyped_decorators = False
11+
ignore_missing_imports = True
12+
warn_unused_ignores = False

pymc/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def __set_compiler_flags():
7575
from pymc.smc import *
7676
from pymc.stats import *
7777
from pymc.step_methods import *
78-
from pymc.tests import test
7978
from pymc.tuning import *
8079
from pymc.variational import *
8180
from pymc.vartypes import *

pymc/backends/arviz.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,9 @@ class _DefaultTrace:
8282
`insert()` method
8383
"""
8484

85-
trace_dict: Dict[str, np.ndarray] = {}
86-
_len: Optional[int] = None
87-
8885
def __init__(self, samples: int):
89-
self._len = samples
90-
self.trace_dict = {}
86+
self._len: int = samples
87+
self.trace_dict: Dict[str, np.ndarray] = {}
9188

9289
def insert(self, k: str, v, idx: int):
9390
"""
@@ -180,10 +177,10 @@ def __init__(
180177
" one of trace, prior, posterior_predictive or predictions."
181178
)
182179

183-
self.coords = {**self.model.coords, **(coords or {})}
180+
untyped_coords = {**self.model.coords, **(coords or {})}
184181
self.coords = {
185182
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
186-
for cname, cvals in self.coords.items()
183+
for cname, cvals in untyped_coords.items()
187184
if cvals is not None
188185
}
189186

@@ -639,7 +636,7 @@ def predictions_to_inference_data(
639636
"""
640637
if inplace and not idata_orig:
641638
raise ValueError(
642-
"Do not pass True for inplace unless passing" "an existing InferenceData as idata_orig"
639+
"Do not pass True for inplace unless passing an existing InferenceData as idata_orig"
643640
)
644641
converter = InferenceDataConverter(
645642
trace=posterior_trace,
@@ -650,6 +647,7 @@ def predictions_to_inference_data(
650647
log_likelihood=False,
651648
)
652649
if hasattr(idata_orig, "posterior"):
650+
assert idata_orig is not None
653651
converter.nchains = idata_orig.posterior.dims["chain"]
654652
converter.ndraws = idata_orig.posterior.dims["draw"]
655653
else:

pymc/blocking.py

+16-14
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,19 @@
3535
RaveledVars = collections.namedtuple("RaveledVars", "data, point_map_info")
3636

3737

38+
class Compose:
39+
"""
40+
Compose two functions in a pickleable way
41+
"""
42+
43+
def __init__(self, fa: Callable[[PointType], T], fb: Callable[[RaveledVars], PointType]):
44+
self.fa = fa
45+
self.fb = fb
46+
47+
def __call__(self, x: RaveledVars) -> T:
48+
return self.fa(self.fb(x))
49+
50+
3851
class DictToArrayBijection:
3952
"""Map between a `dict`s of variables to an array space.
4053
@@ -86,7 +99,9 @@ def rmap(
8699
return result
87100

88101
@classmethod
89-
def mapf(cls, f: Callable[[PointType], T], start_point: Optional[PointType] = None) -> T:
102+
def mapf(
103+
cls, f: Callable[[PointType], T], start_point: Optional[PointType] = None
104+
) -> Callable[[RaveledVars], T]:
90105
"""Create a callable that first maps back to ``dict`` inputs and then applies a function.
91106
92107
function f: DictSpace -> T to ArraySpace -> T
@@ -100,16 +115,3 @@ def mapf(cls, f: Callable[[PointType], T], start_point: Optional[PointType] = No
100115
f: array -> T
101116
"""
102117
return Compose(f, partial(cls.rmap, start_point=start_point))
103-
104-
105-
class Compose:
106-
"""
107-
Compose two functions in a pickleable way
108-
"""
109-
110-
def __init__(self, fa, fb):
111-
self.fa = fa
112-
self.fb = fb
113-
114-
def __call__(self, x):
115-
return self.fa(self.fb(x))

pymc/data.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import warnings
2121

2222
from copy import copy
23-
from typing import Any, Dict, List, Optional, Sequence, Union
23+
from typing import Any, Dict, List, Optional, Sequence, Union, cast
2424

2525
import aesara
2626
import aesara.tensor as at
@@ -518,14 +518,15 @@ def ConstantData(
518518
Registers the ``value`` as a :class:`~aesara.tensor.TensorConstant` with the model.
519519
For more information, please reference :class:`pymc.Data`.
520520
"""
521-
return Data(
521+
var = Data(
522522
name,
523523
value,
524524
dims=dims,
525525
export_index_as_coords=export_index_as_coords,
526526
mutable=False,
527527
**kwargs,
528528
)
529+
return cast(TensorConstant, var)
529530

530531

531532
def MutableData(
@@ -541,14 +542,15 @@ def MutableData(
541542
Registers the ``value`` as a :class:`~aesara.compile.sharedvalue.SharedVariable`
542543
with the model. For more information, please reference :class:`pymc.Data`.
543544
"""
544-
return Data(
545+
var = Data(
545546
name,
546547
value,
547548
dims=dims,
548549
export_index_as_coords=export_index_as_coords,
549550
mutable=True,
550551
**kwargs,
551552
)
553+
return cast(SharedVariable, var)
552554

553555

554556
def Data(
@@ -626,9 +628,8 @@ def Data(
626628
value = np.array(value)
627629

628630
# Add data container to the named variables of the model.
629-
try:
630-
model = pm.Model.get_context()
631-
except TypeError:
631+
model = pm.Model.get_context(error_if_none=False)
632+
if model is None:
632633
raise TypeError(
633634
"No model on context stack, which is needed to instantiate a data container. "
634635
"Add variable inside a 'with model:' block."

pymc/distributions/continuous.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ class BoundedContinuous(Continuous):
158158
"""Base class for bounded continuous distributions"""
159159

160160
# Indices of the arguments that define the lower and upper bounds of the distribution
161-
bound_args_indices = None
161+
bound_args_indices: Optional[List[int]] = None
162162

163163
def __new__(cls, *args, **kwargs):
164164
transform = kwargs.get("transform", UNSET)

pymc/distributions/distribution.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ class Distribution(metaclass=DistributionMeta):
186186
"""Statistical distribution"""
187187

188188
rv_class = None
189-
rv_op = None
189+
rv_op: RandomVariable = None
190190

191191
def __new__(
192192
cls,

pymc/distributions/shape_utils.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -422,18 +422,18 @@ class ellipsis(Enum):
422422
ellipsis = type(Ellipsis)
423423

424424
# User-provided can be lazily specified as scalars
425-
Shape: TypeAlias = Union[int, TensorVariable, Sequence[Union[int, TensorVariable, ellipsis]]]
425+
Shape: TypeAlias = Union[int, TensorVariable, Sequence[Union[int, Variable, ellipsis]]]
426426
Dims: TypeAlias = Union[str, Sequence[Optional[Union[str, ellipsis]]]]
427-
Size: TypeAlias = Union[int, TensorVariable, Sequence[Union[int, TensorVariable]]]
427+
Size: TypeAlias = Union[int, TensorVariable, Sequence[Union[int, Variable]]]
428428

429429
# After conversion to vectors
430-
WeakShape: TypeAlias = Union[TensorVariable, Tuple[Union[int, TensorVariable, ellipsis], ...]]
430+
WeakShape: TypeAlias = Union[TensorVariable, Tuple[Union[int, Variable, ellipsis], ...]]
431431
WeakDims: TypeAlias = Tuple[Optional[Union[str, ellipsis]], ...]
432432

433433
# After Ellipsis were substituted
434-
StrongShape: TypeAlias = Union[TensorVariable, Tuple[Union[int, TensorVariable], ...]]
434+
StrongShape: TypeAlias = Union[TensorVariable, Tuple[Union[int, Variable], ...]]
435435
StrongDims: TypeAlias = Sequence[Optional[str]]
436-
StrongSize: TypeAlias = Union[TensorVariable, Tuple[Union[int, TensorVariable], ...]]
436+
StrongSize: TypeAlias = Union[TensorVariable, Tuple[Union[int, Variable], ...]]
437437

438438

439439
def convert_dims(dims: Optional[Dims]) -> Optional[WeakDims]:
@@ -593,9 +593,9 @@ def find_size(
593593
Number of support dimensions
594594
"""
595595

596-
ndim_expected = None
597-
ndim_batch = None
598-
create_size = None
596+
ndim_expected: Optional[int] = None
597+
ndim_batch: Optional[int] = None
598+
create_size: Optional[StrongSize] = None
599599

600600
if shape is not None:
601601
if Ellipsis in shape:

pymc/initial_point.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def filter_rvs_to_jitter(step) -> Set[TensorVariable]:
6868
The random variables for which jitter should be added.
6969
"""
7070
# TODO: implement this
71-
return {}
71+
return set()
7272

7373

7474
def make_initial_point_fns_per_chain(
@@ -163,12 +163,16 @@ def find_rng_nodes(variables):
163163
)
164164
]
165165

166-
overrides = convert_str_to_rv_dict(model, overrides or {})
166+
sdict_overrides = convert_str_to_rv_dict(model, overrides or {})
167+
initval_strats = {
168+
**model.initial_values,
169+
**sdict_overrides,
170+
}
167171

168172
initial_values = make_initial_point_expression(
169173
free_rvs=model.free_RVs,
170174
rvs_to_values=model.rvs_to_values,
171-
initval_strategies={**model.initial_values, **(overrides or {})},
175+
initval_strategies=initval_strats,
172176
jitter_rvs=jitter_rvs,
173177
default_strategy=default_strategy,
174178
return_transformed=return_transformed,
@@ -178,13 +182,14 @@ def find_rng_nodes(variables):
178182
# when calling the final seeded function
179183
graph = FunctionGraph(outputs=initial_values, clone=False)
180184
rng_nodes = find_rng_nodes(graph.outputs)
181-
new_rng_nodes = []
185+
new_rng_nodes: List[Union[np.random.RandomState, np.random.Generator]] = []
182186
for rng_node in rng_nodes:
187+
rng_cls: type
183188
if isinstance(rng_node, at.random.var.RandomStateSharedVariable):
184-
new_rng = np.random.RandomState(np.random.PCG64())
189+
rng_cls = np.random.RandomState
185190
else:
186-
new_rng = np.random.Generator(np.random.PCG64())
187-
new_rng_nodes.append(aesara.shared(new_rng))
191+
rng_cls = np.random.Generator
192+
new_rng_nodes.append(aesara.shared(rng_cls(np.random.PCG64())))
188193
graph.replace_all(zip(rng_nodes, new_rng_nodes), import_missing=True)
189194
func = compile_pymc(inputs=[], outputs=graph.outputs, mode=aesara.compile.mode.FAST_COMPILE)
190195

@@ -310,7 +315,7 @@ def make_initial_point_expression(
310315

311316
initial_values.append(value)
312317

313-
all_outputs = []
318+
all_outputs: List[TensorVariable] = []
314319
all_outputs.extend(free_rvs)
315320
all_outputs.extend(initial_values)
316321
all_outputs.extend(initial_values_transformed)

pymc/model.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
Type,
3232
TypeVar,
3333
Union,
34+
cast,
3435
)
3536

3637
import aesara
@@ -552,8 +553,8 @@ def __init__(
552553
self.rng_seeder = rng_seeder
553554

554555
# The sequence of model-generated RNGs
555-
self.rng_seq = []
556-
self._initial_values = {}
556+
self.rng_seq: List[SharedVariable] = []
557+
self._initial_values: Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]] = {}
557558

558559
if self.parent is not None:
559560
self.named_vars = treedict(parent=self.parent.named_vars)
@@ -721,17 +722,20 @@ def logpt(
721722
-------
722723
Logp graph(s)
723724
"""
725+
varlist: List[TensorVariable]
724726
if vars is None:
725-
vars = self.free_RVs + self.observed_RVs + self.potentials
727+
varlist = self.free_RVs + self.observed_RVs + self.potentials
726728
elif not isinstance(vars, (list, tuple)):
727-
vars = [vars]
729+
varlist = [vars]
730+
else:
731+
varlist = cast(List[TensorVariable], vars)
728732

729733
# We need to separate random variables from potential terms, and remember their
730734
# original order so that we can merge them together in the same order at the end
731735
rv_values = {}
732736
potentials = []
733737
rv_order, potential_order = [], []
734-
for i, var in enumerate(vars):
738+
for i, var in enumerate(varlist):
735739
value_var = self.rvs_to_values.get(var)
736740
if value_var is not None:
737741
rv_values[var] = value_var
@@ -756,7 +760,7 @@ def logpt(
756760
if potentials:
757761
potential_logps, _ = rvs_to_value_vars(potentials, apply_transforms=True)
758762

759-
logp_factors = [None] * len(vars)
763+
logp_factors = [None] * len(varlist)
760764
for logp_order, logp in zip((rv_order + potential_order), (rv_logps + potential_logps)):
761765
logp_factors[logp_order] = logp
762766

@@ -948,7 +952,7 @@ def coords(self) -> Dict[str, Union[Tuple, None]]:
948952
return self._coords
949953

950954
@property
951-
def dim_lengths(self) -> Dict[str, Tuple[Variable, ...]]:
955+
def dim_lengths(self) -> Dict[str, Variable]:
952956
"""The symbolic lengths of dimensions in the model.
953957
954958
The values are typically instances of ``TensorVariable`` or ``ScalarSharedVariable``.

pymc/parallel_sampling.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import traceback
2222

2323
from collections import namedtuple
24-
from typing import Dict, Sequence
24+
from typing import Dict, List, Sequence
2525

2626
import cloudpickle
2727
import numpy as np
@@ -425,8 +425,8 @@ def __init__(
425425
]
426426

427427
self._inactive = self._samplers.copy()
428-
self._finished = []
429-
self._active = []
428+
self._finished: List[ProcessAdapter] = []
429+
self._active: List[ProcessAdapter] = []
430430
self._max_active = cores
431431

432432
self._in_context = False

pymc/tests/__init__.py

-4
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,3 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
from numpy.testing import Tester
16-
17-
test = Tester().test

0 commit comments

Comments
 (0)