Skip to content

Commit

Permalink
[tune] Add a callable check for converting to trainable (ray-project#…
Browse files Browse the repository at this point in the history
richardliaw authored Jan 8, 2019

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 5dadac1 commit 3331950
Showing 2 changed files with 30 additions and 1 deletion.
13 changes: 12 additions & 1 deletion python/ray/tune/registry.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@
from __future__ import division
from __future__ import print_function

import logging
from types import FunctionType

import ray
@@ -17,6 +18,8 @@
TRAINABLE_CLASS, ENV_CREATOR, RLLIB_MODEL, RLLIB_PREPROCESSOR
]

logger = logging.getLogger(__name__)


def register_trainable(name, trainable):
"""Register a trainable function or class.
@@ -30,8 +33,16 @@ def register_trainable(name, trainable):

from ray.tune.trainable import Trainable, wrap_function

if isinstance(trainable, FunctionType):
if isinstance(trainable, type):
logger.debug("Detected class for trainable.")
elif isinstance(trainable, FunctionType):
logger.debug("Detected function for trainable.")
trainable = wrap_function(trainable)
elif callable(trainable):
logger.warning(
"Detected unknown callable for trainable. Converting to class.")
trainable = wrap_function(trainable)

if not issubclass(trainable, Trainable):
raise TypeError("Second argument must be convertable to Trainable",
trainable)
18 changes: 18 additions & 0 deletions python/ray/tune/test/trial_runner_test.py
Original file line number Diff line number Diff line change
@@ -112,6 +112,24 @@ class B(Trainable):
self.assertRaises(TypeError, lambda: register_trainable("foo", B()))
self.assertRaises(TypeError, lambda: register_trainable("foo", A))

def testRegisterTrainableCallable(self):
def dummy_fn(config, reporter, steps):
reporter(timesteps_total=steps, done=True)

from functools import partial
steps = 500
register_trainable("test", partial(dummy_fn, steps=steps))
[trial] = run_experiments({
"foo": {
"run": "test",
"config": {
"script_min_iter_time_s": 0,
},
}
})
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], steps)

def testBuiltInTrainableResources(self):
class B(Trainable):
@classmethod

0 comments on commit 3331950

Please sign in to comment.