Skip to content

Commit

Permalink
[Runs URI][Python] changes in "models" methods (mlflow#1190)
Browse files Browse the repository at this point in the history
* Initial changes to load_model and tests

* UDF and CLI methods

* Move S3 and boto fixtures into helper functions

* Revert space diff

* Test fix

* Lint

* Lint2

* Remote URI tests for h2o, keras, sklearn, pytorch

* Remote URI tests for tensorflow and spark

* Test cases fixes

* Lint

* Pytorch tests fix

* Test fixes

* Param fixes

* Address subset of comments

* Address remaining comments

* Lint

* Remove remote URI test

* Remove unused variable

* Param fix in Keras

* Arg fix for pytorch tests
  • Loading branch information
dbczumar authored May 7, 2019
1 parent 2ee4c64 commit 1b501a9
Show file tree
Hide file tree
Showing 20 changed files with 496 additions and 240 deletions.
31 changes: 21 additions & 10 deletions mlflow/h2o.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@

import h2o

import mlflow
from mlflow import pyfunc
from mlflow.models import Model
import mlflow.tracking
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
from mlflow.utils.environment import _mlflow_conda_env
from mlflow.utils.model_utils import _get_flavor_configuration

Expand Down Expand Up @@ -148,24 +149,34 @@ def predict(self, dataframe):
def _load_pyfunc(path):
"""
Load PyFunc implementation. Called by ``pyfunc.load_pyfunc``.
:param path: Local filesystem path to the MLflow Model with the ``h2o`` flavor.
"""
return _H2OModelWrapper(_load_model(path, init=True))


def load_model(path, run_id=None):
def load_model(model_uri):
"""
Load an H2O model from a local file (if ``run_id`` is ``None``) or a run.
This function expects there is an H2O instance initialised with ``h2o.init``.
:param path: Local filesystem path or run-relative artifact path to the model saved
by :py:func:`mlflow.h2o.save_model`.
:param run_id: Run ID. If provided, combined with ``path`` to identify the model.
:param model_uri: The location, in URI format, of the MLflow model, for example:
- ``/Users/me/path/to/local/model``
- ``relative/path/to/local/model``
- ``s3://my_bucket/path/to/model``
- ``runs:/<mlflow_run_id>/run-relative/path/to/model``
For more information about supported URI schemes, see the
`Artifacts Documentation <https://www.mlflow.org/docs/latest/tracking.html#
supported-artifact-stores>`_.
:return: An `H2OEstimator model object
<http://docs.h2o.ai/h2o/latest-stable/h2o-py/docs/intro.html#models>`_.
"""
if run_id is not None:
path = mlflow.tracking.artifact_utils._get_model_log_dir(model_name=path, run_id=run_id)
path = os.path.abspath(path)
flavor_conf = _get_flavor_configuration(model_path=path, flavor_name=FLAVOR_NAME)
local_model_path = _download_artifact_from_uri(artifact_uri=model_uri)
flavor_conf = _get_flavor_configuration(model_path=local_model_path, flavor_name=FLAVOR_NAME)
# Flavor configurations for models saved in MLflow version <= 0.8.0 may not contain a
# `data` key; in this case, we assume the model artifact path to be `model.h2o`
h2o_model_file_path = os.path.join(path, flavor_conf.get("data", "model.h2o"))
h2o_model_file_path = os.path.join(local_model_path, flavor_conf.get("data", "model.h2o"))
return _load_model(path=h2o_model_file_path)
32 changes: 21 additions & 11 deletions mlflow/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from mlflow import pyfunc
from mlflow.models import Model
import mlflow.tracking
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
from mlflow.utils.environment import _mlflow_conda_env
from mlflow.utils.model_utils import _get_flavor_configuration

Expand Down Expand Up @@ -168,9 +169,11 @@ def predict(self, dataframe):
return predicted


def _load_pyfunc(model_file):
def _load_pyfunc(path):
"""
Load PyFunc implementation. Called by ``pyfunc.load_pyfunc``.
:param path: Local filesystem path to the MLflow Model with the ``keras`` flavor.
"""
if K._BACKEND == 'tensorflow':
graph = tf.Graph()
Expand All @@ -181,29 +184,36 @@ def _load_pyfunc(model_file):
with graph.as_default():
with sess.as_default(): # pylint:disable=not-context-manager
K.set_learning_phase(0)
m = _load_model(model_file)
m = _load_model(path)
return _KerasModelWrapper(m, graph, sess)
else:
raise Exception("Unsupported backend '%s'" % K._BACKEND)


def load_model(path, run_id=None):
def load_model(model_uri):
"""
Load a Keras model from a local file (if ``run_id`` is None) or a run.
:param path: Local filesystem path or run-relative artifact path to the model saved
by :py:func:`mlflow.keras.log_model`.
:param run_id: Run ID. If provided, combined with ``path`` to identify the model.
:param model_uri: The location, in URI format, of the MLflow model, for example:
- ``/Users/me/path/to/local/model``
- ``relative/path/to/local/model``
- ``s3://my_bucket/path/to/model``
- ``runs:/<mlflow_run_id>/run-relative/path/to/model``
For more information about supported URI schemes, see the
`Artifacts Documentation <https://www.mlflow.org/docs/latest/tracking.html#
supported-artifact-stores>`_.
:return: A Keras model instance.
>>> # Load persisted model as a Keras model or as a PyFunc, call predict() on a Pandas DataFrame
>>> keras_model = mlflow.keras.load_model("models", run_id="96771d893a5e46159d9f3b49bf9013e2")
>>> predictions = keras_model.predict(x_test)
"""
if run_id is not None:
path = mlflow.tracking.artifact_utils._get_model_log_dir(model_name=path, run_id=run_id)
path = os.path.abspath(path)
flavor_conf = _get_flavor_configuration(model_path=path, flavor_name=FLAVOR_NAME)
local_model_path = _download_artifact_from_uri(artifact_uri=model_uri)
flavor_conf = _get_flavor_configuration(model_path=local_model_path, flavor_name=FLAVOR_NAME)
# Flavor configurations for models saved in MLflow version <= 0.8.0 may not contain a
# `data` key; in this case, we assume the model artifact path to be `model.h5`
keras_model_artifacts_path = os.path.join(path, flavor_conf.get("data", "model.h5"))
keras_model_artifacts_path = os.path.join(local_model_path, flavor_conf.get("data", "model.h5"))
return _load_model(model_file=keras_model_artifacts_path)
51 changes: 32 additions & 19 deletions mlflow/pyfunc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@
from mlflow.models import Model
from mlflow.pyfunc.model import PythonModel, PythonModelContext,\
DEFAULT_CONDA_ENV
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
from mlflow.utils import PYTHON_VERSION, get_major_minor_py_version
from mlflow.utils.file_utils import TempDir, _copy_file_or_tree
from mlflow.utils.model_utils import _get_flavor_configuration
Expand Down Expand Up @@ -254,37 +255,43 @@ def add_to_model(model, loader_module, data=None, code=None, env=None, **kwargs)
return model.add_flavor(FLAVOR_NAME, **parms)


def _load_model_env(path, run_id=None):
def _load_model_env(path):
"""
Get ENV file string from a model configuration stored in Python Function format.
Returned value is a model-relative path to a Conda Environment file,
or None if none was specified at model save time
"""
if run_id is not None:
path = tracking.artifact_utils._get_model_log_dir(path, run_id)
return _get_flavor_configuration(model_path=path, flavor_name=FLAVOR_NAME).get(ENV, None)


def load_pyfunc(path, run_id=None, suppress_warnings=False):
def load_pyfunc(model_uri, suppress_warnings=False):
"""
Load a model stored in Python function format.
:param path: Path to the model.
:param run_id: MLflow run ID.
:param model_uri: The location, in URI format, of the MLflow model, for example:
- ``/Users/me/path/to/local/model``
- ``relative/path/to/local/model``
- ``s3://my_bucket/path/to/model``
- ``runs:/<mlflow_run_id>/run-relative/path/to/model``
For more information about supported URI schemes, see the
`Artifacts Documentation <https://www.mlflow.org/docs/latest/tracking.html#
supported-artifact-stores>`_.
:param suppress_warnings: If True, non-fatal warning messages associated with the model
loading process will be suppressed. If False, these warning messages
will be emitted.
"""
if run_id is not None:
path = tracking.artifact_utils._get_model_log_dir(path, run_id)
conf = _get_flavor_configuration(model_path=path, flavor_name=FLAVOR_NAME)
local_model_path = _download_artifact_from_uri(artifact_uri=model_uri)
conf = _get_flavor_configuration(model_path=local_model_path, flavor_name=FLAVOR_NAME)
model_py_version = conf.get(PY_VERSION)
if not suppress_warnings:
_warn_potentially_incompatible_py_version_if_necessary(model_py_version=model_py_version)
if CODE in conf and conf[CODE]:
code_path = os.path.join(path, conf[CODE])
code_path = os.path.join(local_model_path, conf[CODE])
mlflow.pyfunc.utils._add_code_to_system_path(code_path=code_path)
data_path = os.path.join(path, conf[DATA]) if (DATA in conf) else path
data_path = os.path.join(local_model_path, conf[DATA]) if (DATA in conf) else local_model_path
return importlib.import_module(conf[MAIN])._load_pyfunc(data_path)


Expand All @@ -307,7 +314,7 @@ def _warn_potentially_incompatible_py_version_if_necessary(model_py_version=None
model_py_version, PYTHON_VERSION)


def spark_udf(spark, path, run_id=None, result_type="double"):
def spark_udf(spark, model_uri, result_type="double"):
"""
A Spark UDF that can be used to invoke the Python function formatted model.
Expand All @@ -323,9 +330,17 @@ def spark_udf(spark, path, run_id=None, result_type="double"):
>>> df.withColumn("prediction", predict("name", "age")).show()
:param spark: A SparkSession object.
:param path: A path containing a :py:mod:`mlflow.pyfunc` model.
:param run_id: ID of the run that produced this model. If provided, ``run_id`` is used to
retrieve the model logged with MLflow.
:param model_uri: The location, in URI format, of the MLflow model with the
:py:mod:`mlflow.pyfunc` flavor, for example:
- ``/Users/me/path/to/local/model``
- ``relative/path/to/local/model``
- ``s3://my_bucket/path/to/model``
- ``runs:/<mlflow_run_id>/run-relative/path/to/model``
For more information about supported URI schemes, see the
`Artifacts Documentation <https://www.mlflow.org/docs/latest/tracking.html#
supported-artifact-stores>`_.
:param result_type: the return type of the user-defined function. The value can be either a
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
Only a primitive type or an array (pyspark.sql.types.ArrayType) of primitive
Expand Down Expand Up @@ -372,10 +387,8 @@ def spark_udf(spark, path, run_id=None, result_type="double"):
"of the following types types: {}".format(str(elem_type), str(supported_types)),
error_code=INVALID_PARAMETER_VALUE)

if run_id:
path = tracking.artifact_utils._get_model_log_dir(path, run_id)

archive_path = SparkModelCache.add_local_model(spark, path)
local_model_path = _download_artifact_from_uri(artifact_uri=model_uri)
archive_path = SparkModelCache.add_local_model(spark, local_model_path)

def predict(*args):
model = SparkModelCache.get_or_load(archive_path)
Expand Down
36 changes: 13 additions & 23 deletions mlflow/pyfunc/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from mlflow.projects import _get_conda_bin_executable, _get_or_create_conda_env
from mlflow.pyfunc import load_pyfunc, scoring_server, _load_model_env
from mlflow.tracking.artifact_utils import _get_model_log_dir
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
from mlflow.utils import cli_args


Expand Down Expand Up @@ -45,58 +45,48 @@ def commands():


@commands.command("serve")
@cli_args.MODEL_PATH
@cli_args.RUN_ID
@cli_args.MODEL_URI
@click.option("--port", "-p", default=5000, help="Server port. [default: 5000]")
@click.option("--host", "-h", default="127.0.0.1", help="Server host. [default: 127.0.0.1]")
@cli_args.NO_CONDA
def serve(model_path, run_id, port, host, no_conda):
def serve(model_uri, port, host, no_conda):
"""
Serve a pyfunc model saved with MLflow by launching a webserver on the specified
host and port. For information about the input data formats accepted by the webserver,
see the following documentation:
https://www.mlflow.org/docs/latest/models.html#pyfunc-deployment.
If a ``run_id`` is specified, ``model-path`` is treated as an artifact path within that run;
otherwise it is treated as a local path.
"""
if run_id:
model_path = _get_model_log_dir(model_path, run_id)
local_model_path = _download_artifact_from_uri(artifact_uri=model_uri)

model_env_file = _load_model_env(model_path)
model_env_file = _load_model_env(path=local_model_path)
if not no_conda and model_env_file is not None:
conda_env_path = os.path.join(model_path, model_env_file)
conda_env_path = os.path.join(local_model_path, model_env_file)
return _rerun_in_conda(conda_env_path)

app = scoring_server.init(load_pyfunc(model_path))
app = scoring_server.init(load_pyfunc(local_model_path))
app.run(port=port, host=host)


@commands.command("predict")
@cli_args.MODEL_PATH
@cli_args.RUN_ID
@cli_args.MODEL_URI
@click.option("--input-path", "-i", help="CSV containing pandas DataFrame to predict against.",
required=True)
@click.option("--output-path", "-o", help="File to output results to as CSV file." +
" If not provided, output to stdout.")
@cli_args.NO_CONDA
def predict(model_path, run_id, input_path, output_path, no_conda):
def predict(model_uri, input_path, output_path, no_conda):
"""
Load a pandas DataFrame and runs a python_function model saved with MLflow against it.
Return the prediction results as a CSV-formatted pandas DataFrame.
If a ``run-id`` is specified, ``model-path`` is treated as an artifact path within that run;
otherwise it is treated as a local path.
"""
if run_id:
model_path = _get_model_log_dir(model_path, run_id)
local_model_path = _download_artifact_from_uri(artifact_uri=model_uri)

model_env_file = _load_model_env(model_path)
model_env_file = _load_model_env(path=local_model_path)
if not no_conda and model_env_file is not None:
conda_env_path = os.path.join(model_path, model_env_file)
conda_env_path = os.path.join(local_model_path, model_env_file)
return _rerun_in_conda(conda_env_path)

model = load_pyfunc(model_path)
model = load_pyfunc(local_model_path)
df = pandas.read_csv(input_path)
result = model.predict(df)
out_stream = sys.stdout
Expand Down
37 changes: 24 additions & 13 deletions mlflow/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@
import torch
import torchvision

import mlflow
import mlflow.pyfunc.utils as pyfunc_utils
from mlflow import pyfunc
from mlflow.exceptions import MlflowException
from mlflow.models import Model
from mlflow.protos.databricks_pb2 import RESOURCE_DOES_NOT_EXIST
from mlflow.pytorch import pickle_module as mlflow_pytorch_pickle_module
import mlflow.tracking
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
from mlflow.utils.environment import _mlflow_conda_env
from mlflow.utils.file_utils import _copy_file_or_tree
from mlflow.utils.model_utils import _get_flavor_configuration
Expand Down Expand Up @@ -291,14 +292,23 @@ def _load_model(path, **kwargs):
return torch.load(model_path, **kwargs)


def load_model(path, run_id=None, **kwargs):
def load_model(model_uri, **kwargs):
"""
Load a PyTorch model from a local file (if ``run_id`` is ``None``) or a run.
:param path: Local filesystem path or run-relative artifact path to the model saved
by :py:func:`mlflow.pytorch.log_model`.
:param run_id: Run ID. If provided, combined with ``path`` to identify the model.
:param model_uri: The location, in URI format, of the MLflow model, for example:
- ``/Users/me/path/to/local/model``
- ``relative/path/to/local/model``
- ``s3://my_bucket/path/to/model``
- ``runs:/<mlflow_run_id>/run-relative/path/to/model``
For more information about supported URI schemes, see the
`Artifacts Documentation <https://www.mlflow.org/docs/latest/tracking.html#
supported-artifact-stores>`_.
:param kwargs: kwargs to pass to ``torch.load`` method.
:return: A PyTorch model.
>>> import torch
>>> import mlflow
Expand All @@ -309,30 +319,31 @@ def load_model(path, run_id=None, **kwargs):
>>> pytorch_model = mlflow.pytorch.load_model(model_path_dir, run_id)
>>> y_pred = pytorch_model(x_new_data)
"""
if run_id is not None:
path = mlflow.tracking.artifact_utils._get_model_log_dir(model_name=path, run_id=run_id)
path = os.path.abspath(path)

local_model_path = _download_artifact_from_uri(artifact_uri=model_uri)
try:
pyfunc_conf = _get_flavor_configuration(model_path=path, flavor_name=pyfunc.FLAVOR_NAME)
pyfunc_conf = _get_flavor_configuration(
model_path=local_model_path, flavor_name=pyfunc.FLAVOR_NAME)
except MlflowException:
pyfunc_conf = {}
code_subpath = pyfunc_conf.get(pyfunc.CODE)
if code_subpath is not None:
pyfunc_utils._add_code_to_system_path(code_path=os.path.join(path, code_subpath))
pyfunc_utils._add_code_to_system_path(
code_path=os.path.join(local_model_path, code_subpath))

pytorch_conf = _get_flavor_configuration(model_path=path, flavor_name=FLAVOR_NAME)
pytorch_conf = _get_flavor_configuration(model_path=local_model_path, flavor_name=FLAVOR_NAME)
if torch.__version__ != pytorch_conf["pytorch_version"]:
_logger.warning(
"Stored model version '%s' does not match installed PyTorch version '%s'",
pytorch_conf["pytorch_version"], torch.__version__)
torch_model_artifacts_path = os.path.join(path, pytorch_conf['model_data'])
torch_model_artifacts_path = os.path.join(local_model_path, pytorch_conf['model_data'])
return _load_model(path=torch_model_artifacts_path, **kwargs)


def _load_pyfunc(path, **kwargs):
"""
Load PyFunc implementation. Called by ``pyfunc.load_pyfunc``.
:param path: Local filesystem path to the MLflow Model with the ``pytorch`` flavor.
"""
return _PyTorchWrapper(_load_model(path, **kwargs))

Expand Down
Loading

0 comments on commit 1b501a9

Please sign in to comment.