Skip to content

Commit

Permalink
Don't create default experiment if the FileStore root is already pres…
Browse files Browse the repository at this point in the history
…ent (mlflow#604)

* init

* wip

* Revert "wip"

This reverts commit 426ece7.

* fix some tests

* revert some changes

* fix more tests

* fix another test

* fix more tests
  • Loading branch information
andrewmchen authored Oct 9, 2018
1 parent 293903b commit 6598c62
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ private MlflowClient startServerProcess() throws IOException {
int freePort = getFreePort();
String bindAddress = "127.0.0.1";
pb.command("mlflow", "server", "--host", bindAddress, "--port", "" + freePort,
"--file-store", tempDir.toString(), "--workers", "1");
"--file-store", tempDir.resolve("mlruns").toString(), "--workers", "1");
serverProcess = pb.start();

// NB: We cannot use pb.inheritIO() because that interacts poorly with the Maven
Expand Down
9 changes: 4 additions & 5 deletions mlflow/store/file_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,13 @@ def __init__(self, root_directory=None, artifact_root_uri=None):
# Create root directory if needed
if not exists(self.root_directory):
mkdir(self.root_directory)
# Create trash folder if needed
if not exists(self.trash_folder):
mkdir(self.trash_folder)
# Create default experiment if needed
if not self._has_experiment(experiment_id=Experiment.DEFAULT_EXPERIMENT_ID):
print("here")
self._create_experiment_with_id(name="Default",
experiment_id=Experiment.DEFAULT_EXPERIMENT_ID,
artifact_uri=None)
# Create trash folder if needed
if not exists(self.trash_folder):
mkdir(self.trash_folder)

def _check_root_dir(self):
"""
Expand Down
16 changes: 7 additions & 9 deletions tests/projects/test_projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import mlflow
from mlflow.entities import RunStatus, ViewType
from mlflow.exceptions import ExecutionException
from mlflow.store.file_store import FileStore
from mlflow.utils import env

from tests.projects.utils import TEST_PROJECT_DIR, TEST_PROJECT_NAME, GIT_PROJECT_URI, \
Expand Down Expand Up @@ -138,8 +137,7 @@ def test_is_valid_branch_name(local_git_repo):

@pytest.mark.parametrize("use_start_run", map(str, [0, 1]))
@pytest.mark.parametrize("version", [None, "master", "git-commit"])
def test_run_local_git_repo(tmpdir,
local_git_repo,
def test_run_local_git_repo(local_git_repo,
local_git_repo_uri,
tracking_uri_mock, # pylint: disable=unused-argument
use_start_run,
Expand All @@ -163,13 +161,13 @@ def test_run_local_git_repo(tmpdir,
validate_exit_status(submitted_run.get_status(), RunStatus.FINISHED)
# Validate run contents in the FileStore
run_uuid = submitted_run.run_id
store = FileStore(tmpdir.strpath)
run_infos = store.list_run_infos(experiment_id=0, run_view_type=ViewType.ACTIVE_ONLY)
mlflow_service = mlflow.tracking.MlflowClient()
run_infos = mlflow_service.list_run_infos(experiment_id=0, run_view_type=ViewType.ACTIVE_ONLY)
assert "file:" in run_infos[0].source_name
assert len(run_infos) == 1
store_run_uuid = run_infos[0].run_uuid
assert run_uuid == store_run_uuid
run = store.get_run(run_uuid)
run = mlflow_service.get_run(run_uuid)
expected_params = {"use_start_run": use_start_run}
assert run.info.status == RunStatus.FINISHED
assert len(run.data.params) == len(expected_params)
Expand Down Expand Up @@ -211,12 +209,12 @@ def test_run(tmpdir, tracking_uri_mock, use_start_run): # pylint: disable=unuse
validate_exit_status(submitted_run.get_status(), RunStatus.FINISHED)
# Validate run contents in the FileStore
run_uuid = submitted_run.run_id
store = FileStore(tmpdir.strpath)
run_infos = store.list_run_infos(experiment_id=0, run_view_type=ViewType.ACTIVE_ONLY)
mlflow_service = mlflow.tracking.MlflowClient()
run_infos = mlflow_service.list_run_infos(experiment_id=0, run_view_type=ViewType.ACTIVE_ONLY)
assert len(run_infos) == 1
store_run_uuid = run_infos[0].run_uuid
assert run_uuid == store_run_uuid
run = store.get_run(run_uuid)
run = mlflow_service.get_run(run_uuid)
expected_params = {"use_start_run": use_start_run}
assert run.info.status == RunStatus.FINISHED
assert len(run.data.params) == len(expected_params)
Expand Down
2 changes: 1 addition & 1 deletion tests/projects/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def assert_dirs_equal(expected, actual):
@pytest.fixture()
def tracking_uri_mock(tmpdir):
try:
mlflow.set_tracking_uri(tmpdir.strpath)
mlflow.set_tracking_uri(os.path.join(tmpdir.strpath, 'mlruns'))
yield tmpdir
finally:
mlflow.set_tracking_uri(None)
2 changes: 1 addition & 1 deletion tests/spark/test_spark_model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def test_sparkml_model_log(tmpdir, spark_model_iris):
for dfs_tmp_dir in [None, os.path.join(str(tmpdir), "test")]:
print("should_start_run =", should_start_run, "dfs_tmp_dir =", dfs_tmp_dir)
try:
tracking_dir = os.path.abspath(str(tmpdir.mkdir("mlruns")))
tracking_dir = os.path.abspath(str(tmpdir.join("mlruns")))
mlflow.set_tracking_uri("file://%s" % tracking_dir)
if should_start_run:
mlflow.start_run()
Expand Down
6 changes: 6 additions & 0 deletions tests/store/test_file_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,3 +447,9 @@ def test_create_run_with_parent_id(self):
'entry_point_name', 0, None, [], 'test_parent_run_id')
assert any([t.key == MLFLOW_PARENT_RUN_ID and t.value == 'test_parent_run_id'
for t in fs.get_all_tags(run.info.run_uuid)])

def test_default_experiment_initialization(self):
fs = FileStore(self.test_root)
fs.delete_experiment(Experiment.DEFAULT_EXPERIMENT_ID)
fs = FileStore(self.test_root)
assert fs.get_experiment(0).lifecycle_stage == Experiment.DELETED_LIFECYCLE

0 comments on commit 6598c62

Please sign in to comment.