forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[AIR] Remove extra sessions that are not needed any more (ray-project…
…#37023) Having multiple sessions floating around is confusing and we are going to replace the session concept with a unified context object between train and tune going forward (see ray-project#36706) The changes in detail: - Remove the `Session` interface class -- we are not planning to expose it to the user and it just introduces an additional level of abstraction that is not needed / not aligned with the longer term plan of having a unified context object between train and tune - Remove the `_TrainSessionImpl` and `_TuneSessionImpl` and instead push the functionality down into the `_StatusReporter` and the `_TrainSession` -- we might want to rename `_StatusReporter` to `_TuneSession` to be more consistent.
- Loading branch information
Showing
7 changed files
with
54 additions
and
242 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,96 +0,0 @@ | ||
import warnings | ||
from typing import TYPE_CHECKING, Dict, Optional | ||
|
||
from ray.air._internal.session import Session | ||
from ray.air.checkpoint import Checkpoint | ||
|
||
if TYPE_CHECKING: | ||
# avoid circular import | ||
from ray.data import DataIterator | ||
from ray.train._internal.session import _TrainSession | ||
from ray.tune.execution.placement_groups import PlacementGroupFactory | ||
|
||
|
||
class _TrainSessionImpl(Session): | ||
"""Session client that "per worker train loop" can interact with. | ||
Notice that each worker will automatically switch to its working | ||
directory on entering the train loop. This is to ensure that | ||
each worker can safely write to a local directory without racing | ||
and overwriting each other.""" | ||
|
||
def __init__(self, session: "_TrainSession"): | ||
self._session = session | ||
|
||
def report(self, metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None: | ||
self._session.report(metrics, checkpoint) | ||
|
||
@property | ||
def loaded_checkpoint(self) -> Optional[Checkpoint]: | ||
ckpt = self._session.loaded_checkpoint | ||
if ckpt: | ||
# The new API should only interact with Checkpoint object. | ||
assert isinstance(ckpt, Checkpoint) | ||
return ckpt | ||
|
||
@property | ||
def experiment_name(self) -> str: | ||
return self._session.trial_info.experiment_name | ||
|
||
@property | ||
def trial_name(self) -> str: | ||
return self._session.trial_info.name | ||
|
||
@property | ||
def trial_id(self) -> str: | ||
return self._session.trial_info.id | ||
|
||
@property | ||
def trial_resources(self) -> "PlacementGroupFactory": | ||
return self._session.trial_info.resources | ||
|
||
@property | ||
def trial_dir(self) -> str: | ||
return self._session.trial_info.logdir | ||
|
||
@property | ||
def world_size(self) -> int: | ||
return self._session.world_size | ||
|
||
@property | ||
def world_rank(self) -> int: | ||
return self._session.world_rank | ||
|
||
@property | ||
def local_rank(self) -> int: | ||
return self._session.local_rank | ||
|
||
@property | ||
def local_world_size(self) -> int: | ||
return self._session.local_world_size | ||
|
||
@property | ||
def node_rank(self) -> int: | ||
return self._session.node_rank | ||
|
||
def get_dataset_shard( | ||
self, | ||
dataset_name: Optional[str] = None, | ||
) -> Optional["DataIterator"]: | ||
shard = self._session.dataset_shard | ||
if shard is None: | ||
warnings.warn( | ||
"No dataset passed in. Returning None. Make sure to " | ||
"pass in a Dataset to Trainer.run to use this " | ||
"function." | ||
) | ||
elif isinstance(shard, dict): | ||
if not dataset_name: | ||
raise RuntimeError( | ||
"Multiple datasets were passed into ``Trainer``, " | ||
"but no ``dataset_name`` is passed into " | ||
"``get_dataset_shard``. Please specify which " | ||
"dataset shard to retrieve." | ||
) | ||
return shard.get(dataset_name) | ||
return shard | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.