Skip to content

Commit

Permalink
[AIRFLOW-4104] Add type annotations to common classes. (apache#4926)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmcarp authored and ashb committed Mar 27, 2019
1 parent e27950a commit ab58eb6
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 112 deletions.
9 changes: 5 additions & 4 deletions airflow/hooks/base_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import os
import random
from typing import Iterable

from airflow.models.connection import Connection
from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -67,7 +68,7 @@ def _get_connection_from_env(cls, conn_id):
return conn

@classmethod
def get_connections(cls, conn_id):
def get_connections(cls, conn_id): # type: (str) -> Iterable[Connection]
conn = cls._get_connection_from_env(conn_id)
if conn:
conns = [conn]
Expand All @@ -76,15 +77,15 @@ def get_connections(cls, conn_id):
return conns

@classmethod
def get_connection(cls, conn_id):
conn = random.choice(cls.get_connections(conn_id))
def get_connection(cls, conn_id): # type: (str) -> Connection
conn = random.choice(list(cls.get_connections(conn_id)))
if conn.host:
log = LoggingMixin().log
log.info("Using connection to: %s", conn.debug_info())
return conn

@classmethod
def get_hook(cls, conn_id):
def get_hook(cls, conn_id): # type: (str) -> BaseHook
connection = cls.get_connection(conn_id)
return connection.get_hook()

Expand Down
176 changes: 94 additions & 82 deletions airflow/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,13 @@
from builtins import ImportError as BuiltinImportError, bytes, object, str
from collections import defaultdict, namedtuple, OrderedDict
import copy
from typing import Iterable
from datetime import timedelta
from typing import Optional, Union, Type, Callable, Iterable, Set, Dict, Any

from future.standard_library import install_aliases

from airflow.models.base import Base, ID_LEN

try:
# Fix Python > 3.7 deprecation
from collections.abc import Hashable
except ImportError:
# Preserve Python < 3.3 compatibility
from collections import Hashable
from datetime import timedelta

import dill
import functools
import getpass
Expand Down Expand Up @@ -74,6 +67,7 @@
croniter, CroniterBadCronError, CroniterBadDateError, CroniterNotAlphaError
)
import six
from dateutil.relativedelta import relativedelta

from airflow import settings, utils
from airflow.executors import get_default_executor, LocalExecutor
Expand All @@ -83,7 +77,7 @@
AirflowRescheduleException
)
from airflow.dag.base_dag import BaseDag, BaseDagBag
from airflow.lineage import apply_lineage, prepare_lineage
from airflow.lineage import apply_lineage, prepare_lineage, DataSet
from airflow.models.dagpickle import DagPickle
from airflow.models.kubernetes import KubeWorkerIdentifier, KubeResourceVersion # noqa: F401
from airflow.models.log import Log
Expand Down Expand Up @@ -116,6 +110,8 @@

XCOM_RETURN_KEY = 'return_value'

ScheduleInterval = Union[str, timedelta, relativedelta]


class InvalidFernetToken(Exception):
# If Fernet isn't loaded we need a valid exception class to catch. If it is
Expand Down Expand Up @@ -2064,43 +2060,44 @@ class derived from this one results in the creation of a task object,

@apply_defaults
def __init__(
self,
task_id,
owner=configuration.conf.get('operators', 'DEFAULT_OWNER'),
email=None,
email_on_retry=True,
email_on_failure=True,
retries=0,
retry_delay=timedelta(seconds=300),
retry_exponential_backoff=False,
max_retry_delay=None,
start_date=None,
end_date=None,
schedule_interval=None, # not hooked as of now
depends_on_past=False,
wait_for_downstream=False,
dag=None,
params=None,
default_args=None,
priority_weight=1,
weight_rule=WeightRule.DOWNSTREAM,
queue=configuration.conf.get('celery', 'default_queue'),
pool=None,
sla=None,
execution_timeout=None,
on_failure_callback=None,
on_success_callback=None,
on_retry_callback=None,
trigger_rule=TriggerRule.ALL_SUCCESS,
resources=None,
run_as_user=None,
task_concurrency=None,
executor_config=None,
do_xcom_push=True,
inlets=None,
outlets=None,
*args,
**kwargs):
self,
task_id, # type: str
owner=configuration.conf.get('operators', 'DEFAULT_OWNER'), # type: str
email=None, # type: Optional[str]
email_on_retry=True, # type: bool
email_on_failure=True, # type: bool
retries=0, # type: int
retry_delay=timedelta(seconds=300), # type: timedelta
retry_exponential_backoff=False, # type: bool
max_retry_delay=None, # type: Optional[datetime]
start_date=None, # type: Optional[datetime]
end_date=None, # type: Optional[datetime]
schedule_interval=None, # not hooked as of now
depends_on_past=False, # type: bool
wait_for_downstream=False, # type: bool
dag=None, # type: Optional[DAG]
params=None, # type: Optional[Dict]
default_args=None, # type: Optional[Dict]
priority_weight=1, # type: int
weight_rule=WeightRule.DOWNSTREAM, # type: str
queue=configuration.conf.get('celery', 'default_queue'), # type: str
pool=None, # type: Optional[str]
sla=None, # type: Optional[timedelta]
execution_timeout=None, # type: Optional[timedelta]
on_failure_callback=None, # type: Optional[Callable]
on_success_callback=None, # type: Optional[Callable]
on_retry_callback=None, # type: Optional[Callable]
trigger_rule=TriggerRule.ALL_SUCCESS, # type: str
resources=None, # type: Optional[Dict]
run_as_user=None, # type: Optional[str]
task_concurrency=None, # type: Optional[int]
executor_config=None, # type: Optional[Dict]
do_xcom_push=True, # type: bool
inlets=None, # type: Optional[Dict]
outlets=None, # type: Optional[Dict]
*args,
**kwargs
):

if args or kwargs:
# TODO remove *args and **kwargs in Airflow 2.0
Expand Down Expand Up @@ -2183,8 +2180,8 @@ def __init__(
self.do_xcom_push = do_xcom_push

# Private attributes
self._upstream_task_ids = set()
self._downstream_task_ids = set()
self._upstream_task_ids = set() # type: Set[str]
self._downstream_task_ids = set() # type: Set[str]

if not dag and _CONTEXT_MANAGER_DAG:
dag = _CONTEXT_MANAGER_DAG
Expand All @@ -2194,8 +2191,8 @@ def __init__(
self._log = logging.getLogger("airflow.task.operators")

# lineage
self.inlets = []
self.outlets = []
self.inlets = [] # type: Iterable[DataSet]
self.outlets = [] # type: Iterable[DataSet]
self.lineage_data = None

self._inlets = {
Expand All @@ -2206,7 +2203,7 @@ def __init__(

self._outlets = {
"datasets": [],
}
} # type: Dict

if inlets:
self._inlets.update(inlets)
Expand Down Expand Up @@ -2977,29 +2974,32 @@ class DAG(BaseDag, LoggingMixin):
"""

def __init__(
self, dag_id,
description='',
schedule_interval=timedelta(days=1),
start_date=None, end_date=None,
full_filepath=None,
template_searchpath=None,
template_undefined=jinja2.Undefined,
user_defined_macros=None,
user_defined_filters=None,
default_args=None,
concurrency=configuration.conf.getint('core', 'dag_concurrency'),
max_active_runs=configuration.conf.getint(
'core', 'max_active_runs_per_dag'),
dagrun_timeout=None,
sla_miss_callback=None,
default_view=None,
orientation=configuration.conf.get('webserver', 'dag_orientation'),
catchup=configuration.conf.getboolean('scheduler', 'catchup_by_default'),
on_success_callback=None, on_failure_callback=None,
doc_md=None,
params=None,
access_control=None):

self,
dag_id, # type: str
description='', # type: str
schedule_interval=timedelta(days=1), # type: Optional[ScheduleInterval]
start_date=None, # type: Optional[datetime]
end_date=None, # type: Optional[datetime]
full_filepath=None, # type: Optional[str]
template_searchpath=None, # type: Optional[Union[str, Iterable[str]]]
template_undefined=jinja2.Undefined, # type: Type[jinja2.Undefined]
user_defined_macros=None, # type: Optional[Dict]
user_defined_filters=None, # type: Optional[Dict]
default_args=None, # type: Optional[Dict]
concurrency=configuration.conf.getint('core', 'dag_concurrency'), # type: int
max_active_runs=configuration.conf.getint(
'core', 'max_active_runs_per_dag'), # type: int
dagrun_timeout=None, # type: Optional[timedelta]
sla_miss_callback=None, # type: Optional[Callable]
default_view=None, # type: Optional[str]
orientation=configuration.conf.get('webserver', 'dag_orientation'), # type: str
catchup=configuration.conf.getboolean('scheduler', 'catchup_by_default'), # type: bool
on_success_callback=None, # type: Optional[Callable]
on_failure_callback=None, # type: Optional[Callable]
doc_md=None, # type: Optional[str]
params=None, # type: Optional[Dict]
access_control=None # type: Optional[Dict]
):
self.user_defined_macros = user_defined_macros
self.user_defined_filters = user_defined_filters
self.default_args = default_args or {}
Expand All @@ -3021,7 +3021,7 @@ def __init__(
self._description = description
# set file location to caller source path
self.fileloc = sys._getframe().f_back.f_code.co_filename
self.task_dict = dict()
self.task_dict = dict() # type: Dict[str, TaskInstance]

# set timezone
if start_date and start_date.tzinfo:
Expand Down Expand Up @@ -3050,8 +3050,8 @@ def __init__(
)

self.schedule_interval = schedule_interval
if isinstance(schedule_interval, Hashable) and schedule_interval in cron_presets:
self._schedule_interval = cron_presets.get(schedule_interval)
if isinstance(schedule_interval, six.string_types) and schedule_interval in cron_presets:
self._schedule_interval = cron_presets.get(schedule_interval) # type: Optional[ScheduleInterval]
elif schedule_interval == '@once':
self._schedule_interval = None
else:
Expand All @@ -3076,7 +3076,7 @@ def __init__(
self.on_failure_callback = on_failure_callback
self.doc_md = doc_md

self._old_context_manager_dags = []
self._old_context_manager_dags = [] # type: Iterable[DAG]
self._access_control = access_control

self._comps = {
Expand Down Expand Up @@ -4283,7 +4283,13 @@ def setdefault(cls, key, default, deserialize_json=False):

@classmethod
@provide_session
def get(cls, key, default_var=__NO_DEFAULT_SENTINEL, deserialize_json=False, session=None):
def get(
cls,
key, # type: str
default_var=__NO_DEFAULT_SENTINEL, # type: Any
deserialize_json=False, # type: bool
session=None
):
obj = session.query(cls).filter(cls.key == key).first()
if obj is None:
if default_var is not cls.__NO_DEFAULT_SENTINEL:
Expand All @@ -4298,15 +4304,21 @@ def get(cls, key, default_var=__NO_DEFAULT_SENTINEL, deserialize_json=False, ses

@classmethod
@provide_session
def set(cls, key, value, serialize_json=False, session=None):
def set(
cls,
key, # type: str
value, # type: Any
serialize_json=False, # type: bool
session=None
):

if serialize_json:
stored_value = json.dumps(value)
else:
stored_value = str(value)

Variable.delete(key)
session.add(Variable(key=key, val=stored_value))
session.add(Variable(key=key, val=stored_value)) # type: ignore
session.flush()

@classmethod
Expand Down
1 change: 1 addition & 0 deletions airflow/models/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def get_hook(self):
elif self.conn_type == 'grpc':
from airflow.contrib.hooks.grpc_hook import GrpcHook
return GrpcHook(grpc_conn_id=self.conn_id)
raise AirflowException("Unknown hook type {}".format(self.conn_type))

def __repr__(self):
return self.conn_id
Expand Down
35 changes: 24 additions & 11 deletions airflow/operators/check_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from builtins import zip
from builtins import str
from typing import Iterable
from typing import Optional, Any, Iterable, Dict

from airflow.exceptions import AirflowException
from airflow.hooks.base_hook import BaseHook
Expand Down Expand Up @@ -69,9 +69,12 @@ class CheckOperator(BaseOperator):

@apply_defaults
def __init__(
self, sql,
conn_id=None,
*args, **kwargs):
self,
sql, # type: str
conn_id=None, # type: Optional[str]
*args,
**kwargs
):
super(CheckOperator, self).__init__(*args, **kwargs)
self.conn_id = conn_id
self.sql = sql
Expand Down Expand Up @@ -127,9 +130,14 @@ class ValueCheckOperator(BaseOperator):

@apply_defaults
def __init__(
self, sql, pass_value, tolerance=None,
conn_id=None,
*args, **kwargs):
self,
sql, # type: str
pass_value, # type: Any
tolerance=None, # type: Any
conn_id=None, # type: Optional[str]
*args,
**kwargs
):
super(ValueCheckOperator, self).__init__(*args, **kwargs)
self.sql = sql
self.conn_id = conn_id
Expand Down Expand Up @@ -203,10 +211,15 @@ class IntervalCheckOperator(BaseOperator):

@apply_defaults
def __init__(
self, table, metrics_thresholds,
date_filter_column='ds', days_back=-7,
conn_id=None,
*args, **kwargs):
self,
table, # type: str
metrics_thresholds, # type: Dict
date_filter_column='ds', # type: str
days_back=-7, # type: int
conn_id=None, # type: Optional[str]
*args,
**kwargs
):
super(IntervalCheckOperator, self).__init__(*args, **kwargs)
self.table = table
self.metrics_thresholds = metrics_thresholds
Expand Down
Loading

0 comments on commit ab58eb6

Please sign in to comment.