Skip to content

Commit

Permalink
Move URI helper methods to utils.uri module (mlflow#1915)
Browse files Browse the repository at this point in the history
Consolidate URI helper methods that aren't just used by tracking to utils.uri.
  • Loading branch information
sueann authored Oct 9, 2019
1 parent 4c405d0 commit 44f731c
Show file tree
Hide file tree
Showing 16 changed files with 155 additions and 155 deletions.
24 changes: 11 additions & 13 deletions mlflow/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,22 @@
from click import UsageError

import mlflow.azureml.cli
import mlflow.projects as projects
import mlflow.db
import mlflow.experiments
import mlflow.models.cli

import mlflow.sagemaker.cli
import mlflow.projects as projects
import mlflow.runs
import mlflow.sagemaker.cli
import mlflow.store.artifact.cli
import mlflow.store.db.utils
import mlflow.db

from mlflow.tracking.utils import _is_local_uri
from mlflow.utils.logging_utils import eprint
from mlflow.utils.process import ShellCommandException
from mlflow.utils import cli_args
from mlflow import tracking
from mlflow.server import _run_server
from mlflow.server.handlers import _get_store
from mlflow.store.tracking import DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH
from mlflow import tracking
import mlflow.store.artifact.cli
from mlflow.utils import cli_args
from mlflow.utils.logging_utils import eprint
from mlflow.utils.process import ShellCommandException
from mlflow.utils.uri import is_local_uri

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -180,7 +178,7 @@ def ui(backend_store_uri, default_artifact_root, port):
backend_store_uri = DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH

if not default_artifact_root:
if _is_local_uri(backend_store_uri):
if is_local_uri(backend_store_uri):
default_artifact_root = backend_store_uri
else:
default_artifact_root = DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH
Expand Down Expand Up @@ -254,7 +252,7 @@ def server(backend_store_uri, default_artifact_root, host, port,
backend_store_uri = DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH

if not default_artifact_root:
if _is_local_uri(backend_store_uri):
if is_local_uri(backend_store_uri):
default_artifact_root = backend_store_uri
else:
eprint("Option 'default-artifact-root' is required, when backend store is not "
Expand Down
26 changes: 12 additions & 14 deletions mlflow/projects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,32 +19,30 @@
import posixpath
import docker

import mlflow.projects.databricks
import mlflow.tracking as tracking
import mlflow.tracking.fluent as fluent
from mlflow.projects.submitted_run import LocalSubmittedRun, SubmittedRun
from mlflow.projects import _project_spec
from mlflow.exceptions import ExecutionException, MlflowException
from mlflow.entities import RunStatus, SourceType
from mlflow.tracking.fluent import _get_experiment_id
from mlflow.exceptions import ExecutionException, MlflowException
from mlflow.projects import _project_spec
from mlflow.projects.submitted_run import LocalSubmittedRun, SubmittedRun
from mlflow.tracking.context.default_context import _get_user
from mlflow.tracking.context.git_context import _get_git_commit
import mlflow.projects.databricks
from mlflow.utils import process

from mlflow.store.artifact.local_artifact_repo import LocalArtifactRepository
from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository
from mlflow.tracking.fluent import _get_experiment_id
from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository
from mlflow.store.artifact.azure_blob_artifact_repo import AzureBlobArtifactRepository
from mlflow.store.artifact.gcs_artifact_repo import GCSArtifactRepository
from mlflow.store.artifact.hdfs_artifact_repo import HdfsArtifactRepository
from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository

from mlflow.store.artifact.local_artifact_repo import LocalArtifactRepository
from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository
from mlflow.utils import databricks_utils, file_utils, process
from mlflow.utils.file_utils import path_to_local_sqlite_uri, path_to_local_file_uri
from mlflow.utils.mlflow_tags import MLFLOW_PROJECT_ENV, MLFLOW_DOCKER_IMAGE_URI, \
MLFLOW_DOCKER_IMAGE_ID, MLFLOW_USER, MLFLOW_SOURCE_NAME, MLFLOW_SOURCE_TYPE, \
MLFLOW_GIT_COMMIT, MLFLOW_GIT_REPO_URL, MLFLOW_GIT_BRANCH, LEGACY_MLFLOW_GIT_REPO_URL, \
LEGACY_MLFLOW_GIT_BRANCH_NAME, MLFLOW_PROJECT_ENTRY_POINT, MLFLOW_PARENT_RUN_ID, \
MLFLOW_PROJECT_BACKEND
from mlflow.utils import databricks_utils, file_utils
from mlflow.utils.uri import get_db_profile_from_uri, is_databricks_uri

# TODO: this should be restricted to just Git repos and not S3 and stuff like that
_GIT_URI_REGEX = re.compile(r"^[^/]*:")
Expand Down Expand Up @@ -934,8 +932,8 @@ def _get_docker_tracking_cmd_and_envs(tracking_uri):
if local_path is not None:
cmds = ["-v", "%s:%s" % (local_path, _MLFLOW_DOCKER_TRACKING_DIR_PATH)]
env_vars[tracking._TRACKING_URI_ENV_VAR] = container_tracking_uri
if tracking.utils._is_databricks_uri(tracking_uri):
db_profile = mlflow.tracking.utils.get_db_profile_from_uri(tracking_uri)
if is_databricks_uri(tracking_uri):
db_profile = get_db_profile_from_uri(tracking_uri)
config = databricks_utils.get_databricks_host_creds(db_profile)
# We set these via environment variables so that only the current profile is exposed, rather
# than all profiles in ~/.databrickscfg; maybe better would be to mount the necessary
Expand Down
9 changes: 5 additions & 4 deletions mlflow/projects/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@

from six.moves import shlex_quote

from mlflow import tracking
from mlflow.entities import RunStatus
from mlflow.exceptions import MlflowException
from mlflow.projects.submitted_run import SubmittedRun
from mlflow.utils import rest_utils, file_utils, databricks_utils
from mlflow.exceptions import ExecutionException
from mlflow import tracking
from mlflow.utils.mlflow_tags import MLFLOW_DATABRICKS_RUN_URL, MLFLOW_DATABRICKS_SHELL_JOB_ID, \
MLFLOW_DATABRICKS_SHELL_JOB_RUN_ID, MLFLOW_DATABRICKS_WEBAPP_URL
from mlflow.utils.uri import get_db_profile_from_uri, is_databricks_uri, is_http_uri
from mlflow.version import VERSION

# Base directory within driver container for storing files related to MLflow
Expand All @@ -40,8 +41,8 @@ def before_run_validations(tracking_uri, backend_config):
if backend_config is None:
raise ExecutionException("Backend spec must be provided when launching MLflow project "
"runs on Databricks.")
if not tracking.utils._is_databricks_uri(tracking_uri) and \
not tracking.utils._is_http_uri(tracking_uri):
if not is_databricks_uri(tracking_uri) and \
not is_http_uri(tracking_uri):
raise ExecutionException(
"When running on Databricks, the MLflow tracking URI must be of the form "
"'databricks' or 'databricks://profile', or a remote HTTP URI accessible to both the "
Expand Down Expand Up @@ -270,7 +271,7 @@ def run_databricks(remote_run, uri, entry_point, work_dir, parameters, experimen
Run the project at the specified URI on Databricks, returning a ``SubmittedRun`` that can be
used to query the run's status or wait for the resulting Databricks Job run to terminate.
"""
profile = tracking.utils.get_db_profile_from_uri(tracking.get_tracking_uri())
profile = get_db_profile_from_uri(tracking.get_tracking_uri())
run_id = remote_run.info.run_id
db_job_runner = DatabricksJobRunner(databricks_profile=profile)
db_run_id = db_job_runner.run_databricks(
Expand Down
3 changes: 2 additions & 1 deletion mlflow/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from mlflow.utils.environment import _mlflow_conda_env
from mlflow.utils.model_utils import _get_flavor_configuration
from mlflow.utils.file_utils import TempDir
from mlflow.utils.uri import is_local_uri

FLAVOR_NAME = "spark"

Expand Down Expand Up @@ -122,7 +123,7 @@ def log_model(spark_model, artifact_path, conda_env=None, dfs_tmpdir=None,
# writing to `file:/uri` will write to the local filesystem from each executor, which will
# be incorrect on multi-node clusters - to avoid such issues we just use the Model.log() path
# here.
if mlflow.tracking.utils._is_local_uri(run_root_artifact_uri):
if is_local_uri(run_root_artifact_uri):
return Model.log(artifact_path=artifact_path, flavor=mlflow.spark, spark_model=spark_model,
conda_env=conda_env, dfs_tmpdir=dfs_tmpdir,
sample_input=sample_input)
Expand Down
2 changes: 1 addition & 1 deletion mlflow/store/artifact/artifact_repository_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository
from mlflow.store.artifact.sftp_artifact_repo import SFTPArtifactRepository

from mlflow.utils import get_uri_scheme
from mlflow.utils.uri import get_uri_scheme


class ArtifactRepositoryRegistry:
Expand Down
1 change: 1 addition & 0 deletions mlflow/store/artifact/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,5 +119,6 @@ def download_artifacts(run_id, artifact_path, artifact_uri):
artifact_location = artifact_repo.download_artifacts(artifact_path)
print(artifact_location)


if __name__ == '__main__':
commands()
5 changes: 2 additions & 3 deletions mlflow/store/tracking/sqlalchemy_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE, RESOURCE_ALREADY_EXISTS, \
INVALID_STATE, RESOURCE_DOES_NOT_EXIST, INTERNAL_ERROR
from mlflow.tracking.utils import _is_local_uri
from mlflow.utils import extract_db_type_from_uri
from mlflow.utils.uri import is_local_uri, extract_db_type_from_uri
from mlflow.utils.file_utils import mkdir, local_file_uri_to_path
from mlflow.utils.search_utils import SearchUtils
from mlflow.utils.validation import _validate_batch_log_limits, _validate_batch_log_data, \
Expand Down Expand Up @@ -102,7 +101,7 @@ def __init__(self, db_uri, default_artifact_root):
self.ManagedSessionMaker = self._get_managed_session_maker(SessionMaker)
SqlAlchemyStore._verify_schema(self.engine)

if _is_local_uri(default_artifact_root):
if is_local_uri(default_artifact_root):
mkdir(local_file_uri_to_path(default_artifact_root))

if len(self.list_experiments()) == 0:
Expand Down
2 changes: 1 addition & 1 deletion mlflow/tracking/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import entrypoints

from mlflow.exceptions import MlflowException
from mlflow.utils import get_uri_scheme
from mlflow.utils.uri import get_uri_scheme


class TrackingStoreRegistry:
Expand Down
31 changes: 1 addition & 30 deletions mlflow/tracking/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import os
import sys

from six.moves import urllib

from mlflow.store.tracking import DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH
from mlflow.store.db.db_types import DATABASE_ENGINES
from mlflow.store.tracking.file_store import FileStore
Expand All @@ -13,6 +11,7 @@
from mlflow.utils import env, rest_utils
from mlflow.utils.file_utils import path_to_local_file_uri
from mlflow.utils.databricks_utils import get_databricks_host_creds
from mlflow.utils.uri import get_db_profile_from_uri

_TRACKING_URI_ENV_VAR = "MLFLOW_TRACKING_URI"
_LOCAL_FS_URI_PREFIX = "file:///"
Expand Down Expand Up @@ -70,23 +69,6 @@ def get_tracking_uri():
return path_to_local_file_uri(os.path.abspath(DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH))


def _is_local_uri(uri):
"""Returns true if this is a local file path (/foo or file:/foo)."""
scheme = urllib.parse.urlparse(uri).scheme
return uri != 'databricks' and (scheme == '' or scheme == 'file')


def _is_http_uri(uri):
scheme = urllib.parse.urlparse(uri).scheme
return scheme == 'http' or scheme == 'https'


def _is_databricks_uri(uri):
"""Databricks URIs look like 'databricks' (default profile) or 'databricks://profile'"""
scheme = urllib.parse.urlparse(uri).scheme
return scheme == 'databricks' or uri == 'databricks'


def _get_file_store(store_uri, **_):
return FileStore(store_uri, store_uri)

Expand All @@ -111,17 +93,6 @@ def get_default_host_creds():
return RestStore(get_default_host_creds)


def get_db_profile_from_uri(uri):
"""
Get the Databricks profile specified by the tracking URI (if any), otherwise
returns None.
"""
parsed_uri = urllib.parse.urlparse(uri)
if parsed_uri.scheme == "databricks":
return parsed_uri.netloc
return None


def _get_databricks_rest_store(store_uri, **_):
profile = get_db_profile_from_uri(store_uri)
return DatabricksRestStore(lambda: get_databricks_host_creds(profile))
Expand Down
38 changes: 0 additions & 38 deletions mlflow/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,10 @@
from sys import version_info

from six.moves import urllib

from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE

from mlflow.store.db.db_types import DATABASE_ENGINES
from mlflow.utils.annotations import deprecated, experimental, keyword_only
from mlflow.utils.validation import _validate_db_type_string

PYTHON_VERSION = "{major}.{minor}.{micro}".format(major=version_info.major,
minor=version_info.minor,
micro=version_info.micro)
_INVALID_DB_URI_MSG = "Please refer to https://mlflow.org/docs/latest/tracking.html#storage for " \
"format specifications."


def extract_db_type_from_uri(db_uri):
"""
Parse the specified DB URI to extract the database type. Confirm the database type is
supported. If a driver is specified, confirm it passes a plausible regex.
"""
scheme = urllib.parse.urlparse(db_uri).scheme
scheme_plus_count = scheme.count('+')

if scheme_plus_count == 0:
db_type = scheme
elif scheme_plus_count == 1:
db_type, _ = scheme.split('+')
else:
error_msg = "Invalid database URI: '%s'. %s" % (db_uri, _INVALID_DB_URI_MSG)
raise MlflowException(error_msg, INVALID_PARAMETER_VALUE)

_validate_db_type_string(db_type)

return db_type


def get_major_minor_py_version(py_version):
Expand Down Expand Up @@ -69,11 +39,3 @@ def get_unique_resource_id(max_length=None):
if max_length is not None:
unique_id = unique_id[:int(max_length)]
return unique_id


def get_uri_scheme(uri_or_path):
scheme = urllib.parse.urlparse(uri_or_path).scheme
if any([scheme.lower().startswith(db) for db in DATABASE_ENGINES]):
return extract_db_type_from_uri(uri_or_path)
else:
return scheme
66 changes: 66 additions & 0 deletions mlflow/utils/uri.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from six.moves import urllib

from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
from mlflow.store.db.db_types import DATABASE_ENGINES
from mlflow.utils.validation import _validate_db_type_string

_INVALID_DB_URI_MSG = "Please refer to https://mlflow.org/docs/latest/tracking.html#storage for " \
"format specifications."


def is_local_uri(uri):
"""Returns true if this is a local file path (/foo or file:/foo)."""
scheme = urllib.parse.urlparse(uri).scheme
return uri != 'databricks' and (scheme == '' or scheme == 'file')


def is_http_uri(uri):
scheme = urllib.parse.urlparse(uri).scheme
return scheme == 'http' or scheme == 'https'


def is_databricks_uri(uri):
"""Databricks URIs look like 'databricks' (default profile) or 'databricks://profile'"""
scheme = urllib.parse.urlparse(uri).scheme
return scheme == 'databricks' or uri == 'databricks'


def get_db_profile_from_uri(uri):
"""
Get the Databricks profile specified by the tracking URI (if any), otherwise
returns None.
"""
parsed_uri = urllib.parse.urlparse(uri)
if parsed_uri.scheme == "databricks":
return parsed_uri.netloc
return None


def extract_db_type_from_uri(db_uri):
"""
Parse the specified DB URI to extract the database type. Confirm the database type is
supported. If a driver is specified, confirm it passes a plausible regex.
"""
scheme = urllib.parse.urlparse(db_uri).scheme
scheme_plus_count = scheme.count('+')

if scheme_plus_count == 0:
db_type = scheme
elif scheme_plus_count == 1:
db_type, _ = scheme.split('+')
else:
error_msg = "Invalid database URI: '%s'. %s" % (db_uri, _INVALID_DB_URI_MSG)
raise MlflowException(error_msg, INVALID_PARAMETER_VALUE)

_validate_db_type_string(db_type)

return db_type


def get_uri_scheme(uri_or_path):
scheme = urllib.parse.urlparse(uri_or_path).scheme
if any([scheme.lower().startswith(db) for db in DATABASE_ENGINES]):
return extract_db_type_from_uri(uri_or_path)
else:
return scheme
3 changes: 2 additions & 1 deletion tests/store/tracking/test_sqlalchemy_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
from mlflow import entities
from mlflow.exceptions import MlflowException
from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore
from mlflow.utils import extract_db_type_from_uri, mlflow_tags
from mlflow.utils import mlflow_tags
from mlflow.utils.file_utils import TempDir
from mlflow.utils.uri import extract_db_type_from_uri
from tests.resources.db.initial_models import Base as InitialBase
from tests.integration.utils import invoke_cli_runner

Expand Down
Loading

0 comments on commit 44f731c

Please sign in to comment.