Skip to content

Commit

Permalink
[tune] Allow fetching pinned objects from trainable functions (ray-pr…
Browse files Browse the repository at this point in the history
…oject#1895)

* updates

* lint

* Update util.py

* Update function_runner.py

* updates
  • Loading branch information
ericl authored Apr 16, 2018
1 parent ddfc875 commit ed8c0f1
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/ray/tune/function_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ray.tune import TuneError
from ray.tune.trainable import Trainable
from ray.tune.result import TrainingResult
from ray.tune.util import _serve_get_pin_requests


class StatusReporter(object):
Expand Down Expand Up @@ -108,6 +109,7 @@ def _train(self):
self._default_config["script_min_iter_time_s"]))
result = self._status_reporter._get_and_clear_status()
while result is None:
_serve_get_pin_requests()
time.sleep(1)
result = self._status_reporter._get_and_clear_status()
if result.timesteps_total is None:
Expand Down
19 changes: 19 additions & 0 deletions python/ray/tune/test/trial_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,25 @@ def f():

self.assertEqual(ray.get(f.remote()), "hello")

def testFetchPinned(self):
X = pin_in_object_store("hello")

def train(config, reporter):
get_pinned_object(X)
reporter(timesteps_total=100, done=True)

register_trainable("f1", train)
[trial] = run_experiments({
"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}
})
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result.timesteps_total, 100)

def testRegisterEnv(self):
register_env("foo", lambda: None)
self.assertRaises(TypeError, lambda: register_env("foo", 2))
Expand Down
38 changes: 38 additions & 0 deletions python/ray/tune/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,66 @@
from __future__ import print_function

import base64
import queue
import threading

import ray
from ray.tune.registry import _to_pinnable, _from_pinnable

_pinned_objects = []
_fetch_requests = queue.Queue()
PINNED_OBJECT_PREFIX = "ray.tune.PinnedObject:"


def pin_in_object_store(obj):
"""Pin an object in the object store.
It will be available as long as the pinning process is alive. The pinned
object can be retrieved by calling get_pinned_object on the identifier
returned by this call.
"""

obj_id = ray.put(_to_pinnable(obj))
_pinned_objects.append(ray.get(obj_id))
return "{}{}".format(PINNED_OBJECT_PREFIX,
base64.b64encode(obj_id.id()).decode("utf-8"))


def get_pinned_object(pinned_id):
"""Retrieve a pinned object from the object store."""

from ray.local_scheduler import ObjectID

if threading.current_thread().getName() != "MainThread":
placeholder = queue.Queue()
_fetch_requests.put((placeholder, pinned_id))
print("Requesting main thread to fetch pinned object", pinned_id)
return placeholder.get()

return _from_pinnable(
ray.get(
ObjectID(base64.b64decode(pinned_id[len(PINNED_OBJECT_PREFIX):]))))


def _serve_get_pin_requests():
"""This is hack to avoid ray.get() on the function runner thread.
The issue is that we run trainable functions on a separate thread,
which cannot access Ray API methods. So instead, that thread puts the
fetch in a queue that is periodically checked from the main thread.
"""

assert threading.current_thread().getName() == "MainThread"

try:
while not _fetch_requests.empty():
(placeholder, pinned_id) = _fetch_requests.get_nowait()
print("Fetching pinned object from main thread", pinned_id)
placeholder.put(get_pinned_object(pinned_id))
except queue.Empty:
pass


if __name__ == '__main__':
ray.init()
X = pin_in_object_store("hello")
Expand Down

0 comments on commit ed8c0f1

Please sign in to comment.