Skip to content

Commit

Permalink
VertexOrchestrator apply node selector constraint if gpu_limit > 0 (
Browse files Browse the repository at this point in the history
zenml-io#935)

* Add adding node selector constraint if `gpu_limit > 0`

* Change `gpu_limit` to `NonNegativeInt`

* Fix formatting

* Update condition to make it shorter

Co-authored-by: Felix Altenberger <[email protected]>

Co-authored-by: Felix Altenberger <[email protected]>
  • Loading branch information
gabrielmbmb and fa9r authored Oct 7, 2022
1 parent e1a440c commit bd084c0
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/zenml/config/resource_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from enum import Enum
from typing import Optional, Union

from pydantic import Extra, Field, PositiveFloat, PositiveInt
from pydantic import Extra, Field, NonNegativeInt, PositiveFloat

from zenml.config.base_settings import BaseSettings

Expand Down Expand Up @@ -69,7 +69,7 @@ class ResourceSettings(BaseSettings):
"""

cpu_count: Optional[PositiveFloat] = None
gpu_count: Optional[PositiveInt] = None
gpu_count: Optional[NonNegativeInt] = None
memory: Optional[str] = Field(regex=MEMORY_REGEX)

@property
Expand Down
4 changes: 4 additions & 0 deletions src/zenml/integrations/gcp/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,7 @@
_VERTEX_JOB_STATE_FAILED,
_VERTEX_JOB_STATE_CANCELLED,
)

GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL = (
"cloud.google.com/gke-accelerator"
)
21 changes: 15 additions & 6 deletions src/zenml/integrations/gcp/orchestrators/vertex_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
from zenml.constants import ORCHESTRATOR_DOCKER_IMAGE_KEY
from zenml.enums import StackComponentType
from zenml.integrations.gcp import GCP_ARTIFACT_STORE_FLAVOR
from zenml.integrations.gcp.constants import (
GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL,
)
from zenml.integrations.gcp.flavors.vertex_orchestrator_flavor import (
VertexOrchestratorConfig,
)
Expand Down Expand Up @@ -232,16 +235,22 @@ def _configure_container_resources(
if memory_limit is not None:
container_op = container_op.set_memory_limit(memory_limit)

if self.config.node_selector_constraint is not None:
container_op = container_op.add_node_selector_constraint(
label_name=self.config.node_selector_constraint[0],
value=self.config.node_selector_constraint[1],
)

gpu_limit = resource_settings.gpu_count or self.config.gpu_limit
if gpu_limit is not None:
container_op = container_op.set_gpu_limit(gpu_limit)

if self.config.node_selector_constraint is not None:
constraint_label = self.config.node_selector_constraint[0]
value = self.config.node_selector_constraint[1]
if not (
constraint_label
== GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL
and gpu_limit == 0
):
container_op.add_node_selector_constraint(
constraint_label, value
)

def prepare_or_run_pipeline(
self,
deployment: "PipelineDeployment",
Expand Down

0 comments on commit bd084c0

Please sign in to comment.