Skip to content

Commit

Permalink
[autoscaler] Split autoscaler interface public private (ray-project#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
ericl authored Sep 19, 2020
1 parent 9a07c7b commit 6a227ae
Show file tree
Hide file tree
Showing 56 changed files with 436 additions and 251 deletions.
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1797,6 +1797,7 @@ filegroup(
srcs = glob([
"python/ray/*.py",
"python/ray/autoscaler/*.py",
"python/ray/autoscaler/_private/*.py",
"python/ray/autoscaler/aws/example-full.yaml",
"python/ray/autoscaler/azure/example-full.yaml",
"python/ray/autoscaler/gcp/example-full.yaml",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import ray
from ray import tune
from ray.autoscaler.commands import kill_node
from ray.autoscaler._private.commands import kill_node
from ray.tune import CLIReporter
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.schedulers import PopulationBasedTraining
Expand Down
5 changes: 5 additions & 0 deletions python/ray/autoscaler/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
## Note on interface stability.

All the public Python methods and attributes declared in this package can be considered stable public interfaces (except for those in the _private package).

We also guarantee backwards compatibility for the cluster YAMLs.
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,17 @@

from ray.experimental.internal_kv import _internal_kv_put, \
_internal_kv_initialized
from ray.autoscaler.node_provider import get_node_provider
from ray.autoscaler.node_provider import _get_node_provider
from ray.autoscaler.tags import (TAG_RAY_LAUNCH_CONFIG, TAG_RAY_RUNTIME_CONFIG,
TAG_RAY_FILE_MOUNTS_CONTENTS,
TAG_RAY_NODE_STATUS, TAG_RAY_NODE_KIND,
TAG_RAY_USER_NODE_TYPE, STATUS_UP_TO_DATE,
NODE_KIND_WORKER, NODE_KIND_UNMANAGED)
from ray.autoscaler.updater import NodeUpdaterThread
from ray.autoscaler.node_launcher import NodeLauncher
from ray.autoscaler.resource_demand_scheduler import ResourceDemandScheduler
from ray.autoscaler.util import ConcurrentCounter, validate_config, \
from ray.autoscaler._private.updater import NodeUpdaterThread
from ray.autoscaler._private.node_launcher import NodeLauncher
from ray.autoscaler._private.resource_demand_scheduler import \
ResourceDemandScheduler
from ray.autoscaler._private.util import ConcurrentCounter, validate_config, \
with_head_node_ip, hash_launch_conf, hash_runtime_conf, \
DEBUG_AUTOSCALING_STATUS, DEBUG_AUTOSCALING_ERROR
from ray.ray_constants import AUTOSCALER_MAX_NUM_FAILURES, \
Expand Down Expand Up @@ -305,8 +306,8 @@ def reset(self, errors_fatal=False):
self.runtime_hash = new_runtime_hash
self.file_mounts_contents_hash = new_file_mounts_contents_hash
if not self.provider:
self.provider = get_node_provider(self.config["provider"],
self.config["cluster_name"])
self.provider = _get_node_provider(self.config["provider"],
self.config["cluster_name"])
# Check whether we can enable the resource demand scheduler.
if "available_node_types" in self.config:
self.available_node_types = self.config["available_node_types"]
Expand Down Expand Up @@ -579,8 +580,3 @@ def kill_workers(self):
self.provider.terminate_nodes(nodes)
logger.error("StandardAutoscaler: terminated {} node(s)".format(
len(nodes)))


def request_resources(num_cpus=None, num_gpus=None):
raise DeprecationWarning(
"Please use ray.autoscaler.commands.request_resources instead.")
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@

from ray.ray_constants import BOTO_MAX_RETRIES
from ray.autoscaler.tags import NODE_KIND_WORKER, NODE_KIND_HEAD
from ray.autoscaler.aws.utils import LazyDefaultDict, handle_boto_error
from ray.autoscaler.node_provider import PROVIDER_PRETTY_NAMES
from ray.autoscaler.node_provider import _PROVIDER_PRETTY_NAMES
from ray.autoscaler._private.aws.utils import LazyDefaultDict, \
handle_boto_error
from ray.autoscaler._private.cli_logger import cli_logger

from ray.autoscaler.cli_logger import cli_logger
import colorful as cf

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -100,7 +101,7 @@ def _arn_to_name(arn):


def log_to_cli(config):
provider_name = PROVIDER_PRETTY_NAMES.get("aws", None)
provider_name = _PROVIDER_PRETTY_NAMES.get("aws", None)

cli_logger.doassert(provider_name is not None,
"Could not find a pretty name for the AWS provider.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
from botocore.config import Config

from ray.autoscaler.node_provider import NodeProvider
from ray.autoscaler.aws.config import bootstrap_aws
from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME, \
TAG_RAY_LAUNCH_CONFIG, TAG_RAY_NODE_KIND, TAG_RAY_USER_NODE_TYPE
from ray.ray_constants import BOTO_MAX_RETRIES, BOTO_CREATE_MAX_RETRIES
from ray.autoscaler.log_timer import LogTimer
from ray.autoscaler._private.aws.config import bootstrap_aws
from ray.autoscaler._private.log_timer import LogTimer

from ray.autoscaler.aws.utils import boto_exception_handler
from ray.autoscaler.cli_logger import cli_logger
from ray.autoscaler._private.aws.utils import boto_exception_handler
from ray.autoscaler._private.cli_logger import cli_logger
import colorful as cf

logger = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import defaultdict

from ray.autoscaler.cli_logger import cli_logger
from ray.autoscaler._private.cli_logger import cli_logger
import colorful as cf


Expand Down
Empty file.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from knack.util import CLIError

from ray.autoscaler.node_provider import NodeProvider
from ray.autoscaler.azure.config import bootstrap_azure
from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME
from ray.autoscaler._private.azure.config import bootstrap_azure

VM_NAME_MAX_LEN = 64
VM_NAME_UUID_LEN = 8
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# function for demonstration purposes. Primarily useful for tuning color and
# other formatting.

from ray.autoscaler.cli_logger import cli_logger
from ray.autoscaler._private.cli_logger import cli_logger
import colorful as cf

cli_logger.old_style = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@
import time
import warnings

from ray.autoscaler.docker import check_bind_mounts_cmd, \
from ray.autoscaler._private.docker import check_bind_mounts_cmd, \
check_docker_running_cmd, \
check_docker_image, \
docker_start_cmds, \
DOCKER_MOUNT_PREFIX, \
with_docker_exec
from ray.autoscaler.log_timer import LogTimer
from ray.autoscaler._private.log_timer import LogTimer

from ray.autoscaler.subprocess_output_util import (
from ray.autoscaler._private.subprocess_output_util import (
run_cmd_redirected, ProcessRunnerError, is_output_redirected)

from ray.autoscaler.cli_logger import cli_logger
from ray.autoscaler._private.cli_logger import cli_logger
import colorful as cf

logger = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import subprocess
import tempfile
import time
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, List

import click
import yaml
Expand All @@ -20,26 +20,23 @@

from ray.experimental.internal_kv import _internal_kv_get
import ray.services as services
from ray.autoscaler.util import validate_config, hash_runtime_conf, \
from ray.ray_constants import AUTOSCALER_RESOURCE_REQUEST_CHANNEL
from ray.autoscaler._private.util import validate_config, hash_runtime_conf, \
hash_launch_conf, prepare_config, DEBUG_AUTOSCALING_ERROR, \
DEBUG_AUTOSCALING_STATUS
from ray.autoscaler.node_provider import get_node_provider, NODE_PROVIDERS, \
PROVIDER_PRETTY_NAMES, try_get_log_state, try_logging_config, \
try_reload_log_state
from ray.autoscaler.node_provider import _get_node_provider, \
_NODE_PROVIDERS, _PROVIDER_PRETTY_NAMES
from ray.autoscaler.tags import TAG_RAY_NODE_KIND, TAG_RAY_LAUNCH_CONFIG, \
TAG_RAY_NODE_NAME, NODE_KIND_WORKER, NODE_KIND_HEAD, TAG_RAY_USER_NODE_TYPE

from ray.ray_constants import AUTOSCALER_RESOURCE_REQUEST_CHANNEL
from ray.autoscaler.updater import NodeUpdaterThread
from ray.autoscaler.command_runner import set_using_login_shells, \
from ray.autoscaler._private.cli_logger import cli_logger
from ray.autoscaler._private.updater import NodeUpdaterThread
from ray.autoscaler._private.command_runner import set_using_login_shells, \
set_rsync_silent
from ray.autoscaler.log_timer import LogTimer
from ray.autoscaler._private.log_timer import LogTimer
from ray.worker import global_worker
from ray.util.debug import log_once

import ray.autoscaler.subprocess_output_util as cmd_output_util

from ray.autoscaler.cli_logger import cli_logger
import ray.autoscaler._private.subprocess_output_util as cmd_output_util

logger = logging.getLogger(__name__)

Expand All @@ -59,6 +56,26 @@ def _redis():
return redis_client


def try_logging_config(config):
if config["provider"]["type"] == "aws":
from ray.autoscaler._private.aws.config import log_to_cli
log_to_cli(config)


def try_get_log_state(provider_config):
if provider_config["type"] == "aws":
from ray.autoscaler._private.aws.config import get_log_state
return get_log_state()


def try_reload_log_state(provider_config, log_state):
if not log_state:
return
if provider_config["type"] == "aws":
from ray.autoscaler._private.aws.config import reload_log_state
return reload_log_state(log_state)


def debug_status():
"""Return a debug string for the autoscaler."""
status = _internal_kv_get(DEBUG_AUTOSCALING_STATUS)
Expand Down Expand Up @@ -143,14 +160,14 @@ def handle_yaml_error(e):

# todo: validate file_mounts, ssh keys, etc.

importer = NODE_PROVIDERS.get(config["provider"]["type"])
importer = _NODE_PROVIDERS.get(config["provider"]["type"])
if not importer:
cli_logger.abort(
"Unknown provider type " + cf.bold("{}") + "\n"
"Available providers are: {}", config["provider"]["type"],
cli_logger.render_list([
k for k in NODE_PROVIDERS.keys()
if NODE_PROVIDERS[k] is not None
k for k in _NODE_PROVIDERS.keys()
if _NODE_PROVIDERS[k] is not None
]))
raise NotImplementedError("Unsupported provider {}".format(
config["provider"]))
Expand Down Expand Up @@ -236,7 +253,7 @@ def _bootstrap_config(config: Dict[str, Any],
config_cache.get("_version", "none"), CONFIG_CACHE_VERSION)
validate_config(config)

importer = NODE_PROVIDERS.get(config["provider"]["type"])
importer = _NODE_PROVIDERS.get(config["provider"]["type"])
if not importer:
raise NotImplementedError("Unsupported provider {}".format(
config["provider"]))
Expand All @@ -245,7 +262,7 @@ def _bootstrap_config(config: Dict[str, Any],

with cli_logger.timed(
"Checking {} environment settings",
PROVIDER_PRETTY_NAMES.get(config["provider"]["type"])):
_PROVIDER_PRETTY_NAMES.get(config["provider"]["type"])):
resolved_config = provider_cls.bootstrap_config(config)

if not no_config_cache:
Expand Down Expand Up @@ -298,7 +315,7 @@ def teardown_cluster(config_file: str, yes: bool, workers_only: bool,
cli_logger.old_exception(
logger, "Ignoring error attempting a clean shutdown.")

provider = get_node_provider(config["provider"], config["cluster_name"])
provider = _get_node_provider(config["provider"], config["cluster_name"])
try:

def remaining_nodes():
Expand Down Expand Up @@ -402,7 +419,7 @@ def kill_node(config_file, yes, hard, override_cluster_name):
cli_logger.confirm(yes, "A random node will be killed.")
cli_logger.old_confirm("This will kill a node in your cluster", yes)

provider = get_node_provider(config["provider"], config["cluster_name"])
provider = _get_node_provider(config["provider"], config["cluster_name"])
try:
nodes = provider.non_terminated_nodes({
TAG_RAY_NODE_KIND: NODE_KIND_WORKER
Expand Down Expand Up @@ -491,8 +508,8 @@ def get_or_create_head_node(config,
_provider=None,
_runner=subprocess):
"""Create the cluster head node, which in turn creates the workers."""
provider = (_provider or get_node_provider(config["provider"],
config["cluster_name"]))
provider = (_provider or _get_node_provider(config["provider"],
config["cluster_name"]))

config = copy.deepcopy(config)
config_file = os.path.abspath(config_file)
Expand Down Expand Up @@ -793,7 +810,7 @@ def attach_cluster(config_file: str,

def exec_cluster(config_file: str,
*,
cmd: Any = None,
cmd: str = None,
run_env: str = "auto",
screen: bool = False,
tmux: bool = False,
Expand Down Expand Up @@ -833,7 +850,7 @@ def exec_cluster(config_file: str,
head_node = _get_head_node(
config, config_file, override_cluster_name, create_if_needed=start)

provider = get_node_provider(config["provider"], config["cluster_name"])
provider = _get_node_provider(config["provider"], config["cluster_name"])
try:
updater = NodeUpdaterThread(
node_id=head_node,
Expand Down Expand Up @@ -955,7 +972,7 @@ def rsync(config_file: str,
is_file_mount = True
break

provider = get_node_provider(config["provider"], config["cluster_name"])
provider = _get_node_provider(config["provider"], config["cluster_name"])
try:
nodes = []
if all_nodes:
Expand Down Expand Up @@ -1010,7 +1027,7 @@ def get_head_node_ip(config_file: str,
if override_cluster_name is not None:
config["cluster_name"] = override_cluster_name

provider = get_node_provider(config["provider"], config["cluster_name"])
provider = _get_node_provider(config["provider"], config["cluster_name"])
try:
head_node = _get_head_node(config, config_file, override_cluster_name)
if config.get("provider", {}).get("use_internal_ips", False) is True:
Expand All @@ -1024,14 +1041,14 @@ def get_head_node_ip(config_file: str,


def get_worker_node_ips(config_file: str,
override_cluster_name: Optional[str]) -> str:
override_cluster_name: Optional[str]) -> List[str]:
"""Returns worker node IPs for given configuration file."""

config = yaml.safe_load(open(config_file).read())
if override_cluster_name is not None:
config["cluster_name"] = override_cluster_name

provider = get_node_provider(config["provider"], config["cluster_name"])
provider = _get_node_provider(config["provider"], config["cluster_name"])
try:
nodes = provider.non_terminated_nodes({
TAG_RAY_NODE_KIND: NODE_KIND_WORKER
Expand All @@ -1051,7 +1068,7 @@ def _get_worker_nodes(config, override_cluster_name):
if override_cluster_name is not None:
config["cluster_name"] = override_cluster_name

provider = get_node_provider(config["provider"], config["cluster_name"])
provider = _get_node_provider(config["provider"], config["cluster_name"])
try:
return provider.non_terminated_nodes({
TAG_RAY_NODE_KIND: NODE_KIND_WORKER
Expand All @@ -1064,7 +1081,7 @@ def _get_head_node(config: Dict[str, Any],
config_file: str,
override_cluster_name: Optional[str],
create_if_needed: bool = False) -> str:
provider = get_node_provider(config["provider"], config["cluster_name"])
provider = _get_node_provider(config["provider"], config["cluster_name"])
try:
head_node_tags = {
TAG_RAY_NODE_KIND: NODE_KIND_HEAD,
Expand Down
File renamed without changes.
Empty file.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import logging

from ray.autoscaler.node_provider import NodeProvider
from ray.autoscaler.gcp.config import bootstrap_gcp
from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME
from ray.autoscaler.gcp.config import MAX_POLLS, POLL_INTERVAL, \
from ray.autoscaler._private.gcp.config import bootstrap_gcp
from ray.autoscaler._private.gcp.config import MAX_POLLS, POLL_INTERVAL, \
construct_clients_from_provider_config

logger = logging.getLogger(__name__)
Expand Down
Loading

0 comments on commit 6a227ae

Please sign in to comment.