Skip to content

Commit

Permalink
[core] runtime context resource ids getter (ray-project#26907)
Browse files Browse the repository at this point in the history
  • Loading branch information
richardliaw authored Jul 24, 2022
1 parent acbab51 commit d01a80e
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 6 deletions.
2 changes: 1 addition & 1 deletion python/ray/_private/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,7 @@ def get_gpu_ids():
return assigned_ids


@Deprecated
@Deprecated(message="Use ray.get_runtime_context().get_assigned_resources() instead.")
def get_resource_ids():
"""Get the IDs of the resources that are available to the worker.
Expand Down
28 changes: 26 additions & 2 deletions python/ray/runtime_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Optional
from typing import Any, Dict, Optional

import ray._private.worker
from ray._private.client_mode_hook import client_mode_hook
Expand All @@ -17,7 +17,7 @@ def __init__(self, worker):
assert worker is not None
self.worker = worker

def get(self):
def get(self) -> Dict[str, Any]:
"""Get a dictionary of the current context.
Returns:
Expand Down Expand Up @@ -261,6 +261,30 @@ def should_capture_child_tasks_in_placement_group(self):
"""
return self.worker.should_capture_child_tasks_in_placement_group

def get_assigned_resources(self):
"""Get the assigned resources to this worker.
By default for tasks, this will return {"CPU": 1}.
By default for actors, this will return {}. This is because
actors do not have CPUs assigned to them by default.
Returns:
A dictionary mapping the name of a resource to a float, where
the float represents the amount of that resource reserved
for this worker.
"""
assert (
self.worker.mode == ray._private.worker.WORKER_MODE
), f"This method is only available when the process is a\
worker. Current mode: {self.worker.mode}"
self.worker.check_connected()
resource_id_map = self.worker.core_worker.resource_ids()
resource_map = {
res: sum(amt for _, amt in mapping)
for res, mapping in resource_id_map.items()
}
return resource_map

def get_runtime_env_string(self):
"""Get the runtime env string used for the current driver or worker.
Expand Down
6 changes: 3 additions & 3 deletions python/ray/tests/test_placement_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def test_placement_group_actor_resource_ids(ray_start_cluster, connect_to_client
@ray.remote(num_cpus=1)
class F:
def f(self):
return ray._private.worker.get_resource_ids()
return ray.get_runtime_context().get_assigned_resources()

cluster = ray_start_cluster
num_nodes = 1
Expand All @@ -393,7 +393,7 @@ def f(self):
def test_placement_group_task_resource_ids(ray_start_cluster, connect_to_client):
@ray.remote(num_cpus=1)
def f():
return ray._private.worker.get_resource_ids()
return ray.get_runtime_context().get_assigned_resources()

cluster = ray_start_cluster
num_nodes = 1
Expand Down Expand Up @@ -425,7 +425,7 @@ def f():
def test_placement_group_hang(ray_start_cluster, connect_to_client):
@ray.remote(num_cpus=1)
def f():
return ray._private.worker.get_resource_ids()
return ray.get_runtime_context().get_assigned_resources()

cluster = ray_start_cluster
num_nodes = 1
Expand Down
28 changes: 28 additions & 0 deletions python/ray/tests/test_runtime_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,34 @@ def echo2(self, s):
assert ray.get(ray.get(obj)) == "hello"


def test_get_assigned_resources(ray_start_10_cpus):
@ray.remote
class Echo:
def check(self):
return ray.get_runtime_context().get_assigned_resources()

e = Echo.remote()
result = e.check.remote()
print(ray.get(result))
assert ray.get(result).get("CPU") is None
ray.kill(e)

e = Echo.options(num_cpus=4).remote()
result = e.check.remote()
assert ray.get(result)["CPU"] == 4.0
ray.kill(e)

@ray.remote
def check():
return ray.get_runtime_context().get_assigned_resources()

result = check.remote()
assert ray.get(result)["CPU"] == 1.0

result = check.options(num_cpus=2).remote()
assert ray.get(result)["CPU"] == 2.0


def test_actor_stats_normal_task(ray_start_regular):
# Because it works at the core worker level, this API works for tasks.
@ray.remote
Expand Down

0 comments on commit d01a80e

Please sign in to comment.