Skip to content

Commit

Permalink
[Java] Allow passing internal config from raylet to Java worker (ray-…
Browse files Browse the repository at this point in the history
  • Loading branch information
kfstorm authored Mar 15, 2020
1 parent a87199d commit 630e489
Show file tree
Hide file tree
Showing 17 changed files with 223 additions and 120 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
} catch (IOException e) {
throw new RuntimeException("Failed to create the log directory.", e);
}
nativeSetup(rayConfig.logDir);
nativeSetup(rayConfig.logDir, rayConfig.rayletConfigParameters);
Runtime.getRuntime().addShutdownHook(new Thread(RayNativeRuntime::nativeShutdownHook));
}

Expand Down Expand Up @@ -193,7 +193,7 @@ private static native long nativeInitCoreWorker(int workerMode, String storeSock

private static native void nativeDestroyCoreWorker(long nativeCoreWorkerPointer);

private static native void nativeSetup(String logDir);
private static native void nativeSetup(String logDir, Map<String, String> rayletConfigParameters);

private static native void nativeShutdownHook();

Expand Down
10 changes: 5 additions & 5 deletions java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import java.io.File;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
Expand Down Expand Up @@ -67,7 +67,7 @@ public class RayConfig {

public String rayletSocketName;
private int nodeManagerPort;
public final List<String> rayletConfigParameters;
public final Map<String, String> rayletConfigParameters;

public final String jobResourcePath;
public final String pythonWorkerCommand;
Expand Down Expand Up @@ -204,11 +204,11 @@ public RayConfig(Config config) {
}

// Raylet parameters.
rayletConfigParameters = new ArrayList<>();
rayletConfigParameters = new HashMap<>();
Config rayletConfig = config.getConfig("ray.raylet.config");
for (Map.Entry<String, ConfigValue> entry : rayletConfig.entrySet()) {
String parameter = entry.getKey() + "," + entry.getValue().unwrapped();
rayletConfigParameters.add(parameter);
Object value = entry.getValue().unwrapped();
rayletConfigParameters.put(entry.getKey(), value == null ? "" : value.toString());
}

// Job resource path.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,10 @@ private void startGcs() {
gcsServerFile.getAbsolutePath(),
String.format("--redis_address=%s", rayConfig.getRedisIp()),
String.format("--redis_port=%d", rayConfig.getRedisPort()),
String.format("--config_list=%s", String.join(",", rayConfig.rayletConfigParameters)),
String.format("--config_list=%s",
rayConfig.rayletConfigParameters.entrySet().stream()
.map(entry -> entry.getKey() + "," + entry.getValue()).collect(Collectors
.joining(","))),
String.format("--redis_password=%s", redisPasswordOption)
);
startProcess(command, null, "gcs_server");
Expand Down Expand Up @@ -316,7 +319,9 @@ private void startRaylet() {
String.format("--maximum_startup_concurrency=%d", maximumStartupConcurrency),
String.format("--static_resource_list=%s",
ResourceUtil.getResourcesStringFromMap(rayConfig.resources)),
String.format("--config_list=%s", String.join(",", rayConfig.rayletConfigParameters)),
String.format("--config_list=%s", rayConfig.rayletConfigParameters.entrySet().stream()
.map(entry -> entry.getKey() + "," + entry.getValue())
.collect(Collectors.joining(","))),
String.format("--python_worker_command=%s", buildPythonWorkerCommand()),
String.format("--java_worker_command=%s", buildWorkerCommand()),
String.format("--redis_password=%s", redisPasswordOption)
Expand Down Expand Up @@ -378,8 +383,8 @@ private String buildWorkerCommand() {
cmd.add("-Dray.redis.password=" + rayConfig.headRedisPassword);
}

// Number of workers per Java worker process
cmd.add("-Dray.raylet.config.num_workers_per_process_java=RAY_WORKER_NUM_WORKERS_PLACEHOLDER");

cmd.add("RAY_WORKER_RAYLET_CONFIG_PLACEHOLDER");

cmd.addAll(rayConfig.jvmParameters);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package org.ray.api.test;

import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.gson.Gson;
Expand All @@ -13,6 +12,7 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.ray.api.Ray;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.util.NetworkUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -90,15 +90,9 @@ public void setUp() {
String.format("--node-manager-port=%s", nodeManagerPort),
"--load-code-from-local",
"--include-java",
"--java-worker-options=" + workerOptions
"--java-worker-options=" + workerOptions,
"--internal-config=" + new Gson().toJson(RayConfig.create().rayletConfigParameters)
);
String numWorkersPerProcessJava = System
.getProperty("ray.raylet.config.num_workers_per_process_java");
if (!Strings.isNullOrEmpty(numWorkersPerProcessJava)) {
startCommand = ImmutableList.<String>builder().addAll(startCommand)
.add(String.format("--internal-config={\"num_workers_per_process_java\": %s}",
numWorkersPerProcessJava)).build();
}
if (!executeCommand(startCommand, 10, getRayStartEnv())) {
throw new RuntimeException("Couldn't start ray cluster.");
}
Expand Down
39 changes: 39 additions & 0 deletions java/test/src/main/java/org/ray/api/test/RayletConfigTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package org.ray.api.test;

import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.TestUtils;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

public class RayletConfigTest extends BaseTest {

private static final String RAY_CONFIG_KEY = "num_workers_per_process_java";
private static final String RAY_CONFIG_VALUE = "2";

@BeforeClass
public void beforeClass() {
System.setProperty("ray.raylet.config." + RAY_CONFIG_KEY, RAY_CONFIG_VALUE);
}

@AfterClass
public void afterClass() {
System.clearProperty("ray.raylet.config." + RAY_CONFIG_KEY);
}

public static class TestActor {

public String getConfigValue() {
return TestUtils.getRuntime().getRayConfig().rayletConfigParameters.get(RAY_CONFIG_KEY);
}
}

@Test
public void testRayletConfigPassThrough() {
RayActor<TestActor> actor = Ray.createActor(TestActor::new);
String configValue = actor.call(TestActor::getConfigValue).get();
Assert.assertEquals(configValue, RAY_CONFIG_VALUE);
}
}
4 changes: 2 additions & 2 deletions python/ray/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -1366,11 +1366,11 @@ def build_java_worker_command(
pairs.append(("ray.home", RAY_HOME))
pairs.append(("ray.log-dir", os.path.join(session_dir, "logs")))
pairs.append(("ray.session-dir", session_dir))
pairs.append(("ray.raylet.config.num_workers_per_process_java",
"RAY_WORKER_NUM_WORKERS_PLACEHOLDER"))

command = ["java"] + ["-D{}={}".format(*pair) for pair in pairs]

command += ["RAY_WORKER_RAYLET_CONFIG_PLACEHOLDER"]

# Add ray jars path to java classpath
ray_jars = os.path.join(get_ray_jars_dir(), "*")
if java_worker_options is None:
Expand Down
2 changes: 1 addition & 1 deletion src/ray/common/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,6 @@ constexpr char kTaskTablePrefix[] = "TaskTable";
constexpr char kWorkerDynamicOptionPlaceholderPrefix[] =
"RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER_";

constexpr char kWorkerNumWorkersPlaceholder[] = "RAY_WORKER_NUM_WORKERS_PLACEHOLDER";
constexpr char kWorkerRayletConfigPlaceholder[] = "RAY_WORKER_RAYLET_CONFIG_PLACEHOLDER";

#endif // RAY_CONSTANTS_H_
33 changes: 33 additions & 0 deletions src/ray/core_worker/lib/java/jni_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,39 @@ inline jobject NativeIdVectorToJavaByteArrayList(JNIEnv *env,
});
}

/// Convert a Java Map<?, ?> to a C++ std::unordered_map<?, ?>
template <typename key_type, typename value_type>
inline std::unordered_map<key_type, value_type> JavaMapToNativeMap(
JNIEnv *env, jobject java_map,
const std::function<key_type(JNIEnv *, jobject)> &key_converter,
const std::function<value_type(JNIEnv *, jobject)> &value_converter) {
std::unordered_map<key_type, value_type> native_map;
if (java_map) {
jobject entry_set = env->CallObjectMethod(java_map, java_map_entry_set);
RAY_CHECK_JAVA_EXCEPTION(env);
jobject iterator = env->CallObjectMethod(entry_set, java_set_iterator);
RAY_CHECK_JAVA_EXCEPTION(env);
while (env->CallBooleanMethod(iterator, java_iterator_has_next)) {
RAY_CHECK_JAVA_EXCEPTION(env);
jobject map_entry = env->CallObjectMethod(iterator, java_iterator_next);
RAY_CHECK_JAVA_EXCEPTION(env);
auto java_key = (jstring)env->CallObjectMethod(map_entry, java_map_entry_get_key);
RAY_CHECK_JAVA_EXCEPTION(env);
key_type key = key_converter(env, java_key);
auto java_value = env->CallObjectMethod(map_entry, java_map_entry_get_value);
value_type value = value_converter(env, java_value);
native_map.emplace(key, value);
env->DeleteLocalRef(java_key);
env->DeleteLocalRef(java_value);
env->DeleteLocalRef(map_entry);
}
RAY_CHECK_JAVA_EXCEPTION(env);
env->DeleteLocalRef(iterator);
env->DeleteLocalRef(entry_set);
}
return native_map;
}

/// Convert a C++ ray::Buffer to a Java byte array.
inline jbyteArray NativeBufferToJavaByteArray(JNIEnv *env,
const std::shared_ptr<ray::Buffer> buffer) {
Expand Down
14 changes: 11 additions & 3 deletions src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,23 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeDestroyCoreWo
delete core_worker;
}

JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetup(JNIEnv *env,
jclass,
jstring logDir) {
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetup(
JNIEnv *env, jclass, jstring logDir, jobject rayletConfigParameters) {
std::string log_dir = JavaStringToNativeString(env, logDir);
ray::RayLog::StartRayLog("java_worker", ray::RayLogLevel::INFO, log_dir);
// TODO (kfstorm): We can't InstallFailureSignalHandler here, because JVM already
// installed its own signal handler. It's possible to fix this by chaining signal
// handlers. But it's not easy. See
// https://docs.oracle.com/javase/9/troubleshoot/handle-signals-and-exceptions.htm.
auto raylet_config = JavaMapToNativeMap<std::string, std::string>(
env, rayletConfigParameters,
[](JNIEnv *env, jobject java_key) {
return JavaStringToNativeString(env, (jstring)java_key);
},
[](JNIEnv *env, jobject java_value) {
return JavaStringToNativeString(env, (jstring)java_value);
});
RayConfig::instance().initialize(raylet_config);
}

JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeShutdownHook(JNIEnv *,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,11 @@ Java_org_ray_runtime_RayNativeRuntime_nativeDestroyCoreWorker(JNIEnv *, jclass,
/*
* Class: org_ray_runtime_RayNativeRuntime
* Method: nativeSetup
* Signature: (Ljava/lang/String;)V
* Signature: (Ljava/lang/String;Ljava/util/Map;)V
*/
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetup(JNIEnv *, jclass,
jstring);
jstring,
jobject);

/*
* Class: org_ray_runtime_RayNativeRuntime
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,27 +65,16 @@ inline std::vector<ray::TaskArg> ToTaskArgs(JNIEnv *env, jobject args) {
inline std::unordered_map<std::string, double> ToResources(JNIEnv *env,
jobject java_resources) {
std::unordered_map<std::string, double> resources;
if (java_resources) {
jobject entry_set = env->CallObjectMethod(java_resources, java_map_entry_set);
RAY_CHECK_JAVA_EXCEPTION(env);
jobject iterator = env->CallObjectMethod(entry_set, java_set_iterator);
RAY_CHECK_JAVA_EXCEPTION(env);
while (env->CallBooleanMethod(iterator, java_iterator_has_next)) {
RAY_CHECK_JAVA_EXCEPTION(env);
jobject map_entry = env->CallObjectMethod(iterator, java_iterator_next);
RAY_CHECK_JAVA_EXCEPTION(env);
auto java_key = (jstring)env->CallObjectMethod(map_entry, java_map_entry_get_key);
RAY_CHECK_JAVA_EXCEPTION(env);
std::string key = JavaStringToNativeString(env, java_key);
auto java_value = env->CallObjectMethod(map_entry, java_map_entry_get_value);
RAY_CHECK_JAVA_EXCEPTION(env);
double value = env->CallDoubleMethod(java_value, java_double_double_value);
RAY_CHECK_JAVA_EXCEPTION(env);
resources.emplace(key, value);
}
RAY_CHECK_JAVA_EXCEPTION(env);
}
return resources;
return JavaMapToNativeMap<std::string, double>(
env, java_resources,
[](JNIEnv *env, jobject java_key) {
return JavaStringToNativeString(env, (jstring)java_key);
},
[](JNIEnv *env, jobject java_value) {
double value = env->CallDoubleMethod(java_value, java_double_double_value);
RAY_CHECK_JAVA_EXCEPTION(env);
return value;
});
}

inline ray::TaskOptions ToTaskOptions(JNIEnv *env, jint numReturns, jobject callOptions) {
Expand Down
1 change: 1 addition & 0 deletions src/ray/raylet/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ int main(int argc, char *argv[]) {
static_resource_conf[resource_name] = std::stod(resource_quantity);
}

node_manager_config.raylet_config = raylet_config;
node_manager_config.resource_config = ray::ResourceSet(std::move(static_resource_conf));
RAY_LOG(DEBUG) << "Starting raylet with static resource configuration: "
<< node_manager_config.resource_config.ToString();
Expand Down
2 changes: 1 addition & 1 deletion src/ray/raylet/node_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ NodeManager::NodeManager(boost::asio::io_service &io_service,
local_available_resources_(config.resource_config),
worker_pool_(
io_service, config.num_initial_workers, config.maximum_startup_concurrency,
gcs_client_, config.worker_commands,
gcs_client_, config.worker_commands, config.raylet_config,
/*starting_worker_timeout_callback=*/
[this]() { this->DispatchTasks(this->local_queues_.GetReadyTasksByClass()); }),
scheduling_policy_(local_queues_),
Expand Down
2 changes: 2 additions & 0 deletions src/ray/raylet/node_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ struct NodeManagerConfig {
std::string temp_dir;
/// The path of this ray session dir.
std::string session_dir;
/// The raylet config list of this node.
std::unordered_map<std::string, std::string> raylet_config;
};

class NodeManager : public rpc::NodeManagerServiceHandler {
Expand Down
Loading

0 comments on commit 630e489

Please sign in to comment.