Skip to content

Commit

Permalink
Delete artifact links from model version via Client, ModelVersion and…
Browse files Browse the repository at this point in the history
… API (zenml-io#2191)

* typo

* client and ModelVersion implementations

* adding endpoints

* split endpoints

* lint

* coderabbitai feedback

* lint

* Apply suggestions from code review

Co-authored-by: Alex Strick van Linschoten <[email protected]>

* finish renaming

* renaming leftover

* move deletion to sql

---------

Co-authored-by: Alex Strick van Linschoten <[email protected]>
  • Loading branch information
avishniakov and strickvl authored Jan 10, 2024
1 parent 1030551 commit 5e47b81
Show file tree
Hide file tree
Showing 9 changed files with 299 additions and 8 deletions.
49 changes: 46 additions & 3 deletions src/zenml/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2948,7 +2948,7 @@ def delete_artifact_version(
database, not the actual object stored in the artifact store.
Args:
name_id_or_prefix: The ID or name or prefix of the artifact to
name_id_or_prefix: The ID of artifact version or name or prefix of the artifact to
delete.
version: The version of the artifact to delete.
delete_metadata: If True, delete the metadata of the artifact
Expand Down Expand Up @@ -5003,8 +5003,6 @@ def update_model_version(

#################################################
# Model Versions Artifacts
#
# Only view capabilities are exposed via client.
#################################################

def list_model_version_artifact_links(
Expand Down Expand Up @@ -5074,6 +5072,51 @@ def list_model_version_artifact_links(
hydrate=hydrate,
)

def delete_model_version_artifact_link(
self, model_version_id: UUID, artifact_version_id: UUID
) -> None:
"""Delete model version to artifact link in Model Control Plane.
Args:
model_version_id: The id of the model version holding the link.
artifact_version_id: The id of the artifact version to be deleted.
Raises:
RuntimeError: If more than one artifact link is found for given filters.
"""
artifact_links = self.list_model_version_artifact_links(
model_version_id=model_version_id,
artifact_version_id=artifact_version_id,
)
if artifact_links.items:
if artifact_links.total > 1:
raise RuntimeError(
"More than one artifact link found for give model version "
f"`{model_version_id}` and artifact version "
f"`{artifact_version_id}`. This should not be happening and "
"might indicate a corrupted state of your ZenML database. "
"Please seek support via Community Slack."
)
self.zen_store.delete_model_version_artifact_link(
model_version_id=model_version_id,
model_version_artifact_link_name_or_id=artifact_links.items[
0
].id,
)

def delete_all_model_version_artifact_links(
self, model_version_id: UUID, only_links: bool
) -> None:
"""Delete all model version to artifact links in Model Control Plane.
Args:
model_version_id: The id of the model version holding the link.
only_links: If true, only delete the link to the artifact.
"""
self.zen_store.delete_all_model_version_artifact_links(
model_version_id, only_links
)

#################################################
# Model Versions Pipeline Runs
#
Expand Down
63 changes: 63 additions & 0 deletions src/zenml/model/model_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,69 @@ def metadata(self) -> Dict[str, "MetadataType"]:
for name, response in response.run_metadata.items()
}

def delete_artifact(
self,
name: str,
version: Optional[str] = None,
only_link: bool = True,
delete_metadata: bool = True,
delete_from_artifact_store: bool = False,
) -> None:
"""Delete the artifact linked to this model version.
Args:
name: The name of the artifact to delete.
version: The version of the artifact to delete (None for latest/non-versioned)
only_link: Whether to only delete the link to the artifact.
delete_metadata: Whether to delete the metadata of the artifact.
delete_from_artifact_store: Whether to delete the artifact from the artifact store.
"""
from zenml.client import Client
from zenml.models import ArtifactVersionResponse

artifact_version = self.get_artifact(name, version)
if isinstance(artifact_version, ArtifactVersionResponse):
client = Client()
client.delete_model_version_artifact_link(
model_version_id=self.id,
artifact_version_id=artifact_version.id,
)
if not only_link:
client.delete_artifact_version(
name_id_or_prefix=artifact_version.id,
delete_metadata=delete_metadata,
delete_from_artifact_store=delete_from_artifact_store,
)

def delete_all_artifacts(
self,
only_link: bool = True,
delete_from_artifact_store: bool = False,
) -> None:
"""Delete all artifacts linked to this model version.
Args:
only_link: Whether to only delete the link to the artifact.
delete_from_artifact_store: Whether to delete the artifact from the artifact store.
"""
from zenml.client import Client

client = Client()

if not only_link and delete_from_artifact_store:
mv = self._get_model_version()
artifact_responses = mv.data_artifacts
artifact_responses.update(mv.model_artifacts)
artifact_responses.update(mv.deployment_artifacts)

for artifact_ in artifact_responses.values():
for artifact_response_ in artifact_.values():
client._delete_artifact_from_artifact_store(
artifact_version=artifact_response_
)

client.delete_all_model_version_artifact_links(self.id, only_link)

#########################
# Internal methods #
#########################
Expand Down
2 changes: 1 addition & 1 deletion src/zenml/models/v2/core/model_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def data_artifacts(
}

@property
def endpoint_artifacts(
def deployment_artifacts(
self,
) -> Dict[str, Dict[str, "ArtifactVersionResponse"]]:
"""Get all deployment artifacts linked to this model version.
Expand Down
26 changes: 25 additions & 1 deletion src/zenml/zen_server/routers/model_versions_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def delete_model_version_artifact_link(
model_version_artifact_link_name_or_id: Union[str, UUID],
_: AuthContext = Security(authorize),
) -> None:
"""Deletes a model version link.
"""Deletes a model version to artifact link.
Args:
model_version_id: ID of the model version containing the link.
Expand All @@ -256,6 +256,30 @@ def delete_model_version_artifact_link(
)


@router.delete(
"/{model_version_id}" + ARTIFACTS,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_all_model_version_artifact_links(
model_version_id: UUID,
only_links: bool = True,
_: AuthContext = Security(authorize),
) -> None:
"""Deletes all model version to artifact links.
Args:
model_version_id: ID of the model version containing links.
only_links: Whether to only delete the link to the artifact.
"""
model_version = zen_store().get_model_version(model_version_id)
verify_permission_for_model(model_version, action=Action.UPDATE)

zen_store().delete_all_model_version_artifact_links(
model_version_id, only_links
)


##############################
# Model Version Pipeline Runs
##############################
Expand Down
16 changes: 16 additions & 0 deletions src/zenml/zen_stores/rest_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2836,6 +2836,22 @@ def delete_model_version_artifact_link(
route=f"{MODEL_VERSIONS}/{model_version_id}{ARTIFACTS}",
)

def delete_all_model_version_artifact_links(
self,
model_version_id: UUID,
only_links: bool = True,
) -> None:
"""Deletes all links between model version and an artifact.
Args:
model_version_id: ID of the model version containing the link.
only_links: Flag deciding whether to delete only links or all.
"""
self.delete(
f"{MODEL_VERSIONS}/{model_version_id}{ARTIFACTS}",
params={"only_links": only_links},
)

# ---------------------- Model Versions Pipeline Runs ----------------------

def create_model_version_pipeline_run_link(
Expand Down
41 changes: 39 additions & 2 deletions src/zenml/zen_stores/sql_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
OperationalError,
)
from sqlalchemy.orm import noload
from sqlmodel import Session, SQLModel, create_engine, or_, select
from sqlmodel import Session, SQLModel, col, create_engine, delete, or_, select
from sqlmodel.sql.expression import Select, SelectOfScalar

from zenml.analytics.enums import AnalyticsEvent
Expand Down Expand Up @@ -1624,7 +1624,7 @@ def get_artifact_version(
if artifact_version is None:
raise KeyError(
f"Unable to get artifact version with ID "
f"{artifact_version_id}: No artifact versionwith this ID "
f"{artifact_version_id}: No artifact version with this ID "
f"found."
)
return artifact_version.to_model(hydrate=hydrate)
Expand Down Expand Up @@ -6827,6 +6827,43 @@ def delete_model_version_artifact_link(
session.delete(model_version_artifact_link)
session.commit()

def delete_all_model_version_artifact_links(
self,
model_version_id: UUID,
only_links: bool = True,
) -> None:
"""Deletes all model version to artifact links.
Args:
model_version_id: ID of the model version containing the link.
only_links: Whether to only delete the link to the artifact.
"""
with Session(self.engine) as session:
if not only_links:
artifact_version_ids = session.execute(
select(
ModelVersionArtifactSchema.artifact_version_id
).where(
ModelVersionArtifactSchema.model_version_id
== model_version_id
)
).fetchall()
session.execute(
delete(ArtifactVersionSchema).where(
col(ArtifactVersionSchema.id).in_(
[a[0] for a in artifact_version_ids]
)
),
)
session.execute(
delete(ModelVersionArtifactSchema).where(
ModelVersionArtifactSchema.model_version_id
== model_version_id
)
)

session.commit()

# ---------------------- Model Versions Pipeline Runs ----------------------

def create_model_version_pipeline_run_link(
Expand Down
13 changes: 13 additions & 0 deletions src/zenml/zen_stores/zen_store_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2199,6 +2199,19 @@ def delete_model_version_artifact_link(
KeyError: specified ID or name not found.
"""

@abstractmethod
def delete_all_model_version_artifact_links(
self,
model_version_id: UUID,
only_links: bool = True,
) -> None:
"""Deletes all model version to artifact links.
Args:
model_version_id: ID of the model version containing the link.
only_links: Flag deciding whether to delete only links or all.
"""

# -------------------- Model Versions Pipeline Runs --------------------

@abstractmethod
Expand Down
69 changes: 69 additions & 0 deletions tests/integration/functional/model/test_model_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ def step_metadata_logging_functional():
assert get_step_context().model_version.metadata["foo"] == "bar"


@step
def simple_producer() -> str:
"""Simple producer step."""
return "foo"


@step
def consume_from_model_version(
is_consume: bool,
Expand Down Expand Up @@ -393,6 +399,69 @@ def my_pipeline():
assert len(mv.metadata) == 1
assert mv.metadata["foo"] == "bar"

@pytest.mark.parametrize("delete_artifacts", [False, True])
def test_deletion_of_links(
self, clean_client: "Client", delete_artifacts: bool
):
"""Test that user can delete artifact links (with artifacts) from ModelVersion."""

@pipeline(
model_version=ModelVersion(
name=MODEL_NAME,
),
enable_cache=False,
)
def _inner_pipeline():
simple_producer()
simple_producer(id="other_named_producer")

_inner_pipeline()

mv = ModelVersion(name=MODEL_NAME, version="latest")
artifact_ids = mv._get_model_version().data_artifact_ids
assert len(artifact_ids) == 2

# delete run to enable artifacts deletion
run = clean_client.get_pipeline(
name_id_or_prefix="_inner_pipeline"
).last_run
clean_client.delete_pipeline_run(run.id)

mv.delete_artifact(
only_link=not delete_artifacts,
name="_inner_pipeline::other_named_producer::output",
)
assert len(mv._get_model_version().data_artifact_ids) == 1
versions_ = artifact_ids[
"_inner_pipeline::other_named_producer::output"
]["1"]
if delete_artifacts:
with pytest.raises(KeyError):
clean_client.get_artifact_version(versions_)
else:
assert clean_client.get_artifact_version(versions_).id == versions_

_inner_pipeline()
mv = ModelVersion(name=MODEL_NAME, version="latest")
artifact_ids = mv._get_model_version().data_artifact_ids
assert len(artifact_ids) == 2

# delete run to enable artifacts deletion
run = clean_client.get_pipeline(
name_id_or_prefix="_inner_pipeline"
).last_run
clean_client.delete_pipeline_run(run.id)

mv.delete_all_artifacts(only_link=not delete_artifacts)
assert len(mv._get_model_version().data_artifact_ids) == 0
for versions_ in artifact_ids.values():
for id_ in versions_.values():
if delete_artifacts:
with pytest.raises(KeyError):
clean_client.get_artifact_version(id_)
else:
assert clean_client.get_artifact_version(id_).id == id_

def test_that_artifacts_are_not_linked_to_models_outside_of_the_context(
self, clean_client: "Client"
):
Expand Down
Loading

0 comments on commit 5e47b81

Please sign in to comment.