Skip to content

Commit

Permalink
remove deprecated device getter from torch.testing (pytorch#87971)
Browse files Browse the repository at this point in the history
See pytorch#87969 or pytorch#86586 for the reasoning.

Pull Request resolved: pytorch#87971
Approved by: https://github.com/mruberry
  • Loading branch information
pmeier authored and pytorchmergebot committed Nov 2, 2022
1 parent 554cdc9 commit a360be5
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 17 deletions.
5 changes: 1 addition & 4 deletions torch/testing/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

__all__ = [
"assert_allclose",
"get_all_device_types",
"make_non_contiguous",
]

Expand Down Expand Up @@ -89,13 +88,11 @@ def assert_allclose(
)

# Deprecate and expose all dtype getters
for name in _legacy.__all_dtype_getters__:
for name in _legacy.__all__:
fn = getattr(_legacy, name)
globals()[name] = warn_deprecated(getter_instructions)(fn)
__all__.append(name)

get_all_device_types = warn_deprecated(getter_instructions)(_legacy.get_all_device_types)


@warn_deprecated(
"Depending on the use case there a different replacement options:\n\n"
Expand Down
8 changes: 5 additions & 3 deletions torch/testing/_internal/common_device_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
TEST_CUSPARSE_GENERIC, TEST_HIPSPARSE_GENERIC
from torch.testing._internal.common_dtype import get_all_dtypes

# The implementation should be moved here as soon as the deprecation period is over.
from torch.testing._legacy import get_all_device_types # noqa: F401

try:
import psutil # type: ignore[import]
HAS_PSUTIL = True
Expand Down Expand Up @@ -1325,3 +1322,8 @@ def skipMeta(fn):

def skipXLA(fn):
return skipXLAIf(True, "Marked as skipped for XLA")(fn)

# TODO: the "all" in the name isn't true anymore for quite some time as we have also have for example XLA and MPS now.
# This should probably enumerate all available device type test base classes.
def get_all_device_types() -> List[str]:
return ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
11 changes: 1 addition & 10 deletions torch/testing/_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import torch

__all_dtype_getters__ = [
__all__ = [
"_validate_dtypes",
"_dispatch_dtypes",
"all_types",
Expand All @@ -32,11 +32,6 @@
"integral_types_and",
]

__all__ = [
*__all_dtype_getters__,
"get_all_device_types",
]

# Functions and classes for describing the dtypes a function supports
# NOTE: these helpers should correspond to PyTorch's C++ dispatch macros

Expand Down Expand Up @@ -152,7 +147,3 @@ def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> List[torch.dt

def get_all_qint_dtypes() -> List[torch.dtype]:
return [torch.qint8, torch.quint8, torch.qint32, torch.quint4x2, torch.quint2x4]


def get_all_device_types() -> List[str]:
return ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']

0 comments on commit a360be5

Please sign in to comment.