From 38b424138539decd1ce396d0d73c6b73451e547b Mon Sep 17 00:00:00 2001 From: nate nowack Date: Mon, 3 Jun 2024 13:42:56 -0500 Subject: [PATCH] migrate `prefect-aws` to pydantic 2 (#13726) --- requirements-client.txt | 1 + requirements-dev.txt | 1 - .../prefect-aws/prefect_aws/batch.py | 3 +- .../prefect_aws/client_parameters.py | 18 +- .../prefect-aws/prefect_aws/client_waiter.py | 3 +- .../prefect-aws/prefect_aws/credentials.py | 24 +- .../prefect-aws/prefect_aws/glue_job.py | 10 +- .../prefect_aws/lambda_function.py | 14 +- .../prefect-aws/prefect_aws/s3.py | 10 +- .../prefect_aws/secrets_manager.py | 8 +- .../prefect_aws/workers/ecs_worker.py | 185 +++++++++------ .../prefect-aws/tests/conftest.py | 4 +- src/integrations/prefect-aws/tests/test_s3.py | 16 -- .../prefect-aws/tests/test_secrets_manager.py | 7 +- .../tests/workers/test_ecs_worker.py | 211 ++++++++++-------- 15 files changed, 268 insertions(+), 247 deletions(-) diff --git a/requirements-client.txt b/requirements-client.txt index f7166e55ec34..06dacab0f1c2 100644 --- a/requirements-client.txt +++ b/requirements-client.txt @@ -4,6 +4,7 @@ cachetools >= 5.3, < 6.0 cloudpickle >= 2.0, < 4.0 coolname >= 1.0.4, < 3.0.0 croniter >= 1.0.12, < 3.0.0 +exceptiongroup >= 1.0.0 fastapi >= 0.111.0, < 1.0.0 fsspec >= 2022.5.0 graphviz >= 0.20.1 diff --git a/requirements-dev.txt b/requirements-dev.txt index 64274ce2b8ea..188c9105e7f8 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,7 +2,6 @@ ruff cairosvg codespell>=2.2.6 ddtrace -exceptiongroup ipython jinja2 mkdocs diff --git a/src/integrations/prefect-aws/prefect_aws/batch.py b/src/integrations/prefect-aws/prefect_aws/batch.py index 4be9dbd7bfe9..86b26b51e18b 100644 --- a/src/integrations/prefect-aws/prefect_aws/batch.py +++ b/src/integrations/prefect-aws/prefect_aws/batch.py @@ -3,11 +3,12 @@ from typing import Any, Dict, Optional from prefect import get_run_logger, task -from prefect.utilities.asyncutils import run_sync_in_worker_thread +from prefect.utilities.asyncutils import run_sync_in_worker_thread, sync_compatible from prefect_aws.credentials import AwsCredentials @task +@sync_compatible async def batch_submit( job_name: str, job_queue: str, diff --git a/src/integrations/prefect-aws/prefect_aws/client_parameters.py b/src/integrations/prefect-aws/prefect_aws/client_parameters.py index 6b47c422b48b..ca1a12ceda6c 100644 --- a/src/integrations/prefect-aws/prefect_aws/client_parameters.py +++ b/src/integrations/prefect-aws/prefect_aws/client_parameters.py @@ -5,15 +5,10 @@ from botocore import UNSIGNED from botocore.client import Config -from pydantic import VERSION as PYDANTIC_VERSION +from pydantic import BaseModel, Field, FilePath, field_validator, model_validator from prefect_aws.utilities import hash_collection -if PYDANTIC_VERSION.startswith("2."): - from pydantic.v1 import BaseModel, Field, FilePath, root_validator, validator -else: - from pydantic import BaseModel, Field, FilePath, root_validator, validator - class AwsClientParameters(BaseModel): """ @@ -84,7 +79,8 @@ def __hash__(self): ) ) - @validator("config", pre=True) + @field_validator("config", mode="before") + @classmethod def instantiate_config(cls, value: Union[Config, Dict[str, Any]]) -> Dict[str, Any]: """ Casts lists to Config instances. @@ -93,7 +89,8 @@ def instantiate_config(cls, value: Union[Config, Dict[str, Any]]) -> Dict[str, A return value.__dict__["_user_provided_options"] return value - @root_validator + @model_validator(mode="before") + @classmethod def deprecated_verify_cert_path(cls, values: Dict[str, Any]) -> Dict[str, Any]: """ If verify is not a bool, raise a warning. @@ -112,7 +109,8 @@ def deprecated_verify_cert_path(cls, values: Dict[str, Any]) -> Dict[str, Any]: ) return values - @root_validator + @model_validator(mode="before") + @classmethod def verify_cert_path_and_verify(cls, values: Dict[str, Any]) -> Dict[str, Any]: """ If verify_cert_path is set but verify is False, raise a warning. @@ -139,7 +137,7 @@ def get_params_override(self) -> Dict[str, Any]: Return the dictionary of the parameters to override. The parameters to override are the one which are not None. """ - params = self.dict() + params = self.model_dump() if params.get("verify_cert_path"): # to ensure that verify doesn't re-overwrite verify_cert_path params.pop("verify") diff --git a/src/integrations/prefect-aws/prefect_aws/client_waiter.py b/src/integrations/prefect-aws/prefect_aws/client_waiter.py index f289fb40a7d4..0d0c73ce91c0 100644 --- a/src/integrations/prefect-aws/prefect_aws/client_waiter.py +++ b/src/integrations/prefect-aws/prefect_aws/client_waiter.py @@ -5,11 +5,12 @@ from botocore.waiter import WaiterModel, create_waiter_with_client from prefect import get_run_logger, task -from prefect.utilities.asyncutils import run_sync_in_worker_thread +from prefect.utilities.asyncutils import run_sync_in_worker_thread, sync_compatible from prefect_aws.credentials import AwsCredentials @task +@sync_compatible async def client_waiter( client: str, waiter_name: str, diff --git a/src/integrations/prefect-aws/prefect_aws/credentials.py b/src/integrations/prefect-aws/prefect_aws/credentials.py index 04e020ce692f..987a4c722660 100644 --- a/src/integrations/prefect-aws/prefect_aws/credentials.py +++ b/src/integrations/prefect-aws/prefect_aws/credentials.py @@ -8,15 +8,9 @@ import boto3 from mypy_boto3_s3 import S3Client from mypy_boto3_secretsmanager import SecretsManagerClient -from pydantic import VERSION as PYDANTIC_VERSION +from pydantic import ConfigDict, Field, SecretStr from prefect.blocks.abstract import CredentialsBlock - -if PYDANTIC_VERSION.startswith("2."): - from pydantic.v1 import Field, SecretStr -else: - from pydantic import Field, SecretStr - from prefect_aws.client_parameters import AwsClientParameters _LOCK = Lock() @@ -72,6 +66,8 @@ class AwsCredentials(CredentialsBlock): ``` """ # noqa E501 + model_config = ConfigDict(arbitrary_types_allowed=True) + _logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/d74b16fe84ce626345adf235a47008fea2869a60-225x225.png" # noqa _block_type_name = "AWS Credentials" _documentation_url = "https://prefecthq.github.io/prefect-aws/credentials/#prefect_aws.credentials.AwsCredentials" # noqa @@ -107,11 +103,6 @@ class AwsCredentials(CredentialsBlock): title="AWS Client Parameters", ) - class Config: - """Config class for pydantic model.""" - - arbitrary_types_allowed = True - def __hash__(self): field_hashes = ( hash(self.aws_access_key_id), @@ -209,6 +200,8 @@ class MinIOCredentials(CredentialsBlock): ``` """ # noqa E501 + model_config = ConfigDict(arbitrary_types_allowed=True) + _logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/676cb17bcbdff601f97e0a02ff8bcb480e91ff40-250x250.png" # noqa _block_type_name = "MinIO Credentials" _description = ( @@ -231,18 +224,13 @@ class MinIOCredentials(CredentialsBlock): description="Extra parameters to initialize the Client.", ) - class Config: - """Config class for pydantic model.""" - - arbitrary_types_allowed = True - def __hash__(self): return hash( ( hash(self.minio_root_user), hash(self.minio_root_password), hash(self.region_name), - hash(frozenset(self.aws_client_parameters.dict().items())), + hash(frozenset(self.aws_client_parameters.model_dump().items())), ) ) diff --git a/src/integrations/prefect-aws/prefect_aws/glue_job.py b/src/integrations/prefect-aws/prefect_aws/glue_job.py index c131265027c6..5f6d35e2dbc6 100644 --- a/src/integrations/prefect-aws/prefect_aws/glue_job.py +++ b/src/integrations/prefect-aws/prefect_aws/glue_job.py @@ -6,21 +6,15 @@ import time from typing import Any, Optional -from pydantic import VERSION as PYDANTIC_VERSION +from pydantic import BaseModel, Field from prefect.blocks.abstract import JobBlock, JobRun - -if PYDANTIC_VERSION.startswith("2."): - from pydantic.v1 import BaseModel, Field -else: - from pydantic import BaseModel, Field - from prefect_aws import AwsCredentials _GlueJobClient = Any -class GlueJobRun(JobRun, BaseModel): +class GlueJobRun(BaseModel, JobRun): """Execute a Glue Job""" job_name: str = Field( diff --git a/src/integrations/prefect-aws/prefect_aws/lambda_function.py b/src/integrations/prefect-aws/prefect_aws/lambda_function.py index fb76f956cdd3..cd87294b6f2d 100644 --- a/src/integrations/prefect-aws/prefect_aws/lambda_function.py +++ b/src/integrations/prefect-aws/prefect_aws/lambda_function.py @@ -49,19 +49,14 @@ ``` """ + import json from typing import Literal, Optional -from pydantic import VERSION as PYDANTIC_VERSION +from pydantic import Field from prefect.blocks.core import Block from prefect.utilities.asyncutils import run_sync_in_worker_thread, sync_compatible - -if PYDANTIC_VERSION.startswith("2."): - from pydantic.v1 import Field -else: - from pydantic import Field - from prefect_aws.credentials import AwsCredentials @@ -109,11 +104,6 @@ class LambdaFunction(Block): description="The AWS credentials to invoke the Lambda with.", ) - class Config: - """Lambda's pydantic configuration.""" - - smart_union = True - def _get_lambda_client(self): """ Retrieve a boto3 session and Lambda client diff --git a/src/integrations/prefect-aws/prefect_aws/s3.py b/src/integrations/prefect-aws/prefect_aws/s3.py index 84410f148fcc..439bac6ade81 100644 --- a/src/integrations/prefect-aws/prefect_aws/s3.py +++ b/src/integrations/prefect-aws/prefect_aws/s3.py @@ -1,4 +1,5 @@ """Tasks for interacting with AWS S3""" + import asyncio import io import os @@ -9,19 +10,13 @@ import boto3 from botocore.paginate import PageIterator from botocore.response import StreamingBody -from pydantic import VERSION as PYDANTIC_VERSION +from pydantic import Field from prefect import get_run_logger, task from prefect.blocks.abstract import ObjectStorageBlock from prefect.filesystems import WritableDeploymentStorage, WritableFileSystem from prefect.utilities.asyncutils import run_sync_in_worker_thread, sync_compatible from prefect.utilities.filesystem import filter_files - -if PYDANTIC_VERSION.startswith("2."): - from pydantic.v1 import Field -else: - from pydantic import Field - from prefect_aws import AwsCredentials, MinIOCredentials from prefect_aws.client_parameters import AwsClientParameters @@ -394,7 +389,6 @@ async def example_s3_list_objects_flow(): class S3Bucket(WritableFileSystem, WritableDeploymentStorage, ObjectStorageBlock): - """ Block used to store data using AWS S3 or S3-compatible object storage like MinIO. diff --git a/src/integrations/prefect-aws/prefect_aws/secrets_manager.py b/src/integrations/prefect-aws/prefect_aws/secrets_manager.py index a3af406b537c..5e46692ec6ea 100644 --- a/src/integrations/prefect-aws/prefect_aws/secrets_manager.py +++ b/src/integrations/prefect-aws/prefect_aws/secrets_manager.py @@ -3,17 +3,11 @@ from typing import Any, Dict, List, Optional, Union from botocore.exceptions import ClientError -from pydantic import VERSION as PYDANTIC_VERSION +from pydantic import Field from prefect import get_run_logger, task from prefect.blocks.abstract import SecretBlock from prefect.utilities.asyncutils import run_sync_in_worker_thread, sync_compatible - -if PYDANTIC_VERSION.startswith("2."): - from pydantic.v1 import Field -else: - from pydantic import Field - from prefect_aws import AwsCredentials diff --git a/src/integrations/prefect-aws/prefect_aws/workers/ecs_worker.py b/src/integrations/prefect-aws/prefect_aws/workers/ecs_worker.py index 054d57e8f9a9..c825efd4a5bf 100644 --- a/src/integrations/prefect-aws/prefect_aws/workers/ecs_worker.py +++ b/src/integrations/prefect-aws/prefect_aws/workers/ecs_worker.py @@ -58,8 +58,13 @@ import anyio import anyio.abc import yaml -from pydantic import VERSION as PYDANTIC_VERSION +from pydantic import BaseModel, Field, model_validator +from slugify import slugify +from tenacity import retry, stop_after_attempt, wait_fixed, wait_random +from typing_extensions import Literal, Self +from prefect.client.orchestration import PrefectClient +from prefect.client.utilities import inject_client from prefect.exceptions import InfrastructureNotAvailable, InfrastructureNotFound from prefect.server.schemas.core import FlowRun from prefect.utilities.asyncutils import run_sync_in_worker_thread @@ -69,17 +74,10 @@ BaseVariables, BaseWorker, BaseWorkerResult, + apply_values, + resolve_block_document_references, + resolve_variables, ) - -if PYDANTIC_VERSION.startswith("2."): - from pydantic.v1 import BaseModel, Field, root_validator -else: - from pydantic import BaseModel, Field, root_validator - -from slugify import slugify -from tenacity import retry, stop_after_attempt, wait_fixed, wait_random -from typing_extensions import Literal - from prefect_aws.credentials import AwsCredentials, ClientType # Internal type alias for ECS clients which are generated dynamically in botocore @@ -164,7 +162,7 @@ def _default_task_run_request_template() -> dict: return yaml.safe_load(DEFAULT_TASK_RUN_REQUEST_TEMPLATE) -def _drop_empty_keys_from_task_definition(taskdef: dict): +def _drop_empty_keys_from_dict(taskdef: dict): """ Recursively drop keys with 'empty' values from a task definition dict. @@ -174,11 +172,11 @@ def _drop_empty_keys_from_task_definition(taskdef: dict): if not value: taskdef.pop(key) if isinstance(value, dict): - _drop_empty_keys_from_task_definition(value) - if isinstance(value, list): + _drop_empty_keys_from_dict(value) + if isinstance(value, list) and key != "capacity_provider_strategy": for v in value: if isinstance(v, dict): - _drop_empty_keys_from_task_definition(v) + _drop_empty_keys_from_dict(v) def _get_container(containers: List[dict], name: str) -> Optional[dict]: @@ -264,11 +262,13 @@ class ECSJobConfiguration(BaseJobConfiguration): """ aws_credentials: Optional[AwsCredentials] = Field(default_factory=AwsCredentials) - task_definition: Optional[Dict[str, Any]] = Field( - template=_default_task_definition_template() + task_definition: Dict[str, Any] = Field( + default_factory=dict, + json_schema_extra=dict(template=_default_task_definition_template()), ) task_run_request: Dict[str, Any] = Field( - template=_default_task_run_request_template() + default_factory=dict, + json_schema_extra=dict(template=_default_task_run_request_template()), ) configure_cloudwatch_logs: Optional[bool] = Field(default=None) cloudwatch_logs_options: Dict[str, str] = Field(default_factory=dict) @@ -283,101 +283,144 @@ class ECSJobConfiguration(BaseJobConfiguration): cluster: Optional[str] = Field(default=None) match_latest_revision_in_family: bool = Field(default=False) - @root_validator - def task_run_request_requires_arn_if_no_task_definition_given(cls, values) -> dict: + execution_role_arn: Optional[str] = Field( + title="Execution Role ARN", + default=None, + description=( + "An execution role to use for the task. This controls the permissions of " + "the task when it is launching. If this value is not null, it will " + "override the value in the task definition. An execution role must be " + "provided to capture logs from the container." + ), + ) + + @model_validator(mode="after") + def task_run_request_requires_arn_if_no_task_definition_given(self) -> Self: """ If no task definition is provided, a task definition ARN must be present on the task run request. """ - if not values.get("task_run_request", {}).get( - "taskDefinition" - ) and not values.get("task_definition"): + if ( + not (self.task_run_request or {}).get("taskDefinition") + and not self.task_definition + ): raise ValueError( "A task definition must be provided if a task definition ARN is not " "present on the task run request." ) - return values + return self - @root_validator - def container_name_default_from_task_definition(cls, values) -> dict: + @model_validator(mode="after") + def container_name_default_from_task_definition(self) -> Self: """ Infers the container name from the task definition if not provided. """ - if values.get("container_name") is None: - values["container_name"] = _container_name_from_task_definition( - values.get("task_definition") + if self.container_name is None: + self.container_name = _container_name_from_task_definition( + self.task_definition ) # We may not have a name here still; for example if someone is using a task # definition arn. In that case, we'll perform similar logic later to find # the name to treat as the "orchestration" container. - return values + return self - @root_validator(pre=True) - def set_default_configure_cloudwatch_logs(cls, values: dict) -> dict: + @model_validator(mode="after") + def set_default_configure_cloudwatch_logs(self) -> Self: """ Streaming output generally requires CloudWatch logs to be configured. To avoid entangled arguments in the simple case, `configure_cloudwatch_logs` defaults to matching the value of `stream_output`. """ - configure_cloudwatch_logs = values.get("configure_cloudwatch_logs") + configure_cloudwatch_logs = self.configure_cloudwatch_logs if configure_cloudwatch_logs is None: - values["configure_cloudwatch_logs"] = values.get("stream_output") - return values + self.configure_cloudwatch_logs = self.stream_output + return self - @root_validator + @model_validator(mode="after") def configure_cloudwatch_logs_requires_execution_role_arn( - cls, values: dict - ) -> dict: + self, + ) -> Self: """ Enforces that an execution role arn is provided (or could be provided by a runtime task definition) when configuring logging. """ if ( - values.get("configure_cloudwatch_logs") - and not values.get("execution_role_arn") + self.configure_cloudwatch_logs + and not self.execution_role_arn # TODO: Does not match # Do not raise if they've linked to another task definition or provided # it without using our shortcuts - and not values.get("task_run_request", {}).get("taskDefinition") - and not (values.get("task_definition") or {}).get("executionRoleArn") + and not (self.task_run_request or {}).get("taskDefinition") + and not (self.task_definition or {}).get("executionRoleArn") ): raise ValueError( "An `execution_role_arn` must be provided to use " "`configure_cloudwatch_logs` or `stream_logs`." ) - return values + return self - @root_validator + @model_validator(mode="after") def cloudwatch_logs_options_requires_configure_cloudwatch_logs( - cls, values: dict - ) -> dict: + self, + ) -> Self: """ Enforces that an execution role arn is provided (or could be provided by a runtime task definition) when configuring logging. """ - if values.get("cloudwatch_logs_options") and not values.get( - "configure_cloudwatch_logs" - ): + if self.cloudwatch_logs_options and not self.configure_cloudwatch_logs: raise ValueError( "`configure_cloudwatch_log` must be enabled to use " "`cloudwatch_logs_options`." ) - return values + return self - @root_validator - def network_configuration_requires_vpc_id(cls, values: dict) -> dict: + @model_validator(mode="after") + def network_configuration_requires_vpc_id(self) -> Self: """ Enforces a `vpc_id` is provided when custom network configuration mode is enabled for network settings. """ - if values.get("network_configuration") and not values.get("vpc_id"): + if self.network_configuration and not self.vpc_id: raise ValueError( "You must provide a `vpc_id` to enable custom `network_configuration`." ) - return values + return self + + @classmethod + @inject_client + async def from_template_and_values( + cls, + base_job_template: dict, + values: dict, + client: Optional[PrefectClient] = None, + ): + """Creates a valid worker configuration object from the provided base + configuration and overrides. + + Important: this method expects that the base_job_template was already + validated server-side. + """ + + job_config: Dict[str, Any] = base_job_template["job_configuration"] + variables_schema = base_job_template["variables"] + variables = cls._get_base_config_defaults( + variables_schema.get("properties", {}) + ) + variables.update(values) + + _drop_empty_keys_from_dict(variables) # TODO: investigate why this is necessary + + populated_configuration = apply_values(template=job_config, values=variables) + populated_configuration = await resolve_block_document_references( + template=populated_configuration, client=client + ) + populated_configuration = await resolve_variables( + template=populated_configuration, client=client + ) + return cls(**populated_configuration) class ECSVariables(BaseVariables): @@ -428,9 +471,7 @@ class ECSVariables(BaseVariables): "field will be slugified to match AWS character requirements." ), ) - launch_type: Optional[ - Literal["FARGATE", "EC2", "EXTERNAL", "FARGATE_SPOT"] - ] = Field( + launch_type: Literal["FARGATE", "EC2", "EXTERNAL", "FARGATE_SPOT"] = Field( default=ECS_DEFAULT_LAUNCH_TYPE, description=( "The type of ECS task run infrastructure that should be used. Note that" @@ -438,7 +479,7 @@ class ECSVariables(BaseVariables): " the proper capacity provider strategy if set here." ), ) - capacity_provider_strategy: Optional[List[CapacityProvider]] = Field( + capacity_provider_strategy: List[CapacityProvider] = Field( default_factory=list, description=( "The capacity provider strategy to use when running the task. " @@ -454,7 +495,7 @@ class ECSVariables(BaseVariables): "defaults to a Prefect base image matching your local versions." ), ) - cpu: int = Field( + cpu: Optional[int] = Field( title="CPU", default=None, description=( @@ -463,7 +504,7 @@ class ECSVariables(BaseVariables): f"{ECS_DEFAULT_CPU} will be used unless present on the task definition." ), ) - memory: int = Field( + memory: Optional[int] = Field( default=None, description=( "The amount of memory to provide to the ECS task. Valid amounts are " @@ -471,7 +512,7 @@ class ECSVariables(BaseVariables): f"{ECS_DEFAULT_MEMORY} will be used unless present on the task definition." ), ) - container_name: str = Field( + container_name: Optional[str] = Field( default=None, description=( "The name of the container flow run orchestration will occur in. If not " @@ -480,7 +521,7 @@ class ECSVariables(BaseVariables): "be used." ), ) - task_role_arn: str = Field( + task_role_arn: Optional[str] = Field( title="Task Role ARN", default=None, description=( @@ -488,7 +529,7 @@ class ECSVariables(BaseVariables): "task while it is running." ), ) - execution_role_arn: str = Field( + execution_role_arn: Optional[str] = Field( title="Execution Role ARN", default=None, description=( @@ -509,7 +550,7 @@ class ECSVariables(BaseVariables): "VPC will be used. If no default VPC can be found, the task run will fail." ), ) - configure_cloudwatch_logs: bool = Field( + configure_cloudwatch_logs: Optional[bool] = Field( default=None, description=( "If enabled, the Prefect container will be configured to send its output " @@ -550,7 +591,7 @@ class ECSVariables(BaseVariables): ), ) - stream_output: bool = Field( + stream_output: Optional[bool] = Field( default=None, description=( "If enabled, logs will be streamed from the Prefect container to the local " @@ -606,7 +647,7 @@ class ECSWorker(BaseWorker): A Prefect worker to run flow runs as ECS tasks. """ - type = "ecs" + type: str = "ecs" job_configuration = ECSJobConfiguration job_configuration_variables = ECSVariables _description = ( @@ -707,6 +748,7 @@ def _create_task_and_wait_for_start( task_definition = self._prepare_task_definition( configuration, region=ecs_client.meta.region_name, flow_run=flow_run ) + ( task_definition_arn, new_task_definition_registered, @@ -978,6 +1020,7 @@ def _register_task_definition( "Task definition request" f"{json.dumps(task_definition, indent=2, default=str)}" ) + response = ecs_client.register_task_definition(**task_definition) return response["taskDefinition"]["taskDefinitionArn"] @@ -1468,7 +1511,11 @@ def _prepare_task_run_request( task_run_request = deepcopy(configuration.task_run_request) task_run_request.setdefault("taskDefinition", task_definition_arn) - assert task_run_request["taskDefinition"] == task_definition_arn + + assert task_run_request["taskDefinition"] == task_definition_arn, ( + f"Task definition ARN mismatch: {task_run_request['taskDefinition']!r} " + f"!= {task_definition_arn!r}" + ) capacityProviderStrategy = task_run_request.get("capacityProviderStrategy") if capacityProviderStrategy: @@ -1668,8 +1715,8 @@ def _task_definitions_equal(self, taskdef_1, taskdef_2) -> bool: taskdef.setdefault("networkMode", "bridge") - _drop_empty_keys_from_task_definition(taskdef_1) - _drop_empty_keys_from_task_definition(taskdef_2) + _drop_empty_keys_from_dict(taskdef_1) + _drop_empty_keys_from_dict(taskdef_2) # Clear fields that change on registration for comparison for field in ECS_POST_REGISTRATION_FIELDS: diff --git a/src/integrations/prefect-aws/tests/conftest.py b/src/integrations/prefect-aws/tests/conftest.py index a6ac4e6c75b0..f6a9132ff519 100644 --- a/src/integrations/prefect-aws/tests/conftest.py +++ b/src/integrations/prefect-aws/tests/conftest.py @@ -21,13 +21,13 @@ def prefect_db(): @pytest.fixture -def aws_credentials(): +async def aws_credentials(): block = AwsCredentials( aws_access_key_id="access_key_id", aws_secret_access_key="secret_access_key", region_name="us-east-1", ) - block.save("test-creds-block", overwrite=True) + await block.save("test-creds-block", overwrite=True) return block diff --git a/src/integrations/prefect-aws/tests/test_s3.py b/src/integrations/prefect-aws/tests/test_s3.py index d16641a33280..05b688c6805e 100644 --- a/src/integrations/prefect-aws/tests/test_s3.py +++ b/src/integrations/prefect-aws/tests/test_s3.py @@ -18,7 +18,6 @@ ) from prefect import flow -from prefect.deployments import Deployment aws_clients = [ "aws_client_parameters_custom_endpoint", @@ -750,21 +749,6 @@ def test_write_path_in_sync_context(s3_bucket): assert content == b"hello" -def test_deployment_default_basepath(s3_bucket): - deployment = Deployment(name="testing", storage=s3_bucket) - assert deployment.location == "/" - - -def test_deployment_set_basepath(aws_creds_block): - s3_bucket_block = S3Bucket( - bucket_name=BUCKET_NAME, - credentials=aws_creds_block, - bucket_folder="home", - ) - deployment = Deployment(name="testing", storage=s3_bucket_block) - assert deployment.location == "home/" - - def test_resolve_path(s3_bucket): assert s3_bucket._resolve_path("") == "" diff --git a/src/integrations/prefect-aws/tests/test_secrets_manager.py b/src/integrations/prefect-aws/tests/test_secrets_manager.py index 3d08be0d1c9b..b1a479125eb2 100644 --- a/src/integrations/prefect-aws/tests/test_secrets_manager.py +++ b/src/integrations/prefect-aws/tests/test_secrets_manager.py @@ -1,6 +1,7 @@ -from datetime import datetime, timedelta +from datetime import timedelta import boto3 +import pendulum import pytest from moto import mock_secretsmanager from prefect_aws.secrets_manager import ( @@ -158,10 +159,10 @@ async def test_flow(): if not force_delete_without_recovery: assert deletion_date.date() == ( - datetime.utcnow().date() + timedelta(days=recovery_window_in_days) + pendulum.now("UTC").date() + timedelta(days=recovery_window_in_days) ) else: - assert deletion_date.date() == datetime.utcnow().date() + assert deletion_date.date() == pendulum.now("UTC").date() class TestAwsSecret: diff --git a/src/integrations/prefect-aws/tests/workers/test_ecs_worker.py b/src/integrations/prefect-aws/tests/workers/test_ecs_worker.py index 59a1fa6eadcb..1b68e94738aa 100644 --- a/src/integrations/prefect-aws/tests/workers/test_ecs_worker.py +++ b/src/integrations/prefect-aws/tests/workers/test_ecs_worker.py @@ -10,19 +10,9 @@ import botocore import pytest import yaml +from exceptiongroup import ExceptionGroup, catch from moto import mock_ec2, mock_ecs, mock_logs from moto.ec2.utils import generate_instance_identity_document -from pydantic import VERSION as PYDANTIC_VERSION - -from prefect.server.schemas.core import FlowRun -from prefect.utilities.asyncutils import run_sync_in_worker_thread -from prefect.utilities.slugify import slugify - -if PYDANTIC_VERSION.startswith("2."): - from pydantic.v1 import ValidationError -else: - from pydantic import ValidationError - from prefect_aws.credentials import _get_client_cached from prefect_aws.workers.ecs_worker import ( _TAG_REGEX, @@ -42,6 +32,11 @@ mask_sensitive_env_values, parse_identifier, ) +from pydantic import ValidationError + +from prefect.server.schemas.core import FlowRun +from prefect.utilities.asyncutils import run_sync_in_worker_thread +from prefect.utilities.slugify import slugify TEST_TASK_DEFINITION_YAML = """ containerDefinitions: @@ -71,14 +66,6 @@ def reset_task_definition_cache(): yield -@pytest.fixture(autouse=True) -def patch_task_watch_poll_interval(monkeypatch): - # Patch the poll interval to be way shorter for speed during testing! - monkeypatch.setattr( - ECSVariables.__fields__["task_watch_poll_interval"], "default", 0.05 - ) - - def inject_moto_patches(moto_mock, patches: Dict[str, List[Callable]]): def injected_call(method, patch_list, *args, **kwargs): for patch in patch_list: @@ -312,14 +299,14 @@ def ecs_mocks( async def construct_configuration(**options): - variables = ECSVariables(**options) - print(f"Using variables: {variables.json(indent=2)}") + variables = ECSVariables(**options | {"task_watch_poll_interval": 0.03}) + print(f"Using variables: {variables.model_dump_json(indent=2, exclude_none=True)}") configuration = await ECSJobConfiguration.from_template_and_values( base_job_template=ECSWorker.get_default_base_job_template(), values={**variables.model_dump(exclude_none=True)}, ) - print(f"Constructed test configuration: {configuration.json(indent=2)}") + print(f"Constructed test configuration: {configuration.model_dump_json(indent=2)}") return configuration @@ -327,8 +314,10 @@ async def construct_configuration(**options): async def construct_configuration_with_job_template( template_overrides: dict, **variables: dict ): - variables = ECSVariables(**variables) - print(f"Using variables: {variables.json(indent=2)}") + variables: ECSVariables = ECSVariables( + **variables | {"task_watch_poll_interval": 0.03} + ) + print(f"Using variables: {variables.model_dump_json(indent=2)}") base_template = ECSWorker.get_default_base_job_template() for key in template_overrides: @@ -339,11 +328,13 @@ async def construct_configuration_with_job_template( f" {json.dumps(base_template['job_configuration'], indent=2)}" ) - configuration = await ECSJobConfiguration.from_template_and_values( - base_job_template=base_template, - values={**variables.model_dump(exclude_none=True)}, + configuration: ECSJobConfiguration = ( + await ECSJobConfiguration.from_template_and_values( + base_job_template=base_template, + values={**variables.model_dump(exclude_none=True)}, + ) ) - print(f"Constructed test configuration: {configuration.json(indent=2)}") + print(f"Constructed test configuration: {configuration.model_dump_json(indent=2)}") return configuration @@ -1132,14 +1123,17 @@ async def test_network_config_from_custom_settings_invalid_subnet( session = aws_credentials.get_boto3_session() - with pytest.raises( - ValueError, - match=( - r"Subnets \['sn-8asdas'\] not found within VPC with ID " - + vpc.id - + r"\.Please check that VPC is associated with supplied subnets\." - ), - ): + def handle_error(exc_group: ExceptionGroup): + assert len(exc_group.exceptions) == 1 + assert isinstance(exc_group.exceptions[0], ExceptionGroup) + assert len(exc_group.exceptions[0].exceptions) == 1 + assert isinstance(exc_group.exceptions[0].exceptions[0], ValueError) + assert ( + f"Subnets ['sn-8asdas'] not found within VPC with ID {vpc.id}." + "Please check that VPC is associated with supplied subnets." + ) in str(exc_group.exceptions[0].exceptions[0]) + + with catch({ValueError: handle_error}): async with ECSWorker(work_pool_name="test") as worker: original_run_task = worker._create_task_run mock_run_task = MagicMock(side_effect=original_run_task) @@ -1174,14 +1168,17 @@ async def test_network_config_from_custom_settings_invalid_subnet_multiple_vpc_s session = aws_credentials.get_boto3_session() - with pytest.raises( - ValueError, - match=( - rf"Subnets \['{invalid_subnet_id}', '{subnet.id}'\] not found within VPC" - f" with ID {vpc.id}.Please check that VPC is associated with supplied" - " subnets." - ), - ): + def handle_error(exc_group: ExceptionGroup): + assert len(exc_group.exceptions) == 1 + assert isinstance(exc_group.exceptions[0], ExceptionGroup) + assert len(exc_group.exceptions[0].exceptions) == 1 + assert isinstance(exc_group.exceptions[0].exceptions[0], ValueError) + assert ( + f"Subnets ['{invalid_subnet_id}', '{subnet.id}'] not found within VPC with ID" + f" {vpc.id}.Please check that VPC is associated with supplied subnets." + ) in str(exc_group.exceptions[0].exceptions[0]) + + with catch({ValueError: handle_error}): async with ECSWorker(work_pool_name="test") as worker: original_run_task = worker._create_task_run mock_run_task = MagicMock(side_effect=original_run_task) @@ -1289,7 +1286,14 @@ async def test_network_config_missing_default_vpc( configuration = await construct_configuration(aws_credentials=aws_credentials) - with pytest.raises(ValueError, match="Failed to find the default VPC"): + def handle_error(exc_grp: ExceptionGroup): + assert len(exc_grp.exceptions) == 1 + assert isinstance(exc_grp.exceptions[0], ExceptionGroup) + exc = exc_grp.exceptions[0].exceptions[0] + assert isinstance(exc, ValueError) + assert "Failed to find the default VPC" in str(exc) + + with catch({ValueError: handle_error}): async with ECSWorker(work_pool_name="test") as worker: await run_then_stop_task(worker, configuration, flow_run) @@ -1307,9 +1311,14 @@ async def test_network_config_from_vpc_with_no_subnets( vpc_id=vpc.id, ) - with pytest.raises( - ValueError, match=f"Failed to find subnets for VPC with ID {vpc.id}" - ): + def handle_error(exc_grp: ExceptionGroup): + assert len(exc_grp.exceptions) == 1 + assert isinstance(exc_grp.exceptions[0], ExceptionGroup) + exc = exc_grp.exceptions[0].exceptions[0] + assert isinstance(exc, ValueError) + assert "Failed to find subnets for VPC with ID" in str(exc) + + with catch({ValueError: handle_error}): async with ECSWorker(work_pool_name="test") as worker: await run_then_stop_task(worker, configuration, flow_run) @@ -1327,13 +1336,17 @@ async def test_bridge_network_mode_raises_on_fargate( template_overrides=dict(task_definition={"networkMode": "bridge"}), ) - with pytest.raises( - ValueError, - match=( - "Found network mode 'bridge' which is not compatible with launch type " - f"{launch_type!r}" - ), - ): + def handle_error(exc_grp: ExceptionGroup): + assert len(exc_grp.exceptions) == 1 + assert isinstance(exc_grp.exceptions[0], ExceptionGroup) + exc = exc_grp.exceptions[0].exceptions[0] + assert isinstance(exc, ValueError) + assert ( + "Found network mode 'bridge' which is not compatible with launch type" + in str(exc) + ) + + with catch({ValueError: handle_error}): async with ECSWorker(work_pool_name="test") as worker: await run_then_stop_task(worker, configuration, flow_run) @@ -1416,10 +1429,14 @@ async def test_run_task_error_handling( "botocore.client.BaseClient._make_api_call", new=mock_make_api_call ): async with ECSWorker(work_pool_name="test") as worker: - with pytest.raises(RuntimeError, match="Failed to run ECS task") as exc: - await run_then_stop_task(worker, configuration, flow_run) - assert exc.value.args[0] == "Failed to run ECS task: string" + def handle_error(exc_grp: ExceptionGroup): + assert len(exc_grp.exceptions) == 1 + assert isinstance(exc_grp.exceptions[0], RuntimeError) + assert exc_grp.exceptions[0].args[0] == "Failed to run ECS task: string" + + with catch({RuntimeError: handle_error}): + await run_then_stop_task(worker, configuration, flow_run) @pytest.mark.usefixtures("ecs_mocks") @@ -1694,7 +1711,9 @@ async def test_worker_task_definition_cache_is_per_deployment_id( async with ECSWorker(work_pool_name="test") as worker: result_1 = await run_then_stop_task(worker, configuration, flow_run) result_2 = await run_then_stop_task( - worker, configuration, flow_run.copy(update=dict(deployment_id=uuid4())) + worker, + configuration, + flow_run.model_copy(update=dict(deployment_id=uuid4())), ) assert result_2.status_code == 0 @@ -1781,7 +1800,6 @@ async def test_worker_task_definition_cache_miss_on_deregistered( {"env": {"FOO": "BAR"}}, {"command": "test"}, {"labels": {"FOO": "BAR"}}, - {"stream_output": True, "configure_cloudwatch_logs": False}, {"cluster": "test"}, {"task_role_arn": "test"}, # Note: null environment variables can cause override, but not when missing @@ -2164,7 +2182,6 @@ async def test_user_defined_environment_variables_in_task_run_request_template( "environment": [ {"name": "BAR", "value": "FOO"}, {"name": "OVERRIDE", "value": "OLD"}, - {"name": "UNSET", "value": "GONE"}, ], } ], @@ -2260,14 +2277,13 @@ async def test_kill_infrastructure(aws_credentials, cluster: str, flow_run): cluster=cluster, ) - with anyio.fail_after(5): - async with ECSWorker(work_pool_name="test") as worker: - async with anyio.create_task_group() as tg: - identifier = await tg.start(worker.run, flow_run, configuration) + async with ECSWorker(work_pool_name="test") as worker: + async with anyio.create_task_group() as tg: + identifier = await tg.start(worker.run, flow_run, configuration) - await worker.kill_infrastructure( - configuration=configuration, infrastructure_pid=identifier - ) + await worker.kill_infrastructure( + configuration=configuration, infrastructure_pid=identifier + ) _, task_arn = parse_identifier(identifier) task = describe_task(ecs_client, task_arn) @@ -2280,7 +2296,7 @@ async def test_kill_infrastructure_with_invalid_identifier(aws_credentials): aws_credentials=aws_credentials, ) - with pytest.raises(ValueError): + with catch({ValueError: lambda exc_group: None}): async with ECSWorker(work_pool_name="test") as worker: await worker.kill_infrastructure(configuration, "test") @@ -2292,13 +2308,16 @@ async def test_kill_infrastructure_with_mismatched_cluster(aws_credentials): cluster="foo", ) - with pytest.raises( - InfrastructureNotAvailable, - match=( - "Cannot stop ECS task: this infrastructure block has access to cluster " - "'foo' but the task is running in cluster 'bar'." - ), - ): + def handle_error(exc_group: ExceptionGroup): + assert len(exc_group.exceptions) == 1 + assert isinstance(exc_group.exceptions[0], InfrastructureNotAvailable) + assert ( + "Cannot stop ECS task: this infrastructure block has access to cluster" + " 'foo' but the task is running in cluster 'bar'." + in str(exc_group.exceptions[0]) + ) + + with catch({InfrastructureNotAvailable: handle_error}): async with ECSWorker(work_pool_name="test") as worker: await worker.kill_infrastructure(configuration, "bar:::task_arn") @@ -2310,10 +2329,14 @@ async def test_kill_infrastructure_with_cluster_that_does_not_exist(aws_credenti cluster="foo", ) - with pytest.raises( - InfrastructureNotFound, - match="Cannot stop ECS task: the cluster 'foo' could not be found.", - ): + def handle_error(exc_group: ExceptionGroup): + assert len(exc_group.exceptions) == 1 + assert isinstance(exc_group.exceptions[0], InfrastructureNotFound) + assert "Cannot stop ECS task: the cluster 'foo' could not be found." in str( + exc_group.exceptions[0] + ) + + with catch({InfrastructureNotFound: handle_error}): async with ECSWorker(work_pool_name="test") as worker: await worker.kill_infrastructure(configuration, "foo::task_arn") @@ -2331,13 +2354,15 @@ async def test_kill_infrastructure_with_task_that_does_not_exist( async with ECSWorker(work_pool_name="test") as worker: await run_then_stop_task(worker, configuration, flow_run) - with pytest.raises( - InfrastructureNotFound, - match=( + def handle_error(exc_group: ExceptionGroup): + assert len(exc_group.exceptions) == 1 + assert isinstance(exc_group.exceptions[0], InfrastructureNotFound) + assert ( "Cannot stop ECS task: the task 'foo' could not be found in cluster" - " 'default'" - ), - ): + " 'default'" in str(exc_group.exceptions[0]) + ) + + with catch({InfrastructureNotFound: handle_error}): await worker.kill_infrastructure(configuration, "default::foo") @@ -2348,10 +2373,14 @@ async def test_kill_infrastructure_with_cluster_that_has_no_tasks(aws_credential cluster="default", ) - with pytest.raises( - InfrastructureNotFound, - match="Cannot stop ECS task: the cluster 'default' has no tasks.", - ): + def handle_error(exc_group: ExceptionGroup): + assert len(exc_group.exceptions) == 1 + assert isinstance(exc_group.exceptions[0], InfrastructureNotFound) + assert "Cannot stop ECS task: the cluster 'default' has no tasks." in str( + exc_group.exceptions[0] + ) + + with catch({InfrastructureNotFound: handle_error}): async with ECSWorker(work_pool_name="test") as worker: await worker.kill_infrastructure(configuration, "default::foo") @@ -2417,7 +2446,7 @@ async def test_retry_on_failed_task_start( }, ) - with pytest.raises(RuntimeError): + with catch({RuntimeError: lambda exc_group: None}): async with ECSWorker(work_pool_name="test") as worker: await run_then_stop_task(worker, configuration, flow_run)