Skip to content

Commit

Permalink
[tune] Remove checkpoint_dir and reporter deprecation notices (ra…
Browse files Browse the repository at this point in the history
…y-project#42698)

These APIs were hard-deprecated in Ray 2.7, and the utility functions to detect them are incorrectly flagging some cases, causing users to run into DeprecationWarning errors when they shouldn't. This PR removes the deprecated APIs and these unnecessary checks. 

---------

Signed-off-by: Justin Yu <[email protected]>
  • Loading branch information
justinvyu authored Jan 26, 2024
1 parent 8579395 commit e22b4dc
Show file tree
Hide file tree
Showing 8 changed files with 7 additions and 174 deletions.
2 changes: 1 addition & 1 deletion python/ray/train/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ def train_func(config):
# stdout messages and the results directory.
train_func.__name__ = trainer_cls.__name__

trainable_cls = wrap_function(train_func, warn=False)
trainable_cls = wrap_function(train_func)
has_base_dataset = bool(self.datasets)
if has_base_dataset:
from ray.data.context import DataContext
Expand Down
8 changes: 0 additions & 8 deletions python/ray/tune/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,6 @@ py_test(
tags = ["team:ml", "exclusive"],
)

py_test(
name = "test_function_api_legacy",
size = "small",
srcs = ["tests/test_function_api_legacy.py"],
deps = [":tune_lib"],
tags = ["team:ml", "exclusive"],
)

py_test(
name = "test_integration_pytorch_lightning",
size = "small",
Expand Down
6 changes: 3 additions & 3 deletions python/ray/tune/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,10 @@ def register_trainable(name: str, trainable: Union[Callable, Type], warn: bool =
logger.debug("Detected class for trainable.")
elif isinstance(trainable, FunctionType) or isinstance(trainable, partial):
logger.debug("Detected function for trainable.")
trainable = wrap_function(trainable, warn=warn)
trainable = wrap_function(trainable)
elif callable(trainable):
logger.info("Detected unknown callable for trainable. Converting to class.")
trainable = wrap_function(trainable, warn=warn)
trainable = wrap_function(trainable)

if not issubclass(trainable, Trainable):
raise TypeError("Second argument must be convertable to Trainable", trainable)
Expand Down Expand Up @@ -246,7 +246,7 @@ def unregister(self, category, key):

def unregister_all(self, category: Optional[str] = None):
remaining = set()
for (cat, key) in self._registered:
for cat, key in self._registered:
if category and category == cat:
self.unregister(cat, key)
else:
Expand Down
36 changes: 0 additions & 36 deletions python/ray/tune/tests/test_function_api_legacy.py

This file was deleted.

77 changes: 3 additions & 74 deletions python/ray/tune/trainable/function_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,7 @@
SHOULD_CHECKPOINT,
)
from ray.tune.trainable import Trainable
from ray.tune.utils import (
_detect_checkpoint_function,
_detect_config_single,
_detect_reporter,
)
from ray.tune.utils import _detect_config_single
from ray.util.annotations import DeveloperAPI


Expand All @@ -44,64 +40,6 @@
TEMP_MARKER = ".temp_marker"


_CHECKPOINT_DIR_ARG_DEPRECATION_MSG = """Accepting a `checkpoint_dir` argument in your training function is deprecated.
Please use `ray.train.get_checkpoint()` to access your checkpoint as a
`ray.train.Checkpoint` object instead. See below for an example:
Before
------
from ray import tune
def train_fn(config, checkpoint_dir=None):
if checkpoint_dir:
torch.load(os.path.join(checkpoint_dir, "checkpoint.pt"))
...
tuner = tune.Tuner(train_fn)
tuner.fit()
After
-----
from ray import train, tune
def train_fn(config):
checkpoint: train.Checkpoint = train.get_checkpoint()
if checkpoint:
with checkpoint.as_directory() as checkpoint_dir:
torch.load(os.path.join(checkpoint_dir, "checkpoint.pt"))
...
tuner = tune.Tuner(train_fn)
tuner.fit()""" # noqa: E501

_REPORTER_ARG_DEPRECATION_MSG = """Accepting a `reporter` in your training function is deprecated.
Please use `ray.train.report()` to report results instead. See below for an example:
Before
------
from ray import tune
def train_fn(config, reporter):
reporter(metric=1)
tuner = tune.Tuner(train_fn)
tuner.fit()
After
-----
from ray import train, tune
def train_fn(config):
train.report({"metric": 1})
tuner = tune.Tuner(train_fn)
tuner.fit()""" # noqa: E501


@DeveloperAPI
class FunctionTrainable(Trainable):
"""Trainable that runs a user function reporting results.
Expand Down Expand Up @@ -271,29 +209,20 @@ def _report_thread_runner_error(self, block=False):

@DeveloperAPI
def wrap_function(
train_func: Callable[[Any], Any], warn: bool = True, name: Optional[str] = None
train_func: Callable[[Any], Any], name: Optional[str] = None
) -> Type["FunctionTrainable"]:
inherit_from = (FunctionTrainable,)

if hasattr(train_func, "__mixins__"):
inherit_from = train_func.__mixins__ + inherit_from

func_args = inspect.getfullargspec(train_func).args
use_checkpoint = _detect_checkpoint_function(train_func)
use_config_single = _detect_config_single(train_func)
use_reporter = _detect_reporter(train_func)

if use_checkpoint:
raise DeprecationWarning(_CHECKPOINT_DIR_ARG_DEPRECATION_MSG)

if use_reporter:
raise DeprecationWarning(_REPORTER_ARG_DEPRECATION_MSG)

if not use_config_single:
# use_reporter is hidden
raise ValueError(
"Unknown argument found in the Trainable function. "
"The function args must include a 'config' positional parameter."
"The function args must include a single 'config' positional parameter.\n"
"Found: {}".format(func_args)
)

Expand Down
14 changes: 0 additions & 14 deletions python/ray/tune/trainable/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
)
from ray.air.config import ScalingConfig
from ray.tune.registry import _ParameterRegistry
from ray.tune.utils import _detect_checkpoint_function
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
Expand Down Expand Up @@ -124,12 +123,6 @@ def setup(self, config):
trainable_with_params = _Inner
else:
# Function trainable
if _detect_checkpoint_function(trainable, partial=True):
from ray.tune.trainable.function_trainable import (
_CHECKPOINT_DIR_ARG_DEPRECATION_MSG,
)

raise DeprecationWarning(_CHECKPOINT_DIR_ARG_DEPRECATION_MSG)

def inner(config):
fn_kwargs = {}
Expand Down Expand Up @@ -223,13 +216,6 @@ def train_fn(config):
if not inspect.isclass(trainable):
if isinstance(trainable, types.MethodType):
# Methods cannot set arbitrary attributes, so we have to wrap them
if _detect_checkpoint_function(trainable, partial=True):
from ray.tune.trainable.function_trainable import (
_CHECKPOINT_DIR_ARG_DEPRECATION_MSG,
)

raise DeprecationWarning(_CHECKPOINT_DIR_ARG_DEPRECATION_MSG)

def _trainable(config):
return trainable(config)

Expand Down
4 changes: 0 additions & 4 deletions python/ray/tune/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
validate_save_restore,
warn_if_slow,
diagnose_serialization,
_detect_checkpoint_function,
_detect_reporter,
_detect_config_single,
wait_for_gpu,
)
Expand All @@ -24,8 +22,6 @@
"validate_save_restore",
"warn_if_slow",
"diagnose_serialization",
"_detect_checkpoint_function",
"_detect_reporter",
"_detect_config_single",
"wait_for_gpu",
]
34 changes: 0 additions & 34 deletions python/ray/tune/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,40 +584,6 @@ def validate_save_restore(
return True


def _detect_checkpoint_function(train_func, abort=False, partial=False):
"""Use checkpointing if any arg has "checkpoint_dir" and args = 2"""
func_sig = inspect.signature(train_func)
validated = True
try:
# check if signature is func(config, checkpoint_dir=None)
if partial:
func_sig.bind_partial({}, checkpoint_dir="tmp/path")
else:
func_sig.bind({}, checkpoint_dir="tmp/path")
except Exception as e:
logger.debug(str(e))
validated = False
if abort and not validated:
func_args = inspect.getfullargspec(train_func).args
raise ValueError(
"Provided training function must have 1 `config` argument "
"`func(config)`. Got {}".format(func_args)
)
return validated


def _detect_reporter(func):
"""Use reporter if any arg has "reporter" and args = 2"""
func_sig = inspect.signature(func)
use_reporter = True
try:
func_sig.bind({}, reporter=None)
except Exception as e:
logger.debug(str(e))
use_reporter = False
return use_reporter


def _detect_config_single(func):
"""Check if func({}) works."""
func_sig = inspect.signature(func)
Expand Down

0 comments on commit e22b4dc

Please sign in to comment.