diff --git a/mlflow/exceptions.py b/mlflow/exceptions.py index 6cf1696a59886..f51ae186ce685 100644 --- a/mlflow/exceptions.py +++ b/mlflow/exceptions.py @@ -1,5 +1,31 @@ +import json + +from mlflow.protos.databricks_pb2 import INTERNAL_ERROR, ErrorCode + + class MlflowException(Exception): """Base exception in MLflow.""" + def __init__(self, message, error_code=INTERNAL_ERROR): + try: + self.error_code = ErrorCode.Name(error_code) + except (ValueError, TypeError): + self.error_code = ErrorCode.Name(INTERNAL_ERROR) + self.message = message + super(MlflowException, self).__init__(message) + + def serialize_as_json(self): + return json.dumps({'error_code': self.error_code, 'message': self.message}) + + +class RestException(MlflowException): + """Exception thrown on non 200-level responses from the REST API""" + def __init__(self, json): + error_code = json['error_code'] + message = error_code + if 'message' in json: + message = "%s: %s" % (error_code, json['message']) + super(RestException, self).__init__(message, error_code=error_code) + self.json = json class IllegalArtifactPathError(MlflowException): diff --git a/mlflow/server/handlers.py b/mlflow/server/handlers.py index f127defbaea46..807466057e60b 100644 --- a/mlflow/server/handlers.py +++ b/mlflow/server/handlers.py @@ -4,10 +4,12 @@ import re import six +from functools import wraps from flask import Response, request, send_file from querystring_parser import parser from mlflow.entities import Metric, Param, RunTag, ViewType +from mlflow.exceptions import MlflowException from mlflow.protos import databricks_pb2 from mlflow.protos.service_pb2 import CreateExperiment, MlflowService, GetExperiment, \ GetRun, SearchRuns, ListArtifacts, GetMetricHistory, CreateRun, \ @@ -59,6 +61,19 @@ def _get_request_message(request_message, flask_request=request): return request_message +def catch_mlflow_exception(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except MlflowException as e: + response = Response(mimetype='application/json') + response.set_data(e.serialize_as_json()) + response.status_code = 500 + return response + return wrapper + + def get_handler(request_class): """ :param request_class: The type of protobuf message @@ -71,6 +86,7 @@ def get_handler(request_class): 'csv', 'tsv', 'md', 'rst', 'MLmodel', 'MLproject'] +@catch_mlflow_exception def get_artifact_handler(): query_string = request.query_string.decode('utf-8') request_dict = parser.parse(query_string, normalized=True) @@ -89,6 +105,7 @@ def _not_implemented(): return response +@catch_mlflow_exception def _create_experiment(): request_message = _get_request_message(CreateExperiment()) experiment_id = _get_store().create_experiment(request_message.name, @@ -100,6 +117,7 @@ def _create_experiment(): return response +@catch_mlflow_exception def _get_experiment(): request_message = _get_request_message(GetExperiment()) response_message = GetExperiment.Response() @@ -113,6 +131,7 @@ def _get_experiment(): return response +@catch_mlflow_exception def _delete_experiment(): request_message = _get_request_message(DeleteExperiment()) _get_store().delete_experiment(request_message.experiment_id) @@ -122,6 +141,7 @@ def _delete_experiment(): return response +@catch_mlflow_exception def _restore_experiment(): request_message = _get_request_message(RestoreExperiment()) _get_store().restore_experiment(request_message.experiment_id) @@ -131,6 +151,7 @@ def _restore_experiment(): return response +@catch_mlflow_exception def _create_run(): request_message = _get_request_message(CreateRun()) @@ -154,6 +175,7 @@ def _create_run(): return response +@catch_mlflow_exception def _update_run(): request_message = _get_request_message(UpdateRun()) updated_info = _get_store().update_run_info(request_message.run_uuid, request_message.status, @@ -164,6 +186,7 @@ def _update_run(): return response +@catch_mlflow_exception def _delete_run(): request_message = _get_request_message(DeleteRun()) _get_store().delete_run(request_message.run_id) @@ -173,6 +196,7 @@ def _delete_run(): return response +@catch_mlflow_exception def _restore_run(): request_message = _get_request_message(RestoreRun()) _get_store().restore_run(request_message.run_id) @@ -182,6 +206,7 @@ def _restore_run(): return response +@catch_mlflow_exception def _log_metric(): request_message = _get_request_message(LogMetric()) metric = Metric(request_message.key, request_message.value, request_message.timestamp) @@ -192,6 +217,7 @@ def _log_metric(): return response +@catch_mlflow_exception def _log_param(): request_message = _get_request_message(LogParam()) param = Param(request_message.key, request_message.value) @@ -202,6 +228,7 @@ def _log_param(): return response +@catch_mlflow_exception def _set_tag(): request_message = _get_request_message(SetTag()) tag = RunTag(request_message.key, request_message.value) @@ -212,6 +239,7 @@ def _set_tag(): return response +@catch_mlflow_exception def _get_run(): request_message = _get_request_message(GetRun()) response_message = GetRun.Response() @@ -221,6 +249,7 @@ def _get_run(): return response +@catch_mlflow_exception def _search_runs(): request_message = _get_request_message(SearchRuns()) response_message = SearchRuns.Response() @@ -236,6 +265,7 @@ def _search_runs(): return response +@catch_mlflow_exception def _list_artifacts(): request_message = _get_request_message(ListArtifacts()) response_message = ListArtifacts.Response() @@ -252,6 +282,7 @@ def _list_artifacts(): return response +@catch_mlflow_exception def _get_metric_history(): request_message = _get_request_message(GetMetricHistory()) response_message = GetMetricHistory.Response() @@ -263,6 +294,7 @@ def _get_metric_history(): return response +@catch_mlflow_exception def _get_metric(): request_message = _get_request_message(GetMetric()) response_message = GetMetric.Response() @@ -273,6 +305,7 @@ def _get_metric(): return response +@catch_mlflow_exception def _get_param(): request_message = _get_request_message(GetParam()) response_message = GetParam.Response() @@ -283,6 +316,7 @@ def _get_param(): return response +@catch_mlflow_exception def _list_experiments(): request_message = _get_request_message(ListExperiments()) experiment_entities = _get_store().list_experiments(request_message.view_type) @@ -293,6 +327,7 @@ def _list_experiments(): return response +@catch_mlflow_exception def _get_artifact_repo(run): store = _get_store() if run.info.artifact_uri: diff --git a/mlflow/store/rest_store.py b/mlflow/store/rest_store.py index 56be6116444dc..4b789c06e4e59 100644 --- a/mlflow/store/rest_store.py +++ b/mlflow/store/rest_store.py @@ -1,5 +1,6 @@ import json +from mlflow.exceptions import RestException from mlflow.store.abstract_store import AbstractStore from mlflow.entities import Experiment, Run, RunInfo, RunTag, Param, Metric, ViewType @@ -35,16 +36,6 @@ def _api_method_to_info(): _METHOD_TO_INFO = _api_method_to_info() -class RestException(Exception): - """Exception thrown on 400-level errors from the REST API""" - def __init__(self, json): - message = json['error_code'] - if 'message' in json: - message = "%s: %s" % (message, json['message']) - super(RestException, self).__init__(message) - self.json = json - - class RestStore(AbstractStore): """ Client for a remote tracking server accessed via REST API calls diff --git a/mlflow/utils/validation.py b/mlflow/utils/validation.py index 5e4a249150df8..0869a9f967848 100644 --- a/mlflow/utils/validation.py +++ b/mlflow/utils/validation.py @@ -4,6 +4,9 @@ import os.path import re +from mlflow.exceptions import MlflowException +from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE + _VALID_PARAM_AND_METRIC_NAMES = re.compile(r"^[/\w.\- ]*$") # Regex for valid run IDs: must be a 32-character hex string. @@ -55,4 +58,4 @@ def _validate_tag_name(name): def _validate_run_id(run_id): """Check that `run_id` is a valid run ID and raise an exception if it isn't.""" if _RUN_ID_REGEX.match(run_id) is None: - raise Exception("Invalid run ID: '%s'" % run_id) + raise MlflowException("Invalid run ID: '%s'" % run_id, error_code=INVALID_PARAMETER_VALUE) diff --git a/tests/server/test_handlers.py b/tests/server/test_handlers.py index 5949d5df83ec2..370b189e99b3b 100644 --- a/tests/server/test_handlers.py +++ b/tests/server/test_handlers.py @@ -1,9 +1,13 @@ +import json + import mock import pytest from mlflow.entities import ViewType +from mlflow.exceptions import MlflowException +from mlflow.protos.databricks_pb2 import INTERNAL_ERROR, ErrorCode from mlflow.server.handlers import get_endpoints, _create_experiment, _get_request_message, \ - _search_runs + _search_runs, catch_mlflow_exception from mlflow.protos.service_pb2 import CreateExperiment, SearchRuns @@ -72,3 +76,16 @@ def test_search_runs_default_view_type(mock_get_request_message, mock_store): _search_runs() args, _ = mock_store.search_runs.call_args assert args[2] == ViewType.ACTIVE_ONLY + + +def test_catch_mlflow_exception(): + @catch_mlflow_exception + def test_handler(): + raise MlflowException('test error', error_code=INTERNAL_ERROR) + + # pylint: disable=assignment-from-no-return + response = test_handler() + json_response = json.loads(response.get_data()) + assert response.status_code == 500 + assert json_response['error_code'] == ErrorCode.Name(INTERNAL_ERROR) + assert json_response['message'] == 'test error' diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py new file mode 100644 index 0000000000000..75bdb1d6d9106 --- /dev/null +++ b/tests/test_exceptions.py @@ -0,0 +1,19 @@ +import json + +from mlflow.exceptions import MlflowException +from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE + + +class TestMlflowException(object): + def test_error_code_constructor(self): + assert MlflowException('test', error_code=INVALID_PARAMETER_VALUE).error_code == \ + 'INVALID_PARAMETER_VALUE' + + def test_default_error_code(self): + assert MlflowException('test').error_code == 'INTERNAL_ERROR' + + def test_serialize_to_json(self): + mlflow_exception = MlflowException('test') + deserialized = json.loads(mlflow_exception.serialize_as_json()) + assert deserialized['message'] == 'test' + assert deserialized['error_code'] == 'INTERNAL_ERROR'