Skip to content

Commit

Permalink
Increase typing for Apache and http provider package (apache#9729)
Browse files Browse the repository at this point in the history
  • Loading branch information
ephraimbuddy authored Jul 19, 2020
1 parent 750555f commit 4d74ac2
Show file tree
Hide file tree
Showing 37 changed files with 904 additions and 667 deletions.
6 changes: 3 additions & 3 deletions airflow/hooks/base_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
"""Base class for all hooks"""
import logging
import random
from typing import List
from typing import Any, List

from airflow import secrets
from airflow.models import Connection
from airflow.models.connection import Connection
from airflow.utils.log.logging_mixin import LoggingMixin

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -82,6 +82,6 @@ def get_hook(cls, conn_id: str) -> "BaseHook":
connection = cls.get_connection(conn_id)
return connection.get_hook()

def get_conn(self):
def get_conn(self) -> Any:
"""Returns connection for the hook."""
raise NotImplementedError()
5 changes: 3 additions & 2 deletions airflow/providers/apache/cassandra/sensors/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
of a record in a Cassandra cluster.
"""

from typing import Dict
from typing import Any, Dict, Tuple

from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook
from airflow.sensors.base_sensor_operator import BaseSensorOperator
Expand Down Expand Up @@ -56,7 +56,8 @@ class CassandraRecordSensor(BaseSensorOperator):
template_fields = ('table', 'keys')

@apply_defaults
def __init__(self, table: str, keys: Dict[str, str], cassandra_conn_id: str, *args, **kwargs) -> None:
def __init__(self, table: str, keys: Dict[str, str], cassandra_conn_id: str,
*args: Tuple[Any, ...], **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.cassandra_conn_id = cassandra_conn_id
self.table = table
Expand Down
7 changes: 4 additions & 3 deletions airflow/providers/apache/cassandra/sensors/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
of a table in a Cassandra cluster.
"""

from typing import Dict
from typing import Any, Dict, Tuple

from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook
from airflow.sensors.base_sensor_operator import BaseSensorOperator
Expand Down Expand Up @@ -54,12 +54,13 @@ class CassandraTableSensor(BaseSensorOperator):
template_fields = ('table',)

@apply_defaults
def __init__(self, table: str, cassandra_conn_id: str, *args, **kwargs) -> None:
def __init__(self, table: str, cassandra_conn_id: str, *args: Tuple[Any, ...],
**kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.cassandra_conn_id = cassandra_conn_id
self.table = table

def poke(self, context: Dict) -> bool:
def poke(self, context: Dict[Any, Any]) -> bool:
self.log.info('Sensor check existence of table: %s', self.table)
hook = CassandraHook(self.cassandra_conn_id)
return hook.table_exists(self.table)
30 changes: 18 additions & 12 deletions airflow/providers/apache/druid/hooks/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.

import time
from typing import Any, Dict, Iterable, Optional, Tuple

import requests
from pydruid.db import connect
Expand All @@ -43,11 +44,13 @@ class DruidHook(BaseHook):
:param max_ingestion_time: The maximum ingestion time before assuming the job failed
:type max_ingestion_time: int
"""

def __init__(
self,
druid_ingest_conn_id='druid_ingest_default',
timeout=1,
max_ingestion_time=None):
self,
druid_ingest_conn_id: str = 'druid_ingest_default',
timeout: int = 1,
max_ingestion_time: Optional[int] = None
) -> None:

super().__init__()
self.druid_ingest_conn_id = druid_ingest_conn_id
Expand All @@ -58,7 +61,7 @@ def __init__(
if self.timeout < 1:
raise ValueError("Druid timeout should be equal or greater than 1")

def get_conn_url(self):
def get_conn_url(self) -> str:
"""
Get Druid connection url
"""
Expand All @@ -70,7 +73,7 @@ def get_conn_url(self):
return "{conn_type}://{host}:{port}/{endpoint}".format(
conn_type=conn_type, host=host, port=port, endpoint=endpoint)

def get_auth(self):
def get_auth(self) -> Optional[requests.auth.HTTPBasicAuth]:
"""
Return username and password from connections tab as requests.auth.HTTPBasicAuth object.
Expand All @@ -84,7 +87,7 @@ def get_auth(self):
else:
return None

def submit_indexing_job(self, json_index_spec: str):
def submit_indexing_job(self, json_index_spec: Dict[str, Any]) -> None:
"""
Submit Druid ingestion job
"""
Expand Down Expand Up @@ -144,11 +147,11 @@ class DruidDbApiHook(DbApiHook):
default_conn_name = 'druid_broker_default'
supports_autocommit = False

def get_conn(self):
def get_conn(self) -> connect:
"""
Establish a connection to druid broker.
"""
conn = self.get_connection(self.druid_broker_conn_id) # pylint: disable=no-member
conn = self.get_connection(self.conn_name_attr)
druid_broker_conn = connect(
host=conn.host,
port=conn.port,
Expand All @@ -160,7 +163,7 @@ def get_conn(self):
self.log.info('Get the connection to druid broker on %s using user %s', conn.host, conn.login)
return druid_broker_conn

def get_uri(self):
def get_uri(self) -> str:
"""
Get the connection uri for druid broker.
Expand All @@ -175,8 +178,11 @@ def get_uri(self):
return '{conn_type}://{host}/{endpoint}'.format(
conn_type=conn_type, host=host, endpoint=endpoint)

def set_autocommit(self, conn, autocommit):
def set_autocommit(self, conn: connect, autocommit: bool) -> NotImplemented:
raise NotImplementedError()

def insert_rows(self, table, rows, target_fields=None, commit_every=1000):
def insert_rows(self, table: str, rows: Iterable[Tuple[str]],
target_fields: Optional[Iterable[str]] = None,
commit_every: int = 1000, replace: bool = False,
**kwargs: Any) -> NotImplemented:
raise NotImplementedError()
11 changes: 6 additions & 5 deletions airflow/providers/apache/druid/operators/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.

import json
from typing import Any, Dict, Optional

from airflow.models import BaseOperator
from airflow.providers.apache.druid.hooks.druid import DruidHook
Expand All @@ -37,16 +38,16 @@ class DruidOperator(BaseOperator):
template_ext = ('.json',)

@apply_defaults
def __init__(self, json_index_file,
druid_ingest_conn_id='druid_ingest_default',
max_ingestion_time=None,
*args, **kwargs):
def __init__(self, json_index_file: str,
druid_ingest_conn_id: str = 'druid_ingest_default',
max_ingestion_time: Optional[int] = None,
*args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.json_index_file = json_index_file
self.conn_id = druid_ingest_conn_id
self.max_ingestion_time = max_ingestion_time

def execute(self, context):
def execute(self, context: Dict[Any, Any]) -> None:
hook = DruidHook(
druid_ingest_conn_id=self.conn_id,
max_ingestion_time=self.max_ingestion_time
Expand Down
16 changes: 9 additions & 7 deletions airflow/providers/apache/druid/operators/druid_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, Optional

from airflow.exceptions import AirflowException
from airflow.operators.check_operator import CheckOperator
Expand Down Expand Up @@ -57,21 +58,22 @@ class DruidCheckOperator(CheckOperator):

@apply_defaults
def __init__(
self,
sql: str,
druid_broker_conn_id: str = 'druid_broker_default',
*args, **kwargs) -> None:
self,
sql: str,
druid_broker_conn_id: str = 'druid_broker_default',
*args: Any, **kwargs: Any
) -> None:
super().__init__(sql=sql, *args, **kwargs)
self.druid_broker_conn_id = druid_broker_conn_id
self.sql = sql

def get_db_hook(self):
def get_db_hook(self) -> DruidDbApiHook:
"""
Return the druid db api hook.
"""
return DruidDbApiHook(druid_broker_conn_id=self.druid_broker_conn_id)

def get_first(self, sql):
def get_first(self, sql: str) -> Any:
"""
Executes the druid sql to druid broker and returns the first resulting row.
Expand All @@ -82,7 +84,7 @@ def get_first(self, sql):
cur.execute(sql)
return cur.fetchone()

def execute(self, context=None):
def execute(self, context: Optional[Dict[Any, Any]] = None) -> None:
self.log.info('Executing SQL check: %s', self.sql)
record = self.get_first(self.sql)
self.log.info("Record: %s", str(record))
Expand Down
45 changes: 24 additions & 21 deletions airflow/providers/apache/druid/transfers/hive_to_druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
This module contains operator to move data from Hive to Druid.
"""

from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

from airflow.models import BaseOperator
from airflow.providers.apache.druid.hooks.druid import DruidHook
Expand Down Expand Up @@ -84,23 +84,25 @@ class HiveToDruidOperator(BaseOperator):

@apply_defaults
def __init__( # pylint: disable=too-many-arguments
self,
sql: str,
druid_datasource: str,
ts_dim: str,
metric_spec: Optional[List] = None,
hive_cli_conn_id: str = 'hive_cli_default',
druid_ingest_conn_id: str = 'druid_ingest_default',
metastore_conn_id: str = 'metastore_default',
hadoop_dependency_coordinates: Optional[List[str]] = None,
intervals: Optional[List] = None,
num_shards: float = -1,
target_partition_size: int = -1,
query_granularity: str = "NONE",
segment_granularity: str = "DAY",
hive_tblproperties: Optional[Dict] = None,
job_properties: Optional[Dict] = None,
*args, **kwargs) -> None:
self,
sql: str,
druid_datasource: str,
ts_dim: str,
metric_spec: Optional[List[Any]] = None,
hive_cli_conn_id: str = 'hive_cli_default',
druid_ingest_conn_id: str = 'druid_ingest_default',
metastore_conn_id: str = 'metastore_default',
hadoop_dependency_coordinates: Optional[List[str]] = None,
intervals: Optional[List[Any]] = None,
num_shards: float = -1,
target_partition_size: int = -1,
query_granularity: str = "NONE",
segment_granularity: str = "DAY",
hive_tblproperties: Optional[Dict[Any, Any]] = None,
job_properties: Optional[Dict[Any, Any]] = None,
*args: Any,
**kwargs: Any
) -> None:
super().__init__(*args, **kwargs)
self.sql = sql
self.druid_datasource = druid_datasource
Expand All @@ -120,7 +122,7 @@ def __init__( # pylint: disable=too-many-arguments
self.hive_tblproperties = hive_tblproperties or {}
self.job_properties = job_properties

def execute(self, context):
def execute(self, context: Dict[str, Any]) -> None:
hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
self.log.info("Extracting data from Hive")
hive_table = 'druid.' + context['task_instance_key_str'].replace('.', '_')
Expand Down Expand Up @@ -172,7 +174,8 @@ def execute(self, context):
hql = "DROP TABLE IF EXISTS {}".format(hive_table)
hive.run_cli(hql)

def construct_ingest_query(self, static_path, columns):
def construct_ingest_query(self, static_path: str,
columns: List[str]) -> Dict[str, Any]:
"""
Builds an ingest query for an HDFS TSV load.
Expand All @@ -199,7 +202,7 @@ def construct_ingest_query(self, static_path, columns):
# or a metric, as the dimension columns
dimensions = [c for c in columns if c not in metric_names and c != self.ts_dim]

ingest_query_dict = {
ingest_query_dict: Dict[str, Any] = {
"type": "index_hadoop",
"spec": {
"dataSchema": {
Expand Down
13 changes: 10 additions & 3 deletions airflow/providers/apache/hdfs/hooks/hdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
# specific language governing permissions and limitations
# under the License.
"""Hook for HDFS operations"""
from typing import Any, Optional

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.hooks.base_hook import BaseHook

try:
from snakebite.client import AutoConfigClient, Client, HAClient, Namenode # pylint: disable=syntax-error

snakebite_loaded = True
except ImportError:
snakebite_loaded = False
Expand All @@ -43,8 +46,12 @@ class HDFSHook(BaseHook):
:param autoconfig: use snakebite's automatically configured client
:type autoconfig: bool
"""
def __init__(self, hdfs_conn_id='hdfs_default', proxy_user=None,
autoconfig=False):

def __init__(self,
hdfs_conn_id: str = 'hdfs_default',
proxy_user: Optional[str] = None,
autoconfig: bool = False
):
super().__init__()
if not snakebite_loaded:
raise ImportError(
Expand All @@ -56,7 +63,7 @@ def __init__(self, hdfs_conn_id='hdfs_default', proxy_user=None,
self.proxy_user = proxy_user
self.autoconfig = autoconfig

def get_conn(self):
def get_conn(self) -> Any:
"""
Returns a snakebite HDFSClient object.
"""
Expand Down
Loading

0 comments on commit 4d74ac2

Please sign in to comment.