Skip to content

Commit

Permalink
[BE] wrap deprecated function/class with `typing_extensions.deprecate…
Browse files Browse the repository at this point in the history
…d` (pytorch#127689)

Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

Resolves pytorch#126888

- pytorch#126888

This PR is split from PR pytorch#126898.

- pytorch#126898

------

Pull Request resolved: pytorch#127689
Approved by: https://github.com/Skylion007
  • Loading branch information
XuehaiPan authored and pytorchmergebot committed Jun 2, 2024
1 parent c1dd3a6 commit 67ef268
Show file tree
Hide file tree
Showing 97 changed files with 763 additions and 458 deletions.
2 changes: 1 addition & 1 deletion .github/requirements/conda-env-Linux-X64.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ ninja=1.10.2
numpy=1.23.3
pyyaml=6.0
setuptools=68.2.2
typing-extensions=4.3.0
typing-extensions=4.9.0
2 changes: 1 addition & 1 deletion .github/requirements/conda-env-iOS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ ninja=1.10.2
numpy=1.23.3
pyyaml=6.0
setuptools=68.2.2
typing-extensions=4.3.0
typing-extensions=4.9.0
2 changes: 1 addition & 1 deletion .github/requirements/conda-env-macOS-ARM64
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ numpy=1.22.3
pyyaml=6.0
setuptools=61.2.0
cmake=3.22.*
typing-extensions=4.3.0
typing-extensions=4.9.0
dataclasses=0.8
pip=22.2.2
pillow=10.0.1
Expand Down
2 changes: 1 addition & 1 deletion .github/requirements/conda-env-macOS-X64
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ numpy=1.21.2
pyyaml=5.3
setuptools=46.0.0
cmake=3.22.*
typing-extensions=4.3.0
typing-extensions=4.9.0
dataclasses=0.8
pip=22.2.2
pillow=10.0.1
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/_tensor/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def output_fn(outputs, device_mesh):
assert isinstance(outputs, DTensor)
return outputs.to_local()

with self.assertWarnsRegex(UserWarning, "Deprecating"):
with self.assertWarnsRegex(FutureWarning, "Deprecating"):
replica_module = distribute_module(
module_to_replicate,
device_mesh,
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/fsdp/test_fsdp_optim_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1436,7 +1436,7 @@ def should_check_method(method_name: str):
def get_warning_context():
warning_regex = "`optim_input` argument is deprecated"
return self.assertWarnsRegex(
expected_warning=UserWarning, expected_regex=warning_regex
expected_warning=FutureWarning, expected_regex=warning_regex
)

self._run_on_all_optim_state_apis(
Expand Down
6 changes: 4 additions & 2 deletions test/functorch/test_eager_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3258,7 +3258,7 @@ def test_deprecation_vmap(self, device):
x = torch.randn(3, device=device)

# functorch version of the API is deprecated
with self.assertWarnsRegex(UserWarning, "Please use torch.vmap"):
with self.assertWarnsRegex(FutureWarning, "Please use `torch.vmap`"):
vmap(torch.sin)

# the non-functorch version is not deprecated
Expand All @@ -3276,7 +3276,9 @@ def test_deprecation_transforms(self, device, transform):
new_api = getattr(torch.func, transform)

# functorch version of the API is deprecated
with self.assertWarnsRegex(UserWarning, f"Please use torch.func.{transform}"):
with self.assertWarnsRegex(
FutureWarning, f"Please use `torch.func.{transform}`"
):
api(torch.sin)

# the non-functorch version is not deprecated
Expand Down
2 changes: 1 addition & 1 deletion test/nn/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def fn():
init.normal(x)

with self.assertWarnsRegex(
UserWarning,
FutureWarning,
"deprecated",
msg="methods not suffixed with underscore should be deprecated",
):
Expand Down
9 changes: 5 additions & 4 deletions test/nn/test_module_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1387,7 +1387,8 @@ def forward(self, l):
m.register_backward_hook(noop)

with self.assertWarnsRegex(
UserWarning, "does not take as input a single Tensor or a tuple of Tensors"
FutureWarning,
"does not take as input a single Tensor or a tuple of Tensors",
):
m([a, b])

Expand All @@ -1400,7 +1401,7 @@ def forward(self, a, b):
m.register_backward_hook(noop)

with self.assertWarnsRegex(
UserWarning, "does not return a single Tensor or a tuple of Tensors"
FutureWarning, "does not return a single Tensor or a tuple of Tensors"
):
m(a, b)

Expand All @@ -1413,7 +1414,7 @@ def forward(self, a, b):
m.register_backward_hook(noop)

with self.assertWarnsRegex(
UserWarning, "outputs are generated by different autograd Nodes"
FutureWarning, "outputs are generated by different autograd Nodes"
):
m(a, b)

Expand All @@ -1426,7 +1427,7 @@ def forward(self, a):
m.register_backward_hook(noop)

with self.assertWarnsRegex(
UserWarning, "the forward contains multiple autograd Nodes"
FutureWarning, "the forward contains multiple autograd Nodes"
):
m(a)

Expand Down
4 changes: 2 additions & 2 deletions test/test_autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,8 @@ def test_generic_autocast(self):

def test_cpu_autocast_deprecated_warning(self):
with self.assertWarnsRegex(
DeprecationWarning,
r"torch.cpu.amp.autocast\(args...\) is deprecated. Please use torch.amp.autocast\('cpu', args...\) instead.",
FutureWarning,
r"`torch.cpu.amp.autocast\(args...\)` is deprecated. Please use `torch.amp.autocast\('cpu', args...\)` instead.",
):
with torch.cpu.amp.autocast():
_ = torch.ones(10)
Expand Down
10 changes: 5 additions & 5 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def hook(*args):

def test_grad_mode_class_decoration(self):
# Decorating class is deprecated and should not be used
with self.assertWarnsRegex(UserWarning, "Decorating classes is deprecated"):
with self.assertWarnsRegex(FutureWarning, "Decorating classes is deprecated"):

@torch.no_grad()
class Foo:
Expand Down Expand Up @@ -5937,13 +5937,13 @@ def fn(inputs):
b = torch.rand(2, 2, requires_grad=True, dtype=torch.float64)

with self.assertWarnsRegex(
UserWarning, "get_numerical_jacobian was part of PyTorch's private API"
FutureWarning, "`get_numerical_jacobian` was part of PyTorch's private API"
):
jacobian = get_numerical_jacobian(fn, (a, b), target=a, eps=1e-6)
self.assertEqual(jacobian[0], 2 * torch.eye(4, dtype=torch.double))

with self.assertWarnsRegex(
UserWarning, "get_numerical_jacobian was part of PyTorch's private API"
FutureWarning, "`get_numerical_jacobian` was part of PyTorch's private API"
):
jacobian = get_numerical_jacobian(fn, (a, b), eps=1e-6)
self.assertEqual(jacobian[0], 2 * torch.eye(4, dtype=torch.double))
Expand All @@ -5963,7 +5963,7 @@ def fn(x, y):

outputs = fn(a, b)
with self.assertWarnsRegex(
UserWarning, "get_analytical_jacobian was part of PyTorch's private API"
FutureWarning, "`get_analytical_jacobian` was part of PyTorch's private API"
):
(
jacobians,
Expand Down Expand Up @@ -5991,7 +5991,7 @@ def backward(ctx, grad_out):

outputs = NonDetFunc.apply(a, 1e-6)
with self.assertWarnsRegex(
UserWarning, "get_analytical_jacobian was part of PyTorch's private API"
FutureWarning, "`get_analytical_jacobian` was part of PyTorch's private API"
):
(
jacobians,
Expand Down
8 changes: 4 additions & 4 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1820,10 +1820,10 @@ def backward(ctx, grad):
return grad, grad

self.assertRegex(
str(w[0].message), r"torch.cuda.amp.custom_fwd\(args...\) is deprecated."
str(w[0].message), r"`torch.cuda.amp.custom_fwd\(args...\)` is deprecated."
)
self.assertRegex(
str(w[1].message), r"torch.cuda.amp.custom_bwd\(args...\) is deprecated."
str(w[1].message), r"`torch.cuda.amp.custom_bwd\(args...\)` is deprecated."
)

mymm = MyMM.apply
Expand Down Expand Up @@ -2016,8 +2016,8 @@ def test_autocast_checkpointing(self):

def test_cuda_autocast_deprecated_warning(self):
with self.assertWarnsRegex(
DeprecationWarning,
r"torch.cuda.amp.autocast\(args...\) is deprecated. Please use torch.amp.autocast\('cuda', args...\) instead.",
FutureWarning,
r"`torch.cuda.amp.autocast\(args...\)` is deprecated. Please use `torch.amp.autocast\('cuda', args...\)` instead.",
):
with torch.cuda.amp.autocast():
_ = torch.ones(10)
Expand Down
2 changes: 1 addition & 1 deletion test/test_prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def test_mul_complex(self):
prims.mul(torch.randn(2), 1 + 1j)

def test_check_deprecation_warning(self):
with self.assertWarnsRegex(DeprecationWarning, 'will be removed in the future'):
with self.assertWarnsRegex(FutureWarning, 'will be removed in the future'):
torch._prims_common.check(True, lambda: 'message')


Expand Down
2 changes: 1 addition & 1 deletion test/test_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ def __init__(self, x, y):
self.y = y

with self.assertWarnsRegex(
UserWarning, "torch.utils._pytree._register_pytree_node"
FutureWarning, "torch.utils._pytree._register_pytree_node"
):
py_pytree._register_pytree_node(
DummyType,
Expand Down
2 changes: 1 addition & 1 deletion test/test_stateless.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,7 @@ def test_stateless_functional_call_warns(self):
m = torch.nn.Linear(1, 1)
params = dict(m.named_parameters())
x = torch.randn(3, 1)
with self.assertWarnsRegex(UserWarning, "Please use torch.func.functional_call"):
with self.assertWarnsRegex(FutureWarning, "Please use `torch.func.functional_call`"):
stateless.functional_call(m, params, x)

class TestPythonOptimizeMode(TestCase):
Expand Down
4 changes: 2 additions & 2 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6198,8 +6198,8 @@ def test_grad_scaler_deprecated_warning(self, device):
GradScaler = torch.cuda.amp.GradScaler if "cuda" == device.type else torch.cpu.amp.GradScaler

with self.assertWarnsRegex(
UserWarning,
rf"torch.{device.type}.amp.GradScaler\(args...\) is deprecated.",
FutureWarning,
rf"`torch.{device.type}.amp.GradScaler\(args...\)` is deprecated.",
):
_ = GradScaler(init_scale=2.0)

Expand Down
11 changes: 0 additions & 11 deletions torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1996,17 +1996,6 @@ def _register_device_module(device_type, module):
from torch.func import vmap


# The function _sparse_coo_tensor_unsafe is removed from PyTorch
# Python API (v. 1.13), here we temporarily provide its replacement
# with a deprecation warning.
# TODO: remove the function for PyTorch v 1.15.
def _sparse_coo_tensor_unsafe(*args, **kwargs):
import warnings
warnings.warn('torch._sparse_coo_tensor_unsafe is deprecated, '
'use torch.sparse_coo_tensor(..., check_invariants=False) instead.')
kwargs['check_invariants'] = False
return torch.sparse_coo_tensor(*args, **kwargs)

# Register MPS specific decomps
torch.backends.mps._init()

Expand Down
10 changes: 7 additions & 3 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,9 @@ def guard_export_print(guards):
warnings.warn(
"explain(f, *args, **kwargs) is deprecated, use explain(f)(*args, **kwargs) instead. "
"If you don't migrate, we may break your explain call in the future if your user defined kwargs "
"conflict with future kwargs added to explain(f)."
"conflict with future kwargs added to explain(f).",
FutureWarning,
stacklevel=2,
)
return inner(*extra_args, **extra_kwargs)
else:
Expand Down Expand Up @@ -941,7 +943,7 @@ def check_signature_rewritable(graph):
tb = "".join(traceback.format_list(stack))
extra = ""
if len(user_stacks) > 1:
extra = f"(elided {len(user_stacks)-1} more accesses)"
extra = f"(elided {len(user_stacks) - 1} more accesses)"
msg = f"{source.name()}, accessed at:\n{tb}{extra}"
# TODO: option to print ALL of the stack traces at once
input_errors.append(msg)
Expand Down Expand Up @@ -1476,7 +1478,9 @@ def graph_with_interpreter(*args):
warnings.warn(
"export(f, *args, **kwargs) is deprecated, use export(f)(*args, **kwargs) instead. "
"If you don't migrate, we may break your export call in the future if your user defined kwargs "
"conflict with future kwargs added to export(f)."
"conflict with future kwargs added to export(f).",
FutureWarning,
stacklevel=2,
)
return inner(*extra_args, **extra_kwargs)
else:
Expand Down
26 changes: 13 additions & 13 deletions torch/_functorch/deprecated.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
"""
The APIs in this file are exposed as `functorch.*`. They are thin wrappers
around the torch.func.* APIs that have deprecation warnings -- we're trying
to move people to the torch.func.* equivalents.
NB: We don't use *args, **kwargs in the signatures because that changes the
documentation.
"""

import textwrap
import warnings
from typing import Any, Callable, Optional, Tuple, Union
Expand All @@ -9,25 +18,16 @@
from torch._functorch.eager_transforms import argnums_t
from torch._functorch.vmap import in_dims_t, out_dims_t

"""
The APIs in this file are exposed as `functorch.*`. They are thin wrappers
around the torch.func.* APIs that have deprecation warnings -- we're trying
to move people to the torch.func.* equivalents.
NB: We don't use *args, **kwargs in the signatures because that changes the
documentation.
"""


def get_warning(api, new_api=None, replace_newlines=False):
if new_api is None:
new_api = f"torch.func.{api}"
warning = (
f"We've integrated functorch into PyTorch. As the final step of the \n"
f"integration, functorch.{api} is deprecated as of PyTorch \n"
f"integration, `functorch.{api}` is deprecated as of PyTorch \n"
f"2.0 and will be deleted in a future version of PyTorch >= 2.3. \n"
f"Please use {new_api} instead; see the PyTorch 2.0 release notes \n"
f"and/or the torch.func migration guide for more details \n"
f"Please use `{new_api}` instead; see the PyTorch 2.0 release notes \n"
f"and/or the `torch.func` migration guide for more details \n"
f"https://pytorch.org/docs/main/func.migrating.html"
)
if replace_newlines:
Expand All @@ -37,7 +37,7 @@ def get_warning(api, new_api=None, replace_newlines=False):

def warn_deprecated(api, new_api=None):
warning = get_warning(api, new_api, replace_newlines=True)
warnings.warn(warning, stacklevel=2)
warnings.warn(warning, FutureWarning, stacklevel=3)


def setup_docs(functorch_api, torch_func_api=None, new_api_name=None):
Expand Down
5 changes: 3 additions & 2 deletions torch/_functorch/pytree_hacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"torch._functorch.pytree_hacks is deprecated and will be removed in a future release. "
"Please use torch.utils._pytree instead.",
"`torch._functorch.pytree_hacks` is deprecated and will be removed in a future release. "
"Please `use torch.utils._pytree` instead.",
DeprecationWarning,
stacklevel=2,
)
3 changes: 2 additions & 1 deletion torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -2750,7 +2750,8 @@ def make_indexer(self):
"""A closure containing math to read a given element"""

def indexer(index):
assert len(index) == len(self.stride) == len(self.size)
assert len(index) == len(self.stride)
assert len(index) == len(self.size)
result = self.offset
for idx, stride, sz in zip(index, self.stride, self.size):
if sz != 1:
Expand Down
9 changes: 5 additions & 4 deletions torch/_library/abstract_impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import contextlib
import functools
import warnings
from typing import Callable, Optional
from typing_extensions import deprecated

import torch
from torch._library.utils import Kernel, RegistrationHandle
Expand Down Expand Up @@ -124,10 +124,11 @@ def __init__(self, _fake_mode, _op):
self._shape_env = _fake_mode.shape_env
self._op = _op

@deprecated(
"`create_unbacked_symint` is deprecated, please use `new_dynamic_size` instead",
category=FutureWarning,
)
def create_unbacked_symint(self, *, min=2, max=None) -> torch.SymInt:
warnings.warn(
"create_unbacked_symint is deprecated, please use new_dynamic_size instead"
)
return self.new_dynamic_size(min=min, max=max)

def new_dynamic_size(self, *, min=0, max=None) -> torch.SymInt:
Expand Down
Loading

0 comments on commit 67ef268

Please sign in to comment.