Skip to content

Commit

Permalink
Fixing MyPy issues inside tests/providers/amazon (apache#20561)
Browse files Browse the repository at this point in the history
  • Loading branch information
khalidmammadov authored Dec 30, 2021
1 parent e07e831 commit 488ed66
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 45 deletions.
16 changes: 10 additions & 6 deletions tests/providers/amazon/aws/hooks/test_dms_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import json
import unittest
from typing import Any, Dict
from unittest import mock

import pytest
Expand Down Expand Up @@ -55,12 +56,15 @@
'ReplicationTaskArn': MOCK_TASK_ARN,
'Status': 'creating',
}
MOCK_DESCRIBE_RESPONSE = {'ReplicationTasks': [MOCK_TASK_RESPONSE_DATA]}
MOCK_DESCRIBE_RESPONSE_WITH_MARKER = {'ReplicationTasks': [MOCK_TASK_RESPONSE_DATA], 'Marker': 'marker'}
MOCK_CREATE_RESPONSE = {'ReplicationTask': MOCK_TASK_RESPONSE_DATA}
MOCK_START_RESPONSE = {'ReplicationTask': {**MOCK_TASK_RESPONSE_DATA, 'Status': 'starting'}}
MOCK_STOP_RESPONSE = {'ReplicationTask': {**MOCK_TASK_RESPONSE_DATA, 'Status': 'stopping'}}
MOCK_DELETE_RESPONSE = {'ReplicationTask': {**MOCK_TASK_RESPONSE_DATA, 'Status': 'deleting'}}
MOCK_DESCRIBE_RESPONSE: Dict[str, Any] = {'ReplicationTasks': [MOCK_TASK_RESPONSE_DATA]}
MOCK_DESCRIBE_RESPONSE_WITH_MARKER: Dict[str, Any] = {
'ReplicationTasks': [MOCK_TASK_RESPONSE_DATA],
'Marker': 'marker',
}
MOCK_CREATE_RESPONSE: Dict[str, Any] = {'ReplicationTask': MOCK_TASK_RESPONSE_DATA}
MOCK_START_RESPONSE: Dict[str, Any] = {'ReplicationTask': {**MOCK_TASK_RESPONSE_DATA, 'Status': 'starting'}}
MOCK_STOP_RESPONSE: Dict[str, Any] = {'ReplicationTask': {**MOCK_TASK_RESPONSE_DATA, 'Status': 'stopping'}}
MOCK_DELETE_RESPONSE: Dict[str, Any] = {'ReplicationTask': {**MOCK_TASK_RESPONSE_DATA, 'Status': 'deleting'}}


class TestDmsHook(unittest.TestCase):
Expand Down
37 changes: 17 additions & 20 deletions tests/providers/amazon/aws/hooks/test_eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#
import sys
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Type
from unittest import mock
Expand Down Expand Up @@ -124,9 +125,7 @@ def __init__(self, count: int, minimal: bool) -> None:
inputs=ClusterInputs, cluster_name=self.existing_cluster_name
)

def _execute(
count: Optional[int] = 1, minimal: Optional[bool] = True
) -> Tuple[EksHook, ClusterTestDataFactory]:
def _execute(count: int = 1, minimal: bool = True) -> Tuple[EksHook, ClusterTestDataFactory]:
return eks_hook, ClusterTestDataFactory(count=count, minimal=minimal)

mock_eks().start()
Expand Down Expand Up @@ -173,9 +172,7 @@ def __init__(self, count: int, minimal: bool) -> None:
fargate_profile_name=self.existing_fargate_profile_name,
)

def _execute(
count: Optional[int] = 1, minimal: Optional[bool] = True
) -> Tuple[EksHook, FargateProfileTestDataFactory]:
def _execute(count: int = 1, minimal: bool = True) -> Tuple[EksHook, FargateProfileTestDataFactory]:
return eks_hook, FargateProfileTestDataFactory(count=count, minimal=minimal)

eks_hook, cluster = cluster_builder()
Expand Down Expand Up @@ -217,9 +214,7 @@ def __init__(self, count: int, minimal: bool) -> None:
nodegroup_name=self.existing_nodegroup_name,
)

def _execute(
count: Optional[int] = 1, minimal: Optional[bool] = True
) -> Tuple[EksHook, NodegroupTestDataFactory]:
def _execute(count: int = 1, minimal: bool = True) -> Tuple[EksHook, NodegroupTestDataFactory]:
return eks_hook, NodegroupTestDataFactory(count=count, minimal=minimal)

eks_hook, cluster = cluster_builder()
Expand Down Expand Up @@ -280,7 +275,7 @@ def test_create_cluster_throws_exception_when_cluster_exists(

with pytest.raises(ClientError) as raised_exception:
eks_hook.create_cluster(
name=generated_test_data.existing_cluster_name, **dict(ClusterInputs.REQUIRED)
name=generated_test_data.existing_cluster_name, **dict(ClusterInputs.REQUIRED) # type: ignore
)

assert_client_error_exception_thrown(
Expand Down Expand Up @@ -311,7 +306,7 @@ def test_create_cluster_generates_valid_cluster_arn(self, cluster_builder) -> No
def test_create_cluster_generates_valid_cluster_created_timestamp(self, cluster_builder) -> None:
_, generated_test_data = cluster_builder()

result_time: str = generated_test_data.cluster_describe_output[ClusterAttributes.CREATED_AT]
result_time: datetime = generated_test_data.cluster_describe_output[ClusterAttributes.CREATED_AT]

assert iso_date(result_time) == FROZEN_TIME

Expand Down Expand Up @@ -441,7 +436,7 @@ def test_create_nodegroup_throws_exception_when_cluster_not_found(self) -> None:
eks_hook.create_nodegroup(
clusterName=non_existent_cluster_name,
nodegroupName=non_existent_nodegroup_name,
**dict(NodegroupInputs.REQUIRED),
**dict(NodegroupInputs.REQUIRED), # type: ignore
)

assert_client_error_exception_thrown(
Expand All @@ -464,7 +459,7 @@ def test_create_nodegroup_throws_exception_when_nodegroup_already_exists(
eks_hook.create_nodegroup(
clusterName=generated_test_data.cluster_name,
nodegroupName=generated_test_data.existing_nodegroup_name,
**dict(NodegroupInputs.REQUIRED),
**dict(NodegroupInputs.REQUIRED), # type: ignore
)

assert_client_error_exception_thrown(
Expand Down Expand Up @@ -493,7 +488,7 @@ def test_create_nodegroup_throws_exception_when_cluster_not_active(
eks_hook.create_nodegroup(
clusterName=generated_test_data.cluster_name,
nodegroupName=non_existent_nodegroup_name,
**dict(NodegroupInputs.REQUIRED),
**dict(NodegroupInputs.REQUIRED), # type: ignore
)

assert_client_error_exception_thrown(
Expand Down Expand Up @@ -528,15 +523,15 @@ def test_create_nodegroup_generates_valid_nodegroup_arn(self, nodegroup_builder)
def test_create_nodegroup_generates_valid_nodegroup_created_timestamp(self, nodegroup_builder) -> None:
_, generated_test_data = nodegroup_builder()

result_time: str = generated_test_data.nodegroup_describe_output[NodegroupAttributes.CREATED_AT]
result_time: datetime = generated_test_data.nodegroup_describe_output[NodegroupAttributes.CREATED_AT]

assert iso_date(result_time) == FROZEN_TIME

@freeze_time(FROZEN_TIME)
def test_create_nodegroup_generates_valid_nodegroup_modified_timestamp(self, nodegroup_builder) -> None:
_, generated_test_data = nodegroup_builder()

result_time: str = generated_test_data.nodegroup_describe_output[NodegroupAttributes.MODIFIED_AT]
result_time: datetime = generated_test_data.nodegroup_describe_output[NodegroupAttributes.MODIFIED_AT]

assert iso_date(result_time) == FROZEN_TIME

Expand Down Expand Up @@ -813,7 +808,7 @@ def test_create_fargate_profile_throws_exception_when_cluster_not_found(self) ->
eks_hook.create_fargate_profile(
clusterName=non_existent_cluster_name,
fargateProfileName=non_existent_fargate_profile_name,
**dict(FargateProfileInputs.REQUIRED),
**dict(FargateProfileInputs.REQUIRED), # type: ignore
)

assert_client_error_exception_thrown(
Expand All @@ -833,7 +828,7 @@ def test_create_fargate_profile_throws_exception_when_fargate_profile_already_ex
eks_hook.create_fargate_profile(
clusterName=generated_test_data.cluster_name,
fargateProfileName=generated_test_data.existing_fargate_profile_name,
**dict(FargateProfileInputs.REQUIRED),
**dict(FargateProfileInputs.REQUIRED), # type: ignore
)

assert_client_error_exception_thrown(
Expand Down Expand Up @@ -862,7 +857,7 @@ def test_create_fargate_profile_throws_exception_when_cluster_not_active(
eks_hook.create_fargate_profile(
clusterName=generated_test_data.cluster_name,
fargateProfileName=non_existent_fargate_profile_name,
**dict(FargateProfileInputs.REQUIRED),
**dict(FargateProfileInputs.REQUIRED), # type: ignore
)

assert_client_error_exception_thrown(
Expand Down Expand Up @@ -897,7 +892,9 @@ def test_create_fargate_profile_generates_valid_profile_arn(self, fargate_profil
def test_create_fargate_profile_generates_valid_created_timestamp(self, fargate_profile_builder) -> None:
_, generated_test_data = fargate_profile_builder()

result_time: str = generated_test_data.fargate_describe_output[FargateProfileAttributes.CREATED_AT]
result_time: datetime = generated_test_data.fargate_describe_output[
FargateProfileAttributes.CREATED_AT
]

assert iso_date(result_time) == FROZEN_TIME

Expand Down
10 changes: 5 additions & 5 deletions tests/providers/amazon/aws/operators/test_eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class CreateNodegroupParams(TypedDict):
class TestEksCreateClusterOperator(unittest.TestCase):
def setUp(self) -> None:
# Parameters which are needed to create a cluster.
self.create_cluster_params: ClusterParams = dict(
self.create_cluster_params: ClusterParams = dict( # type: ignore
cluster_name=CLUSTER_NAME,
cluster_role_arn=ROLE_ARN[1],
resources_vpc_config=RESOURCES_VPC_CONFIG[1],
Expand All @@ -101,7 +101,7 @@ def setUp(self) -> None:
def nodegroup_setUp(self) -> None:
# Parameters which are added to the cluster parameters
# when creating both the cluster and nodegroup together.
self.base_nodegroup_params: NodeGroupParams = dict(
self.base_nodegroup_params: NodeGroupParams = dict( # type: ignore
nodegroup_name=NODEGROUP_NAME,
nodegroup_role_arn=NODEROLE_ARN[1],
)
Expand All @@ -122,7 +122,7 @@ def nodegroup_setUp(self) -> None:
def fargate_profile_setUp(self) -> None:
# Parameters which are added to the cluster parameters
# when creating both the cluster and Fargate profile together.
self.base_fargate_profile_params: BaseFargateProfileParams = dict(
self.base_fargate_profile_params: BaseFargateProfileParams = dict( # type: ignore
fargate_profile_name=FARGATE_PROFILE_NAME,
fargate_pod_execution_role_arn=POD_EXECUTION_ROLE_ARN[1],
fargate_selectors=SELECTORS[1],
Expand Down Expand Up @@ -180,7 +180,7 @@ def test_execute_when_called_with_fargate_creates_both(

class TestEksCreateFargateProfileOperator(unittest.TestCase):
def setUp(self) -> None:
self.create_fargate_profile_params: CreateFargateProfileParams = dict(
self.create_fargate_profile_params: CreateFargateProfileParams = dict( # type: ignore
cluster_name=CLUSTER_NAME,
pod_execution_role_arn=POD_EXECUTION_ROLE_ARN[1],
selectors=SELECTORS[1],
Expand All @@ -202,7 +202,7 @@ def test_execute_when_fargate_profile_does_not_already_exist(self, mock_create_f

class TestEksCreateNodegroupOperator(unittest.TestCase):
def setUp(self) -> None:
self.create_nodegroup_params: CreateNodegroupParams = dict(
self.create_nodegroup_params: CreateNodegroupParams = dict( # type: ignore
cluster_name=CLUSTER_NAME,
nodegroup_name=NODEGROUP_NAME,
nodegroup_subnets=SUBNET_IDS,
Expand Down
14 changes: 7 additions & 7 deletions tests/providers/amazon/aws/utils/eks_test_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"""
import re
from enum import Enum
from typing import Dict, List, Pattern, Tuple
from typing import Any, Dict, List, Pattern, Tuple

DEFAULT_CONN_ID: str = "aws_default"
DEFAULT_NAMESPACE: str = "default_namespace"
Expand Down Expand Up @@ -96,8 +96,8 @@ class ErrorAttributes:
class ClusterInputs:
"""All possible inputs for creating an EKS Cluster."""

REQUIRED: List[Tuple] = [ROLE_ARN, RESOURCES_VPC_CONFIG]
OPTIONAL: List[Tuple] = [
REQUIRED: List[Tuple[str, Any]] = [ROLE_ARN, RESOURCES_VPC_CONFIG]
OPTIONAL: List[Tuple[str, Any]] = [
CLIENT_REQUEST_TOKEN,
ENCRYPTION_CONFIG,
LOGGING,
Expand All @@ -108,15 +108,15 @@ class ClusterInputs:


class FargateProfileInputs:
REQUIRED: List[Tuple] = [POD_EXECUTION_ROLE_ARN, SELECTORS]
OPTIONAL: List[Tuple] = [SUBNETS, TAGS]
REQUIRED: List[Tuple[str, Any]] = [POD_EXECUTION_ROLE_ARN, SELECTORS]
OPTIONAL: List[Tuple[str, Any]] = [SUBNETS, TAGS]


class NodegroupInputs:
"""All possible inputs for creating an EKS Managed Nodegroup."""

REQUIRED: List[Tuple] = [NODEROLE_ARN, SUBNETS]
OPTIONAL: List[Tuple] = [
REQUIRED: List[Tuple[str, Any]] = [NODEROLE_ARN, SUBNETS]
OPTIONAL: List[Tuple[str, Any]] = [
AMI_TYPE,
DISK_SIZE,
INSTANCE_TYPES,
Expand Down
15 changes: 8 additions & 7 deletions tests/providers/amazon/aws/utils/eks_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# specific language governing permissions and limitations
# under the License.
#
import datetime
import re
from copy import deepcopy
from typing import Dict, List, Optional, Pattern, Tuple, Type, Union
Expand Down Expand Up @@ -58,7 +59,7 @@ def attributes_to_test(
:return: Returns a list of tuples containing the keys and values to be validated in testing.
:rtype: List[Tuple]
"""
result: List[Tuple] = deepcopy(inputs.REQUIRED + inputs.OPTIONAL + [STATUS])
result: List[Tuple] = deepcopy(inputs.REQUIRED + inputs.OPTIONAL + [STATUS]) # type: ignore
if inputs == ClusterInputs:
result += [(ClusterAttributes.NAME, cluster_name)]
elif inputs == FargateProfileInputs:
Expand Down Expand Up @@ -195,13 +196,13 @@ def _input_builder(options: InputTypes, minimal: bool) -> Dict:
:type options: InputTypes
:param minimal: If True, only the required values are generated; if False all values are generated.
:type minimal: bool
:return: Returns a list of tuples containing the keys and values to be validated in testing.
:rtype: List[Tuple]
:return: Returns a dict containing the keys and values to be validated in testing.
:rtype: Dict
"""
values: List[Tuple] = deepcopy(options.REQUIRED)
values: List[Tuple] = deepcopy(options.REQUIRED) # type: ignore
if not minimal:
values.extend(deepcopy(options.OPTIONAL))
return dict(values)
return dict(values) # type: ignore


def string_to_regex(value: str) -> Pattern[str]:
Expand Down Expand Up @@ -259,8 +260,8 @@ def convert_keys(original: Dict) -> Dict:
return {conversion_map[k]: v for (k, v) in deepcopy(original).items()}


def iso_date(datetime: str) -> str:
return datetime.strftime("%Y-%m-%dT%H:%M:%S") + "Z"
def iso_date(input_datetime: datetime.datetime) -> str:
return input_datetime.strftime("%Y-%m-%dT%H:%M:%S") + "Z"


def generate_dict(prefix, count) -> Dict:
Expand Down

0 comments on commit 488ed66

Please sign in to comment.