Skip to content

Commit

Permalink
Enable reading ~/.databrickscfg files form MLflow Java API (mlflow#398)
Browse files Browse the repository at this point in the history
  • Loading branch information
aarondav authored Aug 29, 2018
1 parent 5ae59fe commit d02916a
Show file tree
Hide file tree
Showing 10 changed files with 486 additions and 9 deletions.
14 changes: 11 additions & 3 deletions mlflow/java/client/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,18 @@
<artifactId>commons-io</artifactId>
</dependency>
<dependency>
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>
<version>1.3.9</version>
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>
<version>1.3.9</version>
</dependency>
<dependency>
<groupId>javax.annotation</groupId>
<artifactId>javax.annotation-api</artifactId>
</dependency>
<dependency>
<groupId>org.ini4j</groupId>
<artifactId>ini4j</artifactId>
</dependency>
</dependencies>

<build>
Expand Down Expand Up @@ -108,6 +112,10 @@
<pattern>com.databricks.api.proto.databricks</pattern>
<shadedPattern>${mlflow.shade.packageName}.databricks</shadedPattern>
</relocation>
<relocation>
<pattern>org.yaml</pattern>
<shadedPattern>${mlflow.shade.packageName}.yaml</shadedPattern>
</relocation>
</relocations>
</configuration>
<executions>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import org.apache.http.client.utils.URIBuilder;

import org.mlflow.api.proto.Service.*;
import org.mlflow.tracking.creds.BasicMlflowHostCreds;
import org.mlflow.tracking.creds.MlflowHostCredsProvider;
import org.mlflow.tracking.creds.*;

import java.net.URI;
import java.net.URISyntaxException;
Expand Down Expand Up @@ -207,8 +206,20 @@ private static String getDefaultTrackingUri() {
private static MlflowHostCredsProvider getHostCredsProviderFromTrackingUri(String trackingUri) {
URI uri = URI.create(trackingUri);
MlflowHostCredsProvider provider;

if ("http".equals(uri.getScheme()) || "https".equals(uri.getScheme())) {
provider = new BasicMlflowHostCreds(trackingUri);
} else if (trackingUri.equals("databricks")) {
MlflowHostCredsProvider profileProvider = new DatabricksConfigHostCredsProvider();
MlflowHostCredsProvider dynamicProvider =
DatabricksDynamicHostCredsProvider.createIfAvailable();
if (dynamicProvider != null) {
provider = new HostCredsProviderChain(dynamicProvider, profileProvider);
} else {
provider = profileProvider;
}
} else if ("databricks".equals(uri.getScheme())) {
provider = new DatabricksConfigHostCredsProvider(uri.getHost());
} else if (uri.getScheme() == null || "file".equals(uri.getScheme())) {
throw new IllegalArgumentException("Java Client currently does not support" +
" local tracking URIs. Please point to a Tracking Server.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ public class BasicMlflowHostCreds implements MlflowHostCreds, MlflowHostCredsPro
private String username;
private String password;
private String token;
private boolean noTlsVerify;
private boolean shouldIgnoreTlsVerification;

public BasicMlflowHostCreds(String host) {
this.host = host;
Expand All @@ -28,12 +28,12 @@ public BasicMlflowHostCreds(
String username,
String password,
String token,
boolean noTlsVerify) {
boolean shouldIgnoreTlsVerification) {
this.host = host;
this.username = username;
this.password = password;
this.token = token;
this.noTlsVerify = noTlsVerify;
this.shouldIgnoreTlsVerification = shouldIgnoreTlsVerification;
}

@Override
Expand All @@ -58,7 +58,7 @@ public String getToken() {

@Override
public boolean shouldIgnoreTlsVerification() {
return noTlsVerify;
return shouldIgnoreTlsVerification;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package org.mlflow.tracking.creds;

import java.io.File;
import java.io.IOException;
import java.nio.file.Paths;

import org.ini4j.Ini;
import org.ini4j.Profile;

public class DatabricksConfigHostCredsProvider implements MlflowHostCredsProvider {
private static final String CONFIG_FILE_ENV_VAR = "DATABRICKS_CONFIG_FILE";

private final String profile;

private MlflowHostCreds hostCreds;

public DatabricksConfigHostCredsProvider(String profile) {
this.profile = profile;
}

public DatabricksConfigHostCredsProvider() {
this.profile = null;
}

private void loadConfigIfNecessary() {
if (hostCreds == null) {
reloadConfig();
}
}

private void reloadConfig() {
String basePath = System.getenv(CONFIG_FILE_ENV_VAR);
if (basePath == null) {
String userHome = System.getProperty("user.home");
basePath = Paths.get(userHome, ".databrickscfg").toString();
}

if (!new File(basePath).isFile()) {
throw new IllegalStateException("Could not find Databricks configuration file" +
" (" + basePath + "). Please run 'databricks configure' using the Databricks CLI.");
}

Ini ini;
try {
ini = new Ini(new File(basePath));
} catch (IOException e) {
throw new IllegalStateException("Failed to load databrickscfg file at " + basePath, e);
}

Profile.Section section;
if (profile == null) {
section = ini.get("DEFAULT");
if (section == null) {
throw new IllegalStateException("Could not find 'DEFAULT' section within config file" +
" (" + basePath + "). Please run 'databricks configure' using the Databricks CLI.");
}
} else {
section = ini.get(profile);
if (section == null) {
throw new IllegalStateException("Could not find '" + profile + "' section within config" +
" file (" + basePath + "). Please run 'databricks configure --profile " + profile + "'" +
" using the Databricks CLI.");
}
}
assert (section != null);

String host = section.get("host");
String username = section.get("username");
String password = section.get("password");
String token = section.get("token");
boolean insecure = section.get("insecure", "false").equals("true");

if (host == null) {
throw new IllegalStateException("No 'host' configured within Databricks config file" +
" (" + basePath + "). Please run 'databricks configure' using the Databricks CLI.");
}

boolean hasValidUserPassword = username != null && password != null;
boolean hasValidToken = token != null;
if (!hasValidUserPassword && !hasValidToken) {
throw new IllegalStateException("No authentication configured within Databricks config file" +
" (" + basePath + "). Please run 'databricks configure' using the Databricks CLI.");
}

this.hostCreds = new BasicMlflowHostCreds(host, username, password, token, insecure);
}

@Override
public MlflowHostCreds getHostCreds() {
loadConfigIfNecessary();
return hostCreds;
}

@Override
public void refresh() {
reloadConfig();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package org.mlflow.tracking.creds;

import java.util.Map;

import com.google.common.annotations.VisibleForTesting;
import org.apache.log4j.Logger;

public class DatabricksDynamicHostCredsProvider implements MlflowHostCredsProvider {
private static final Logger logger = Logger.getLogger(DatabricksDynamicHostCredsProvider.class);

private final Map<String, String> configProvider;

private DatabricksDynamicHostCredsProvider(Map<String, String> configProvider) {
this.configProvider = configProvider;
}

public static DatabricksDynamicHostCredsProvider createIfAvailable() {
return createIfAvailable("com.databricks.config.DatabricksClientSettingsProvider");
}

@VisibleForTesting
static DatabricksDynamicHostCredsProvider createIfAvailable(String className) {
try {
Class<?> cls = Class.forName(className);
return new DatabricksDynamicHostCredsProvider((Map<String, String>) cls.newInstance());
} catch (ClassNotFoundException e) {
return null;
} catch (IllegalAccessException | InstantiationException e) {
logger.warn("Found but failed to invoke dynamic config provider", e);
return null;
}

}

@Override
public MlflowHostCreds getHostCreds() {
return new BasicMlflowHostCreds(
configProvider.get("host"),
configProvider.get("username"),
configProvider.get("password"),
configProvider.get("token"),
"true".equals(configProvider.get("shouldIgnoreTlsVerification"))
);
}

@Override
public void refresh() {
// no-op
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package org.mlflow.tracking.creds;

import org.apache.log4j.Logger;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.mlflow.tracking.MlflowClientException;

public class HostCredsProviderChain implements MlflowHostCredsProvider {
private static final Logger logger = Logger.getLogger(HostCredsProviderChain.class);

private final List<MlflowHostCredsProvider> hostCredsProviders = new ArrayList<>();

public HostCredsProviderChain(MlflowHostCredsProvider... hostCredsProviders) {
this.hostCredsProviders.addAll(Arrays.asList(hostCredsProviders));
}

@Override
public MlflowHostCreds getHostCreds() {
List<String> exceptionMessages = new ArrayList<>();
for (MlflowHostCredsProvider provider : hostCredsProviders) {
try {
MlflowHostCreds hostCreds = provider.getHostCreds();

if (hostCreds != null && hostCreds.getHost() != null) {
logger.debug("Loading credentials from " + provider.toString());
return hostCreds;
}
} catch (Exception e) {
String message = provider + ": " + e.getMessage();
logger.debug("Unable to load credentials from " + message);
exceptionMessages.add(message);
}
}
throw new MlflowClientException("Unable to load MLflow Host/Credentials from any provider in" +
" the chain: " + exceptionMessages);
}

@Override
public void refresh() {
for (MlflowHostCredsProvider provider : hostCredsProviders) {
provider.refresh();
}
}
}
Loading

0 comments on commit d02916a

Please sign in to comment.