Skip to content

Commit

Permalink
[Model Control Plane] link cached artifacts (zenml-io#1946)
Browse files Browse the repository at this point in the history
* link outputs of cached steps

* refactor

* improve test case

* port stability fix here

* linking from cached step with same id

* stabilize tests

* restore tests

* fix bug in implementation
  • Loading branch information
avishniakov authored Oct 15, 2023
1 parent 82092e0 commit 85db1d5
Show file tree
Hide file tree
Showing 5 changed files with 272 additions and 60 deletions.
2 changes: 1 addition & 1 deletion src/zenml/model/artifact_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _model_config(self) -> "ModelConfig":
"""
try:
model_config = get_step_context().model_config
except StepContextError:
except (StepContextError, RuntimeError):
model_config = None
# Check if a specific model name is provided and it doesn't match the context name
if (self.model_name is not None) and (
Expand Down
47 changes: 47 additions & 0 deletions src/zenml/orchestrators/step_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from zenml.utils import string_utils

if TYPE_CHECKING:
from zenml.model import ModelConfig
from zenml.models.artifact_models import ArtifactResponseModel
from zenml.models.pipeline_deployment_models import (
PipelineDeploymentResponseModel,
Expand Down Expand Up @@ -380,11 +381,57 @@ def _prepare(
output_name: artifact.id
for output_name, artifact in cached_outputs.items()
}
if model_config:
self._link_cached_artifacts_to_model_version(
model_config=model_config,
step_run=step_run,
)
step_run.status = ExecutionStatus.CACHED
step_run.end_time = step_run.start_time

return execution_needed, step_run

def _link_cached_artifacts_to_model_version(
self,
model_config: "ModelConfig",
step_run: StepRunRequestModel,
) -> None:
"""Links the output artifacts of the cached step to the model version in Control Plane.
Args:
model_config: The model config of the current step.
step_run: The step to run.
"""
from zenml.model.artifact_config import ArtifactConfig
from zenml.steps.base_step import BaseStep
from zenml.steps.utils import parse_return_type_annotations

model_version = model_config.get_or_create_model_version()
step_instance = BaseStep.load_from_source(self._step.spec.source)
output_annotations = parse_return_type_annotations(
step_instance.entrypoint
)
for output_name_, output_ in step_run.outputs.items():
if output_name_ in output_annotations:
annotation = output_annotations.get(output_name_, None)
artifact_config = (
annotation.artifact_config
if annotation and annotation.artifact_config is not None
else ArtifactConfig()
)
artifact_config_ = artifact_config.copy()
artifact_config_.model_name = (
artifact_config.model_name or model_version.model.name
)
artifact_config_.model_version = (
artifact_config_.model_version or model_version.name
)
artifact_config_._pipeline_name = (
self._deployment.pipeline_configuration.name
)
artifact_config_._step_name = self._step_name
artifact_config_.link_to_model(output_)

def _run_step(
self,
pipeline_run: PipelineRunResponseModel,
Expand Down
58 changes: 48 additions & 10 deletions src/zenml/zen_stores/sql_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5949,6 +5949,39 @@ def create_model_version_artifact_link(
EntityExistsError: If a link with the given name already exists.
"""
with Session(self.engine) as session:
collision_msg = (
"Unable to create model version link {name}: "
"An artifact with same ID is already tracked in {version} model version "
"with the same name. It has to be deleted first."
)
existing_model_version_artifact_link_in_other_name = session.exec(
select(ModelVersionArtifactSchema)
.where(
and_(
or_(
ModelVersionArtifactSchema.name
!= model_version_artifact_link.name,
ModelVersionArtifactSchema.pipeline_name
!= model_version_artifact_link.pipeline_name,
ModelVersionArtifactSchema.step_name
!= model_version_artifact_link.step_name,
),
ModelVersionArtifactSchema.artifact_id
== model_version_artifact_link.artifact,
)
)
.where(
ModelVersionArtifactSchema.model_version_id
== model_version_artifact_link.model_version
)
).first()
if existing_model_version_artifact_link_in_other_name is not None:
raise EntityExistsError(
collision_msg.format(
name=existing_model_version_artifact_link_in_other_name.name,
version=existing_model_version_artifact_link_in_other_name.model_version,
)
)
existing_model_version_artifact_link = session.exec(
select(ModelVersionArtifactSchema)
.where(
Expand All @@ -5971,16 +6004,21 @@ def create_model_version_artifact_link(
)
.order_by(ModelVersionArtifactSchema.version.desc()) # type: ignore[attr-defined]
).first()
if existing_model_version_artifact_link is not None and (
existing_model_version_artifact_link.artifact_id
== model_version_artifact_link.artifact
or model_version_artifact_link.overwrite
):
raise EntityExistsError(
f"Unable to create model version link {existing_model_version_artifact_link.name}: "
f"An artifact with same ID is already tracked in {existing_model_version_artifact_link.model_version} model version "
"with the same name. It has to be deleted first."
)
if existing_model_version_artifact_link is not None:
if model_version_artifact_link.overwrite:
raise EntityExistsError(
collision_msg.format(
name=existing_model_version_artifact_link.name,
version=existing_model_version_artifact_link.model_version,
)
)
elif (
model_version_artifact_link.artifact
== existing_model_version_artifact_link.artifact_id
):
return ModelVersionArtifactSchema.to_model(
existing_model_version_artifact_link
)

if (
model_version_artifact_link.name is None
Expand Down
133 changes: 133 additions & 0 deletions tests/integration/functional/model/test_artifact_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from tests.integration.functional.utils import model_killer
from zenml import pipeline, step
from zenml.client import Client
from zenml.constants import RUNNING_MODEL_VERSION
from zenml.enums import ModelStages
from zenml.exceptions import EntityExistsError
from zenml.model import (
Expand Down Expand Up @@ -779,3 +780,135 @@ def test_link_with_manual_linkage_flexible_config(
assert len(links) == 1
assert links[0].link_version == 1
assert links[0].name == "1"


@step(enable_cache=True)
def _cacheable_step_annotated() -> (
Annotated[str, "cacheable", ModelArtifactConfig()]
):
return "cacheable"


@step(enable_cache=True)
def _cacheable_step_not_annotated():
return "cacheable"


@step(enable_cache=True)
def _cacheable_step_custom_model_annotated() -> (
Annotated[
str,
"cacheable",
ArtifactConfig(model_name="bar", model_version=RUNNING_MODEL_VERSION),
]
):
return "cacheable"


@step(enable_cache=False)
def _non_cacheable_step():
return "not cacheable"


def test_artifacts_linked_from_cache_steps():
"""Test that artifacts are linked from cache steps."""

@pipeline(
model_config=ModelConfig(name="foo", create_new_model_version=True),
enable_cache=False,
)
def _inner_pipeline(force_disable_cache: bool = False):
_cacheable_step_annotated.with_options(
enable_cache=force_disable_cache
)()
_cacheable_step_not_annotated.with_options(
enable_cache=force_disable_cache
)()
_cacheable_step_custom_model_annotated.with_options(
enable_cache=force_disable_cache
)()
_non_cacheable_step()

with model_killer():
client = Client()

for i in range(1, 3):
fake_version = ModelConfig(
name="bar", create_new_model_version=True
).get_or_create_model_version()
_inner_pipeline(i != 1)

mv = client.get_model_version(
model_name_or_id="foo", model_version_name_or_number_or_id=i
)
assert len(mv.artifact_object_ids) == 2, f"Failed on {i} run"
assert len(mv.model_object_ids) == 1, f"Failed on {i} run"
assert set(mv.artifact_object_ids.keys()) == {
"_inner_pipeline::_non_cacheable_step::output",
"_inner_pipeline::_cacheable_step_not_annotated::output",
}, f"Failed on {i} run"
assert set(mv.model_object_ids.keys()) == {
"_inner_pipeline::_cacheable_step_annotated::cacheable",
}, f"Failed on {i} run"

mv = client.get_model_version(
model_name_or_id="bar",
model_version_name_or_number_or_id=RUNNING_MODEL_VERSION,
)
assert len(mv.artifact_object_ids) == 1, f"Failed on {i} run"
assert set(mv.artifact_object_ids.keys()) == {
"_inner_pipeline::_cacheable_step_custom_model_annotated::cacheable",
}, f"Failed on {i} run"
assert (
len(
mv.artifact_object_ids[
"_inner_pipeline::_cacheable_step_custom_model_annotated::cacheable"
]
)
== 1
), f"Failed on {i} run"

fake_version._update_default_running_version_name()


def test_artifacts_linked_from_cache_steps_same_id():
"""Test that artifacts are linked from cache steps with same id.
This case appears if cached step is executed inside same model version
and we need to silently pass linkage without failing on same id.
"""

@pipeline(
model_config=ModelConfig(name="foo", create_new_model_version=True),
enable_cache=False,
)
def _inner_pipeline(force_disable_cache: bool = False):
_cacheable_step_custom_model_annotated.with_options(
enable_cache=force_disable_cache
)()
_non_cacheable_step()

with model_killer():
client = Client()

for i in range(1, 3):
ModelConfig(
name="bar", create_new_model_version=True
).get_or_create_model_version()
_inner_pipeline(i != 1)

mv = client.get_model_version(
model_name_or_id="bar",
model_version_name_or_number_or_id=RUNNING_MODEL_VERSION,
)
assert len(mv.artifact_object_ids) == 1, f"Failed on {i} run"
assert set(mv.artifact_object_ids.keys()) == {
"_inner_pipeline::_cacheable_step_custom_model_annotated::cacheable",
}, f"Failed on {i} run"
assert (
len(
mv.artifact_object_ids[
"_inner_pipeline::_cacheable_step_custom_model_annotated::cacheable"
]
)
== 1
), f"Failed on {i} run"
Loading

0 comments on commit 85db1d5

Please sign in to comment.