Skip to content

Commit

Permalink
Add kwargs parameter to pyfunc 'add_to_model' (mlflow#791)
Browse files Browse the repository at this point in the history
  • Loading branch information
dbczumar authored Jan 7, 2019
1 parent 0811ae7 commit 9fa1c99
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
14 changes: 12 additions & 2 deletions mlflow/pyfunc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
<data>: data packaged with the model (specified in the MLmodel file)
<env>: Conda environment definition (specified in the MLmodel file)
* The directory structure may contain additional contents that can be referenced by the
``MLmodel`` configuration.
A Python model contains an ``MLmodel`` file in "python_function" format in its root with the
following parameters:
Expand All @@ -48,6 +51,9 @@
Relative path to an exported Conda environment. If present this environment
should be activated prior to running the model.
- **Optionally, any additional parameters necessary for interpreting the serialized model in pyfunc
format.**
.. rubric:: Example
>>> tree example/sklearn_iris/mlruns/run1/outputs/linear-lr
Expand Down Expand Up @@ -81,6 +87,7 @@
import pandas
import shutil
import sys
from copy import deepcopy

from mlflow.tracking.fluent import active_run, log_artifacts
from mlflow import tracking
Expand All @@ -100,7 +107,7 @@
_logger = logging.getLogger(__name__)


def add_to_model(model, loader_module, data=None, code=None, env=None):
def add_to_model(model, loader_module, data=None, code=None, env=None, **kwargs):
"""
Add a pyfunc spec to the model configuration.
Expand All @@ -117,9 +124,12 @@ def add_to_model(model, loader_module, data=None, code=None, env=None):
:param data: Path to the model data.
:param code: Path to the code dependencies.
:param env: Conda environment.
:param kwargs: Additional key-value pairs to include in the pyfunc flavor specification.
Values must be YAML-serializable.
:return: Updated model configuration.
"""
parms = {MAIN: loader_module}
parms = deepcopy(kwargs)
parms[MAIN] = loader_module
parms[PY_VERSION] = PYTHON_VERSION
if code:
parms[CODE] = code
Expand Down
17 changes: 17 additions & 0 deletions tests/pyfunc/test_model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,23 @@ def test_model_log(self):
# Remove the log directory in order to avoid adding new tests to pytest...
shutil.rmtree(tracking_dir)

def test_add_to_model_adds_specified_kwargs_to_mlmodel_configuration(self):
custom_kwargs = {
"key1": "value1",
"key2": 20,
"key3": range(10),
}
model_config = Model()
pyfunc.add_to_model(model=model_config,
loader_module=os.path.basename(__file__)[:-3],
data="data",
code="code",
env=None,
**custom_kwargs)

assert pyfunc.FLAVOR_NAME in model_config.flavors
assert all([item in model_config.flavors[pyfunc.FLAVOR_NAME] for item in custom_kwargs])

def _create_conda_env_file(self, tmp):
conda_env_path = tmp.path("conda.yml")
with open(conda_env_path, "w") as f:
Expand Down

0 comments on commit 9fa1c99

Please sign in to comment.