forked from mlflow/mlflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enable reading ~/.databrickscfg files form MLflow Java API (mlflow#398)
- Loading branch information
Showing
10 changed files
with
486 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
98 changes: 98 additions & 0 deletions
98
...ava/client/src/main/java/org/mlflow/tracking/creds/DatabricksConfigHostCredsProvider.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
package org.mlflow.tracking.creds; | ||
|
||
import java.io.File; | ||
import java.io.IOException; | ||
import java.nio.file.Paths; | ||
|
||
import org.ini4j.Ini; | ||
import org.ini4j.Profile; | ||
|
||
public class DatabricksConfigHostCredsProvider implements MlflowHostCredsProvider { | ||
private static final String CONFIG_FILE_ENV_VAR = "DATABRICKS_CONFIG_FILE"; | ||
|
||
private final String profile; | ||
|
||
private MlflowHostCreds hostCreds; | ||
|
||
public DatabricksConfigHostCredsProvider(String profile) { | ||
this.profile = profile; | ||
} | ||
|
||
public DatabricksConfigHostCredsProvider() { | ||
this.profile = null; | ||
} | ||
|
||
private void loadConfigIfNecessary() { | ||
if (hostCreds == null) { | ||
reloadConfig(); | ||
} | ||
} | ||
|
||
private void reloadConfig() { | ||
String basePath = System.getenv(CONFIG_FILE_ENV_VAR); | ||
if (basePath == null) { | ||
String userHome = System.getProperty("user.home"); | ||
basePath = Paths.get(userHome, ".databrickscfg").toString(); | ||
} | ||
|
||
if (!new File(basePath).isFile()) { | ||
throw new IllegalStateException("Could not find Databricks configuration file" + | ||
" (" + basePath + "). Please run 'databricks configure' using the Databricks CLI."); | ||
} | ||
|
||
Ini ini; | ||
try { | ||
ini = new Ini(new File(basePath)); | ||
} catch (IOException e) { | ||
throw new IllegalStateException("Failed to load databrickscfg file at " + basePath, e); | ||
} | ||
|
||
Profile.Section section; | ||
if (profile == null) { | ||
section = ini.get("DEFAULT"); | ||
if (section == null) { | ||
throw new IllegalStateException("Could not find 'DEFAULT' section within config file" + | ||
" (" + basePath + "). Please run 'databricks configure' using the Databricks CLI."); | ||
} | ||
} else { | ||
section = ini.get(profile); | ||
if (section == null) { | ||
throw new IllegalStateException("Could not find '" + profile + "' section within config" + | ||
" file (" + basePath + "). Please run 'databricks configure --profile " + profile + "'" + | ||
" using the Databricks CLI."); | ||
} | ||
} | ||
assert (section != null); | ||
|
||
String host = section.get("host"); | ||
String username = section.get("username"); | ||
String password = section.get("password"); | ||
String token = section.get("token"); | ||
boolean insecure = section.get("insecure", "false").equals("true"); | ||
|
||
if (host == null) { | ||
throw new IllegalStateException("No 'host' configured within Databricks config file" + | ||
" (" + basePath + "). Please run 'databricks configure' using the Databricks CLI."); | ||
} | ||
|
||
boolean hasValidUserPassword = username != null && password != null; | ||
boolean hasValidToken = token != null; | ||
if (!hasValidUserPassword && !hasValidToken) { | ||
throw new IllegalStateException("No authentication configured within Databricks config file" + | ||
" (" + basePath + "). Please run 'databricks configure' using the Databricks CLI."); | ||
} | ||
|
||
this.hostCreds = new BasicMlflowHostCreds(host, username, password, token, insecure); | ||
} | ||
|
||
@Override | ||
public MlflowHostCreds getHostCreds() { | ||
loadConfigIfNecessary(); | ||
return hostCreds; | ||
} | ||
|
||
@Override | ||
public void refresh() { | ||
reloadConfig(); | ||
} | ||
} |
50 changes: 50 additions & 0 deletions
50
...va/client/src/main/java/org/mlflow/tracking/creds/DatabricksDynamicHostCredsProvider.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
package org.mlflow.tracking.creds; | ||
|
||
import java.util.Map; | ||
|
||
import com.google.common.annotations.VisibleForTesting; | ||
import org.apache.log4j.Logger; | ||
|
||
public class DatabricksDynamicHostCredsProvider implements MlflowHostCredsProvider { | ||
private static final Logger logger = Logger.getLogger(DatabricksDynamicHostCredsProvider.class); | ||
|
||
private final Map<String, String> configProvider; | ||
|
||
private DatabricksDynamicHostCredsProvider(Map<String, String> configProvider) { | ||
this.configProvider = configProvider; | ||
} | ||
|
||
public static DatabricksDynamicHostCredsProvider createIfAvailable() { | ||
return createIfAvailable("com.databricks.config.DatabricksClientSettingsProvider"); | ||
} | ||
|
||
@VisibleForTesting | ||
static DatabricksDynamicHostCredsProvider createIfAvailable(String className) { | ||
try { | ||
Class<?> cls = Class.forName(className); | ||
return new DatabricksDynamicHostCredsProvider((Map<String, String>) cls.newInstance()); | ||
} catch (ClassNotFoundException e) { | ||
return null; | ||
} catch (IllegalAccessException | InstantiationException e) { | ||
logger.warn("Found but failed to invoke dynamic config provider", e); | ||
return null; | ||
} | ||
|
||
} | ||
|
||
@Override | ||
public MlflowHostCreds getHostCreds() { | ||
return new BasicMlflowHostCreds( | ||
configProvider.get("host"), | ||
configProvider.get("username"), | ||
configProvider.get("password"), | ||
configProvider.get("token"), | ||
"true".equals(configProvider.get("shouldIgnoreTlsVerification")) | ||
); | ||
} | ||
|
||
@Override | ||
public void refresh() { | ||
// no-op | ||
} | ||
} |
47 changes: 47 additions & 0 deletions
47
mlflow/java/client/src/main/java/org/mlflow/tracking/creds/HostCredsProviderChain.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
package org.mlflow.tracking.creds; | ||
|
||
import org.apache.log4j.Logger; | ||
|
||
import java.util.ArrayList; | ||
import java.util.Arrays; | ||
import java.util.List; | ||
|
||
import org.mlflow.tracking.MlflowClientException; | ||
|
||
public class HostCredsProviderChain implements MlflowHostCredsProvider { | ||
private static final Logger logger = Logger.getLogger(HostCredsProviderChain.class); | ||
|
||
private final List<MlflowHostCredsProvider> hostCredsProviders = new ArrayList<>(); | ||
|
||
public HostCredsProviderChain(MlflowHostCredsProvider... hostCredsProviders) { | ||
this.hostCredsProviders.addAll(Arrays.asList(hostCredsProviders)); | ||
} | ||
|
||
@Override | ||
public MlflowHostCreds getHostCreds() { | ||
List<String> exceptionMessages = new ArrayList<>(); | ||
for (MlflowHostCredsProvider provider : hostCredsProviders) { | ||
try { | ||
MlflowHostCreds hostCreds = provider.getHostCreds(); | ||
|
||
if (hostCreds != null && hostCreds.getHost() != null) { | ||
logger.debug("Loading credentials from " + provider.toString()); | ||
return hostCreds; | ||
} | ||
} catch (Exception e) { | ||
String message = provider + ": " + e.getMessage(); | ||
logger.debug("Unable to load credentials from " + message); | ||
exceptionMessages.add(message); | ||
} | ||
} | ||
throw new MlflowClientException("Unable to load MLflow Host/Credentials from any provider in" + | ||
" the chain: " + exceptionMessages); | ||
} | ||
|
||
@Override | ||
public void refresh() { | ||
for (MlflowHostCredsProvider provider : hostCredsProviders) { | ||
provider.refresh(); | ||
} | ||
} | ||
} |
Oops, something went wrong.