Skip to content

Commit

Permalink
Fix tag resource id generator (zenml-io#2056)
Browse files Browse the repository at this point in the history
* fix tag resource id generator

* revert "fix"

* fix creation/deletion of tag links

* docs

* Update src/zenml/zen_stores/schemas/tag_schemas.py

Co-authored-by: Stefan Nica <[email protected]>

* lint

* Auto-update of E2E template

---------

Co-authored-by: Stefan Nica <[email protected]>
Co-authored-by: GitHub Actions <[email protected]>
  • Loading branch information
3 people authored Nov 20, 2023
1 parent 2b661b4 commit 51768a0
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 71 deletions.
11 changes: 0 additions & 11 deletions src/zenml/models/tag_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,6 @@ class TagResourceBaseModel(BaseModel):
resource_id: UUID
resource_type: TaggableResourceTypes

@property
def tag_resource_id(self) -> UUID:
"""Get stable ID from tag_id and resource_id.
Returns:
The generated stable ID.
"""
from zenml.utils.tag_utils import _get_tag_resource_id

return _get_tag_resource_id(self.tag_id, self.resource_id)


class TagResourceResponseModel(TagResourceBaseModel, BaseResponseModel):
"""Response model for tag resource relationships."""
Expand Down
22 changes: 0 additions & 22 deletions src/zenml/utils/tag_utils.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def upgrade() -> None:
"""Upgrade database schema and/or data, creating a new revision."""
# ### commands auto generated by Alembic - please adjust! ###
from zenml.enums import ColorVariants, TaggableResourceTypes
from zenml.utils.tag_utils import _get_tag_resource_id

bind = op.get_bind()
session = sqlmodel.Session(bind=bind)
Expand Down Expand Up @@ -107,7 +106,7 @@ def upgrade() -> None:
for model_id_, tags_in_model in model_tags_prepared:
for tag in tags_in_model:
insert_tag_models += (
f"('{_get_tag_resource_id(tags_ids_mapping[tag],model_id_).hex}', "
f"('{uuid4().hex}', "
f"'{tags_ids_mapping[tag]}', '{model_id_}', "
f"'{TaggableResourceTypes.MODEL.value}', '{now}', '{now}'),"
)
Expand Down
1 change: 0 additions & 1 deletion src/zenml/zen_stores/schemas/tag_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def from_request(
The converted schema.
"""
return cls(
id=request.tag_resource_id,
tag_id=request.tag_id,
resource_id=request.resource_id,
resource_type=request.resource_type.value,
Expand Down
55 changes: 32 additions & 23 deletions src/zenml/zen_stores/sql_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6524,13 +6524,17 @@ def _get_tag_schema(

def _get_tag_model_schema(
self,
tag_resource_id: Union[str, UUID],
tag_id: UUID,
resource_id: UUID,
resource_type: TaggableResourceTypes,
session: Session,
) -> TagResourceSchema:
"""Gets a tag model schema by name or ID.
"""Gets a tag model schema by tag and resource.
Args:
tag_resource_id: The ID of the tag resource relation to get.
tag_id: The ID of the tag to get.
resource_id: The ID of the resource to get.
resource_type: The type of the resource to get.
session: The database session to use.
Returns:
Expand All @@ -6542,13 +6546,17 @@ def _get_tag_model_schema(
with Session(self.engine) as session:
schema = session.exec(
select(TagResourceSchema).where(
TagResourceSchema.id == tag_resource_id
TagResourceSchema.tag_id == tag_id,
TagResourceSchema.resource_id == resource_id,
TagResourceSchema.resource_type == resource_type.value,
)
).first()
if schema is None:
raise KeyError(
f"Unable to get {TagResourceSchema.__tablename__} with name or ID "
f"'{tag_resource_id}': No {TagResourceSchema.__tablename__} with this ID found."
f"Unable to get {TagResourceSchema.__tablename__} with IDs "
f"`tag_id`='{tag_id}' and `resource_id`='{resource_id}' and "
f"`resource_type`='{resource_type.value}': No "
f"{TagResourceSchema.__tablename__} with these IDs found."
)
return schema

Expand Down Expand Up @@ -7376,13 +7384,12 @@ def _detach_tags_from_resource(
resource_id: The id of the resource.
resource_type: The type of the resource to create link with.
"""
from zenml.utils.tag_utils import _get_tag_resource_id

for tag_name in tag_names:
try:
tag = self.get_tag(tag_name)
self.delete_tag_resource(
tag_resource_id=_get_tag_resource_id(tag.id, resource_id),
tag_id=tag.id,
resource_id=resource_id,
resource_type=resource_type,
)
except KeyError:
Expand Down Expand Up @@ -7541,7 +7548,10 @@ def create_tag_resource(
with Session(self.engine) as session:
existing_tag_resource = session.exec(
select(TagResourceSchema).where(
TagResourceSchema.id == tag_resource.tag_resource_id
TagResourceSchema.tag_id == tag_resource.tag_id,
TagResourceSchema.resource_id == tag_resource.resource_id,
TagResourceSchema.resource_type
== tag_resource.resource_type.value,
)
).first()
if existing_tag_resource is not None:
Expand All @@ -7559,34 +7569,33 @@ def create_tag_resource(

def delete_tag_resource(
self,
tag_resource_id: UUID,
tag_id: UUID,
resource_id: UUID,
resource_type: TaggableResourceTypes,
) -> None:
"""Deletes a tag resource relationship.
Args:
tag_resource_id: id of the tag<>resource to delete.
resource_type: The type of the resource to create link with.
tag_id: The ID of the tag to delete.
resource_id: The ID of the resource to delete.
resource_type: The type of the resource to delete.
Raises:
KeyError: specified ID not found.
RuntimeError: on resource type mismatch.
"""
with Session(self.engine) as session:
tag_model = self._get_tag_model_schema(
tag_resource_id=tag_resource_id,
tag_id=tag_id,
resource_id=resource_id,
resource_type=resource_type,
session=session,
)
if tag_model is None:
raise KeyError(
f"Unable to delete tag<>resource with ID `{tag_resource_id}`: "
f"No tag<>resource with these ID found."
)
if tag_model.resource_type != resource_type.value:
raise RuntimeError(
f"Unable to delete tag<>resource with ID `{tag_resource_id}`: "
f"Resource type in request `{resource_type.value}` do not match "
f"resource type defined in database `{tag_model.resource_type}`."
f"Unable to delete tag<>resource with IDs: "
f"`tag_id`='{tag_id}' and `resource_id`='{resource_id}' and "
f"`resource_type`='{resource_type.value}': No "
"tag<>resource with these IDs found."
)
session.delete(tag_model)
session.commit()
21 changes: 9 additions & 12 deletions tests/integration/functional/zen_stores/test_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@
_load_file_from_artifact_store,
)
from zenml.utils.enum_utils import StrEnum
from zenml.utils.tag_utils import _get_tag_resource_id
from zenml.zen_stores.base_zen_store import (
DEFAULT_ADMIN_ROLE,
DEFAULT_GUEST_ROLE,
Expand Down Expand Up @@ -5108,12 +5107,14 @@ def test_delete_tag_resource_pass(self, client):
)
)
client.zen_store.delete_tag_resource(
tag_resource_id=_get_tag_resource_id(tag.id, resource_id),
tag_id=tag.id,
resource_id=resource_id,
resource_type=TaggableResourceTypes.MODEL,
)
with pytest.raises(KeyError):
client.zen_store.delete_tag_resource(
tag_resource_id=_get_tag_resource_id(tag.id, resource_id),
tag_id=tag.id,
resource_id=resource_id,
resource_type=TaggableResourceTypes.MODEL,
)

Expand All @@ -5135,12 +5136,10 @@ class MockTaggableResourceTypes(StrEnum):
resource_type=TaggableResourceTypes.MODEL,
)
)
with pytest.raises(
RuntimeError,
match="Resource type in request.*do not match",
):
with pytest.raises(KeyError):
client.zen_store.delete_tag_resource(
tag_resource_id=_get_tag_resource_id(tag.id, resource_id),
tag_id=tag.id,
resource_id=resource_id,
resource_type=MockTaggableResourceTypes.APPLE,
)

Expand Down Expand Up @@ -5193,9 +5192,7 @@ def test_cascade_deletion(self, use_model, use_tag, client):
)
# cleanup
client.zen_store.delete_tag_resource(
_get_tag_resource_id(
tag_id=tag.id,
resource_id=fake_model_id,
),
tag_id=tag.id,
resource_id=fake_model_id,
resource_type=TaggableResourceTypes.MODEL,
)

0 comments on commit 51768a0

Please sign in to comment.