Skip to content

Commit

Permalink
[Reland] Update mypy to 1.4.1 (pytorch#105227)
Browse files Browse the repository at this point in the history
This PR re-lands
- [Typing] Fix PEP 484 Violation (pytorch#105022)
- Update mypy to 1.4.1 (pytorch#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: pytorch#105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
  • Loading branch information
malfet authored and pytorchmergebot committed Jul 14, 2023
1 parent 1518d5e commit c9c4f8e
Show file tree
Hide file tree
Showing 97 changed files with 262 additions and 254 deletions.
4 changes: 2 additions & 2 deletions .ci/docker/requirements-ci.txt
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ librosa>=0.6.2 ; python_version < "3.11"
#Pinned versions:
#test that import:

mypy==0.960
mypy==1.4.1
# Pin MyPy version because new errors are likely to appear with each release
#Description: linter
#Pinned versions: 0.960
#Pinned versions: 1.4.1
#test that import: test_typing.py, test_type_hints.py

networkx==2.8.8
Expand Down
2 changes: 1 addition & 1 deletion .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ init_command = [
'--dry-run={{DRYRUN}}',
'numpy==1.24.3',
'expecttest==0.1.3',
'mypy==0.960',
'mypy==1.4.1',
'types-requests==2.27.25',
'types-PyYAML==6.0.7',
'types-tabulate==0.8.8',
Expand Down
2 changes: 1 addition & 1 deletion caffe2/contrib/aten/gen_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
# use faster C loader if available
from yaml import CSafeLoader as Loader
except ImportError:
from yaml import SafeLoader as Loader # type: ignore[misc]
from yaml import SafeLoader as Loader # type: ignore[assignment, misc]


def write(filename, s):
Expand Down
6 changes: 3 additions & 3 deletions caffe2/contrib/tensorboard/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def graph_def_to_event(step, graph_def):
wall_time=step, step=step, graph_def=graph_def.SerializeToString())


@cli.command("tensorboard-graphs")
@cli.command("tensorboard-graphs") # type: ignore[arg-type, attr-defined]
@click.option("--c2-netdef", type=click.Path(exists=True, dir_okay=False),
multiple=True)
@click.option("--tf-dir", type=click.Path(exists=True))
Expand All @@ -129,7 +129,7 @@ def parse_net_def(path):
log.info("Wrote %s graphs to logdir %s", len(events), tf_dir)


@cli.command("tensorboard-events")
@cli.command("tensorboard-events") # type: ignore[arg-type, attr-defined]
@click.option("--c2-dir", type=click.Path(exists=True, file_okay=False),
help="Root directory of the Caffe2 run")
@click.option("--tf-dir", type=click.Path(writable=True),
Expand Down Expand Up @@ -209,4 +209,4 @@ def event(step, values):


if __name__ == "__main__":
cli()
cli() # type: ignore[misc]
10 changes: 5 additions & 5 deletions test/test_futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,23 @@ def test_set_exception(self) -> None:
error_msg = "Intentional Value Error"
value_error = ValueError(error_msg)

f = Future[T]()
f = Future[T]() # type: ignore[valid-type]
# Set exception
f.set_exception(value_error)
# Exception should throw on wait
with self.assertRaisesRegex(ValueError, "Intentional"):
f.wait()

# Exception should also throw on value
f = Future[T]()
f = Future[T]() # type: ignore[valid-type]
f.set_exception(value_error)
with self.assertRaisesRegex(ValueError, "Intentional"):
f.value()

def cb(fut):
fut.value()

f = Future[T]()
f = Future[T]() # type: ignore[valid-type]
f.set_exception(value_error)

with self.assertRaisesRegex(RuntimeError, "Got the following error"):
Expand All @@ -54,7 +54,7 @@ def wait_future(f):
with self.assertRaisesRegex(ValueError, "Intentional"):
f.wait()

f = Future[T]()
f = Future[T]() # type: ignore[valid-type]
t = threading.Thread(target=wait_future, args=(f, ))
t.start()
f.set_exception(value_error)
Expand All @@ -68,7 +68,7 @@ def then_future(f):
with self.assertRaisesRegex(RuntimeError, "Got the following error"):
fut.wait()

f = Future[T]()
f = Future[T]() # type: ignore[valid-type]
t = threading.Thread(target=then_future, args=(f, ))
t.start()
f.set_exception(value_error)
Expand Down
4 changes: 2 additions & 2 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from functools import partial
from torch import multiprocessing as mp
from torch.testing import make_tensor
from torch.testing._internal.common_utils import (
from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
TEST_WITH_TORCHINDUCTOR, TestCase, TEST_WITH_ROCM, run_tests, IS_JETSON,
IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN,
IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, skipIfTorchInductor, load_tests, slowTest,
Expand Down Expand Up @@ -8497,7 +8497,7 @@ def _spawn_method(self, method, arg):
except RuntimeError:
pass
with mp.Pool(1) as pool:
out: list = pool.map(method, [arg])
out = pool.map(method, [arg])
self.assertTrue(out[0])

def _test_multinomial_invalid_probs(probs):
Expand Down
10 changes: 5 additions & 5 deletions tools/code_coverage/oss_coverage.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#!/usr/bin/env python3
import time

from package.oss.cov_json import get_json_report
from package.oss.init import initialization
from package.tool.summarize_jsons import summarize_jsons
from package.util.setting import TestPlatform
from package.util.utils import print_time
from package.oss.cov_json import get_json_report # type: ignore[import]
from package.oss.init import initialization # type: ignore[import]
from package.tool.summarize_jsons import summarize_jsons # type: ignore[import]
from package.util.setting import TestPlatform # type: ignore[import]
from package.util.utils import print_time # type: ignore[import]


def report_coverage() -> None:
Expand Down
2 changes: 1 addition & 1 deletion tools/code_coverage/package/tool/summarize_jsons.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def transform_file_name(
return file_path[file_path.find(folder) :]
# remove pytorch base folder path
if platform == TestPlatform.OSS:
from package.oss.utils import get_pytorch_folder
from package.oss.utils import get_pytorch_folder # type: ignore[import]

pytorch_foler = get_pytorch_folder()
assert file_path.startswith(pytorch_foler)
Expand Down
6 changes: 4 additions & 2 deletions tools/code_coverage/package/util/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def get_raw_profiles_folder() -> str:

def detect_compiler_type(platform: TestPlatform) -> CompilerType:
if platform == TestPlatform.OSS:
from package.oss.utils import detect_compiler_type # type: ignore[misc]
from package.oss.utils import ( # type: ignore[assignment, import, misc]
detect_compiler_type,
)

cov_type = detect_compiler_type() # type: ignore[call-arg]
else:
Expand All @@ -100,7 +102,7 @@ def detect_compiler_type(platform: TestPlatform) -> CompilerType:
cov_type = detect_compiler_type()

check_compiler_type(cov_type)
return cov_type
return cov_type # type: ignore[no-any-return]


def get_test_name_from_whole_path(path: str) -> str:
Expand Down
2 changes: 1 addition & 1 deletion tools/gen_vulkan_spv.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
try:
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader # type: ignore[misc]
from yaml import Loader # type: ignore[assignment, misc]

H_NAME = "spv.h"
CPP_NAME = "spv.cpp"
Expand Down
2 changes: 1 addition & 1 deletion tools/linter/adapters/workflow_consistency_linter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
try:
from yaml import CSafeLoader as Loader
except ImportError:
from yaml import SafeLoader as Loader # type: ignore[misc]
from yaml import SafeLoader as Loader # type: ignore[assignment, misc]


class LintSeverity(str, Enum):
Expand Down
2 changes: 1 addition & 1 deletion tools/lite_interpreter/gen_selected_mobile_ops_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
try:
from yaml import CSafeLoader as Loader
except ImportError:
from yaml import SafeLoader as Loader # type: ignore[misc]
from yaml import SafeLoader as Loader # type: ignore[assignment, misc]


if_condition_template_str = """if (kernel_tag_sv.compare("$kernel_tag_name") == 0) {
Expand Down
2 changes: 1 addition & 1 deletion tools/setup_helpers/generate_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# use faster C loader if available
from yaml import CSafeLoader as YamlLoader
except ImportError:
from yaml import SafeLoader as YamlLoader # type: ignore[misc]
from yaml import SafeLoader as YamlLoader # type: ignore[assignment, misc]

NATIVE_FUNCTIONS_PATH = "aten/src/ATen/native/native_functions.yaml"
TAGS_PATH = "aten/src/ATen/native/tags.yaml"
Expand Down
1 change: 1 addition & 0 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ def _jit_is_script_object(obj: Any) -> _bool: ...
def _last_executed_optimized_graph() -> Graph: ...
def parse_type_comment(comment: str) -> Decl: ...
def _get_upgraders_map_size() -> _int: ...
def _get_upgraders_entry_map() -> Dict[str, str]: ...
def _dump_upgraders_map() -> Dict[str, str]: ...
def _test_only_populate_upgraders(content: Dict[str, str]) -> None: ...
def _test_only_remove_upgraders(content: Dict[str, str]) -> None: ...
Expand Down
4 changes: 2 additions & 2 deletions torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def sym_int(a):
if isinstance(a, SymInt):
return a
elif isinstance(a, SymFloat):
return math.floor(a) if a >= 0 else math.ceil(a) # type: ignore[arg-type]
return math.floor(a) if a >= 0 else math.ceil(a) # type: ignore[arg-type, call-overload]
return py_int(a) # type: ignore[operator]

def sym_max(a, b):
Expand Down Expand Up @@ -1320,7 +1320,7 @@ def manager_path():
# Some type signatures pulled in from _VariableFunctions here clash with
# signatures already imported. For now these clashes are ignored; see
# PR #43339 for details.
from torch._C._VariableFunctions import * # type: ignore[misc] # noqa: F403
from torch._C._VariableFunctions import * # type: ignore[assignment, misc] # noqa: F403
# Fixup segment_reduce visibility
_segment_reduce = segment_reduce
del segment_reduce
Expand Down
8 changes: 4 additions & 4 deletions torch/_export/serde/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,20 @@ class _Union:
@classmethod
def create(cls, **kwargs):
assert len(kwargs) == 1
return cls(**{**{f.name: None for f in fields(cls)}, **kwargs})
return cls(**{**{f.name: None for f in fields(cls)}, **kwargs}) # type: ignore[arg-type]

def __post_init__(self):
assert sum(1 for f in fields(self) if getattr(self, f.name) is not None) == 1
assert sum(1 for f in fields(self) if getattr(self, f.name) is not None) == 1 # type: ignore[arg-type, misc]

@property
def value(self):
val = next((getattr(self, f.name) for f in fields(self) if getattr(self, f.name) is not None), None)
val = next((getattr(self, f.name) for f in fields(self) if getattr(self, f.name) is not None), None) # type: ignore[arg-type]
assert val is not None
return val

@property
def type(self):
val_type = next((f.name for f in fields(self) if getattr(self, f.name) is not None), None)
val_type = next((f.name for f in fields(self) if getattr(self, f.name) is not None), None) # type: ignore[arg-type]
assert val_type is not None
return val_type

Expand Down
9 changes: 5 additions & 4 deletions torch/_export/serde/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from typing import cast, Any, Callable, Dict, Iterator, List, Optional, Tuple, Union

import sympy

Expand Down Expand Up @@ -749,7 +749,7 @@ def __init__(self):
self.module = torch.nn.Module()

@contextmanager
def save_graph_module(self) -> None:
def save_graph_module(self) -> Iterator[None]:
saved = self.graph, self.module, self.serialized_name_to_node, self.serialized_name_to_meta
self.graph = torch.fx.Graph()
self.module = torch.nn.Module()
Expand All @@ -773,7 +773,7 @@ def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]:

if vr := self.symbol_name_to_range.get(val.expr_str):
symbolic_shapes._constrain_symbol_range(
self.shape_env, sym, vr.lower, vr.upper
self.shape_env, sym, vr.lower, vr.upper # type: ignore[arg-type]
)

return self.shape_env.create_symintnode(sym, hint=val.hint)
Expand Down Expand Up @@ -855,6 +855,7 @@ def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph:
output_node.meta["val"] = tuple(
arg.meta["val"] for arg in output_node.args[0]
)
return output_node

def deserialize_node(self, serialized_node: Node, target: Callable) -> None:
if target.__module__ == "_operator": # TODO(zhxchen17) Follow up on this.
Expand Down Expand Up @@ -1050,7 +1051,7 @@ def deserialize_multiple_outputs(self, serialized_node: Node, fx_node: torch.fx.
self.serialized_name_to_node[fx_node.name] = fx_node

def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]:
ret = {}
ret: Dict[str, Any] = {}
if stack_trace := metadata.get("stack_trace"):
ret["stack_trace"] = stack_trace

Expand Down
2 changes: 1 addition & 1 deletion torch/_export/serde/upgrade.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_upgraders() -> Dict[str, Tuple[str, str]]:
"""Getting upgraders entry map and operator version map and merge them into one dict."""
upgraders = torch._C._get_upgraders_entry_map()
op_version_map = torch._C._get_operator_version_map()
output = defaultdict(tuple)
output: Dict[str, Tuple[str, str]] = defaultdict(tuple) # type: ignore[arg-type]
for opname, entry_list in op_version_map.items():
if not entry_list:
raise RuntimeError(f"Op version map has an empty entry for opname {opname}")
Expand Down
4 changes: 2 additions & 2 deletions torch/_functorch/functional_call.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import Counter
from typing import Any, Dict, List, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import torch
import torch.nn as nn
Expand All @@ -12,7 +12,7 @@ def functional_call(
module: "torch.nn.Module",
parameter_and_buffer_dicts: Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]],
args: Union[Any, Tuple],
kwargs: Dict[str, Any] = None,
kwargs: Optional[Dict[str, Any]] = None,
*,
tie_weights: bool = True,
strict: bool = False,
Expand Down
8 changes: 4 additions & 4 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,7 @@ def _linalg_svd_meta(
A: Tensor,
full_matrices: bool = False,
compute_uv: bool = True,
driver: str = None,
driver: Optional[str] = None,
):
checkIsMatrix(A, "linalg.svd")
checkFloatingOrComplex(A, "linalg.svd")
Expand Down Expand Up @@ -1207,7 +1207,7 @@ def linalg_solve_triangular_meta(
upper: bool,
left: bool = True,
unitriangular: bool = False,
out: Tensor = None,
out: Optional[Tensor] = None,
) -> Tensor:
if out is None:
out = A.new_empty([0])
Expand Down Expand Up @@ -4755,8 +4755,8 @@ def upsample_nearest2d_backward(
grad_output: Tensor,
output_size: Sequence[Union[int, torch.types.SymInt]],
input_size: Sequence[Union[int, torch.types.SymInt]],
scales_h: float = None,
scales_w: float = None,
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
):
full_output_size = upsample_common_check(
input_size, output_size, num_spatial_dims=2
Expand Down
2 changes: 1 addition & 1 deletion torch/_prims/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ class ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND(Enum):
def _elementwise_meta(
*args,
type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND,
args_with_fixed_dtypes: Tuple[TensorLikeType, ...] = None,
args_with_fixed_dtypes: Optional[Tuple[TensorLikeType, ...]] = None,
) -> FakeTensor:
"""
Meta function for elementwise operations that produce outputs in the same dtype
Expand Down
8 changes: 4 additions & 4 deletions torch/_prims/context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import functools
from contextlib import nullcontext
from typing import Any, Callable, Dict, Sequence
from typing import Any, Callable, Dict, Optional, Sequence
from warnings import warn

import torch
Expand Down Expand Up @@ -111,7 +111,7 @@ def __torch_function__(
orig_func: Callable,
types: Sequence,
args: Sequence[Any] = (),
kwargs: Dict = None,
kwargs: Optional[Dict] = None,
):
if kwargs is None:
kwargs = {}
Expand Down Expand Up @@ -161,7 +161,7 @@ def __torch_function__(
orig_func: Callable,
types: Sequence,
args: Sequence[Any] = (),
kwargs: Dict = None,
kwargs: Optional[Dict] = None,
):
if kwargs is None:
kwargs = {}
Expand Down Expand Up @@ -374,7 +374,7 @@ def __torch_function__(
orig_func: Callable,
types: Sequence,
args: Sequence[Any] = (),
kwargs: Dict = None,
kwargs: Optional[Dict] = None,
):
if kwargs is None:
kwargs = {}
Expand Down
Loading

0 comments on commit c9c4f8e

Please sign in to comment.