Skip to content

Commit

Permalink
[AIRFLOW-4419] Refine concurrency check in scheduler (apache#5194)
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinYang21 authored and feng-tao committed Apr 29, 2019
1 parent 477698a commit d63b6c9
Show file tree
Hide file tree
Showing 5 changed files with 260 additions and 90 deletions.
76 changes: 34 additions & 42 deletions airflow/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,22 @@
import sys
import threading
import time
from collections import defaultdict, OrderedDict
from collections import OrderedDict, defaultdict
from datetime import timedelta
from time import sleep
from typing import Any

import six
from past.builtins import basestring
from sqlalchemy import (Column, Index, Integer, String, and_, func, not_, or_)
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm.session import make_transient
from typing import Any

from airflow import configuration as conf
from airflow import executors, models, settings
from airflow.exceptions import (
AirflowException, NoAvailablePoolSlot, PoolNotFound, DagConcurrencyLimitReached,
)

from airflow.models import DAG, DagPickle, DagRun, errors, SlaMiss
from airflow.exceptions import (AirflowException, DagConcurrencyLimitReached,
NoAvailablePoolSlot, PoolNotFound)
from airflow.models import DAG, DagPickle, DagRun, SlaMiss, errors
from airflow.stats import Stats
from airflow.task.task_runner import get_task_runner
from airflow.ti_deps.dep_context import DepContext, QUEUE_DEPS, RUN_DEPS
Expand Down Expand Up @@ -1014,13 +1012,14 @@ def _change_state_for_tis_without_dagrun(self,
)

@provide_session
def __get_task_concurrency_map(self, states, session=None):
def __get_concurrency_maps(self, states, session=None):
"""
Returns a map from tasks to number in the states list given.
Get the concurrency maps.
:param states: List of states to query for
:type states: list[airflow.utils.state.State]
:return: A map from (dag_id, task_id) to count of tasks in states
:return: A map from (dag_id, task_id) to # of task instances and
a map from (dag_id, task_id) to # of task instances in the given state list
:rtype: dict[tuple[str, str], int]
"""
Expand All @@ -1031,11 +1030,13 @@ def __get_task_concurrency_map(self, states, session=None):
.filter(TI.state.in_(states))
.group_by(TI.task_id, TI.dag_id)
).all()
dag_map = defaultdict(int)
task_map = defaultdict(int)
for result in ti_concurrency_query:
task_id, dag_id, count = result
dag_map[dag_id] += count
task_map[(dag_id, task_id)] = count
return task_map
return dag_map, task_map

@provide_session
def _find_executable_task_instances(self, simple_dag_bag, states, session=None):
Expand All @@ -1054,7 +1055,7 @@ def _find_executable_task_instances(self, simple_dag_bag, states, session=None):
"""
executable_tis = []

# Get all the queued task instances from associated with scheduled
# Get all task instances associated with scheduled
# DagRuns which are not backfilled, in the given states,
# and the dag is not paused
TI = models.TaskInstance
Expand All @@ -1074,6 +1075,8 @@ def _find_executable_task_instances(self, simple_dag_bag, states, session=None):
.filter(or_(DM.dag_id == None, # noqa: E711
not_(DM.is_paused)))
)

# Additional filters on task instance state
if None in states:
ti_query = ti_query.filter(
or_(TI.state == None, TI.state.in_(states)) # noqa: E711
Expand Down Expand Up @@ -1103,7 +1106,8 @@ def _find_executable_task_instances(self, simple_dag_bag, states, session=None):
pool_to_task_instances[task_instance.pool].append(task_instance)

states_to_count_as_running = [State.RUNNING, State.QUEUED]
task_concurrency_map = self.__get_task_concurrency_map(
# dag_id to # of running tasks and (dag_id, task_id) to # of running tasks.
dag_concurrency_map, task_concurrency_map = self.__get_concurrency_maps(
states=states_to_count_as_running, session=session)

# Go through each pool, and queue up a task for execution if there are
Expand All @@ -1113,9 +1117,9 @@ def _find_executable_task_instances(self, simple_dag_bag, states, session=None):
if not pool:
# Arbitrary:
# If queued outside of a pool, trigger no more than
# non_pooled_task_slot_count per run
open_slots = conf.getint('core', 'non_pooled_task_slot_count')
pool_name = 'not_pooled'
# non_pooled_task_slot_count
open_slots = models.Pool.default_pool_open_slots()
pool_name = models.Pool.default_pool_name
else:
if pool not in pools:
self.log.warning(
Expand All @@ -1126,19 +1130,16 @@ def _find_executable_task_instances(self, simple_dag_bag, states, session=None):
else:
open_slots = pools[pool].open_slots(session=session)

num_queued = len(task_instances)
num_ready = len(task_instances)
self.log.info(
"Figuring out tasks to run in Pool(name=%s) with %s open slots "
"and %s task instances in queue",
pool, open_slots, num_queued
"and %s task instances ready to be queued",
pool, open_slots, num_ready
)

priority_sorted_task_instances = sorted(
task_instances, key=lambda ti: (-ti.priority_weight, ti.execution_date))

# DAG IDs with running tasks that equal the concurrency limit of the dag
dag_id_to_possibly_running_task_count = {}

# Number of tasks that cannot be scheduled because of no open slot in pool
num_starving_tasks = 0
for current_index, task_instance in enumerate(priority_sorted_task_instances):
Expand All @@ -1156,42 +1157,32 @@ def _find_executable_task_instances(self, simple_dag_bag, states, session=None):
dag_id = task_instance.dag_id
simple_dag = simple_dag_bag.get_dag(dag_id)

if dag_id not in dag_id_to_possibly_running_task_count:
dag_id_to_possibly_running_task_count[dag_id] = \
DAG.get_num_task_instances(
dag_id,
simple_dag_bag.get_dag(dag_id).task_ids,
states=states_to_count_as_running,
session=session)

current_task_concurrency = dag_id_to_possibly_running_task_count[dag_id]
task_concurrency_limit = simple_dag_bag.get_dag(dag_id).concurrency
current_dag_concurrency = dag_concurrency_map[dag_id]
dag_concurrency_limit = simple_dag_bag.get_dag(dag_id).concurrency
self.log.info(
"DAG %s has %s/%s running and queued tasks",
dag_id, current_task_concurrency, task_concurrency_limit
dag_id, current_dag_concurrency, dag_concurrency_limit
)
if current_task_concurrency >= task_concurrency_limit:
if current_dag_concurrency >= dag_concurrency_limit:
self.log.info(
"Not executing %s since the number of tasks running or queued "
"from DAG %s is >= to the DAG's task concurrency limit of %s",
task_instance, dag_id, task_concurrency_limit
task_instance, dag_id, dag_concurrency_limit
)
continue

task_concurrency = simple_dag.get_task_special_arg(
task_concurrency_limit = simple_dag.get_task_special_arg(
task_instance.task_id,
'task_concurrency')
if task_concurrency is not None:
num_running = task_concurrency_map[
if task_concurrency_limit is not None:
current_task_concurrency = task_concurrency_map[
(task_instance.dag_id, task_instance.task_id)
]

if num_running >= task_concurrency:
if current_task_concurrency >= task_concurrency_limit:
self.log.info("Not executing %s since the task concurrency for"
" this task has been reached.", task_instance)
continue
else:
task_concurrency_map[(task_instance.dag_id, task_instance.task_id)] += 1

if self.executor.has_task(task_instance):
self.log.debug(
Expand All @@ -1201,7 +1192,8 @@ def _find_executable_task_instances(self, simple_dag_bag, states, session=None):
continue
executable_tis.append(task_instance)
open_slots -= 1
dag_id_to_possibly_running_task_count[dag_id] += 1
dag_concurrency_map[dag_id] += 1
task_concurrency_map[(task_instance.dag_id, task_instance.task_id)] += 1

Stats.gauge('pool.starving_tasks.{pool_name}'.format(pool_name=pool_name),
num_starving_tasks)
Expand Down
57 changes: 21 additions & 36 deletions airflow/models/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
# specific language governing permissions and limitations
# under the License.

from sqlalchemy import Column, Integer, String, Text
from sqlalchemy import Column, Integer, String, Text, func

from airflow import conf
from airflow.models.base import Base
from airflow.utils.db import provide_session
from airflow.utils.state import State
from airflow.utils.db import provide_session


class Pool(Base):
Expand All @@ -32,9 +33,21 @@ class Pool(Base):
slots = Column(Integer, default=0)
description = Column(Text)

default_pool_name = 'not_pooled'

def __repr__(self):
return self.pool

@staticmethod
@provide_session
def default_pool_open_slots(session):
from airflow.models import TaskInstance as TI # To avoid circular imports
total_slots = conf.getint('core', 'non_pooled_task_slot_count')
used_slots = session.query(func.count()).filter(
TI.pool == Pool.default_pool_name).filter(
TI.state.in_([State.RUNNING, State.QUEUED])).scalar()
return total_slots - used_slots

def to_json(self):
return {
'id': self.id,
Expand All @@ -43,42 +56,14 @@ def to_json(self):
'description': self.description,
}

@provide_session
def used_slots(self, session):
"""
Returns the number of slots used at the moment
"""
from airflow.models.taskinstance import TaskInstance # Avoid circular import

running = (
session
.query(TaskInstance)
.filter(TaskInstance.pool == self.pool)
.filter(TaskInstance.state == State.RUNNING)
.count()
)
return running

@provide_session
def queued_slots(self, session):
"""
Returns the number of slots used at the moment
"""
from airflow.models.taskinstance import TaskInstance # Avoid circular import

return (
session
.query(TaskInstance)
.filter(TaskInstance.pool == self.pool)
.filter(TaskInstance.state == State.QUEUED)
.count()
)

@provide_session
def open_slots(self, session):
"""
Returns the number of slots open at the moment
"""
used_slots = self.used_slots(session=session)
queued_slots = self.queued_slots(session=session)
return self.slots - used_slots - queued_slots
from airflow.models.taskinstance import \
TaskInstance as TI # Avoid circular import

used_slots = session.query(func.count()).filter(TI.pool == self.pool).filter(
TI.state.in_([State.RUNNING, State.QUEUED])).scalar()
return self.slots - used_slots
82 changes: 82 additions & 0 deletions tests/models/test_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import unittest

from airflow import settings
from airflow.models import DAG
from airflow.models.pool import Pool
from airflow.models.taskinstance import TaskInstance as TI
from airflow.operators.dummy_operator import DummyOperator
from airflow.utils import timezone
from airflow.utils.state import State
from tests.test_utils.db import clear_db_pools, clear_db_runs
from tests.test_utils.decorators import mock_conf_get

DEFAULT_DATE = timezone.datetime(2016, 1, 1)


class PoolTest(unittest.TestCase):

def tearDown(self):
clear_db_runs()
clear_db_pools()

def test_open_slots(self):
pool = Pool(pool='test_pool', slots=5)
dag = DAG(
dag_id='test_open_slots',
start_date=DEFAULT_DATE, )
t1 = DummyOperator(task_id='dummy1', dag=dag, pool='test_pool')
t2 = DummyOperator(task_id='dummy2', dag=dag, pool='test_pool')
ti1 = TI(task=t1, execution_date=DEFAULT_DATE)
ti2 = TI(task=t2, execution_date=DEFAULT_DATE)
ti1.state = State.RUNNING
ti2.state = State.QUEUED

session = settings.Session
session.add(pool)
session.add(ti1)
session.add(ti2)
session.commit()
session.close()

self.assertEqual(3, pool.open_slots())

@mock_conf_get('core', 'non_pooled_task_slot_count', 5)
def test_default_pool_open_slots(self):
dag = DAG(
dag_id='test_default_pool_open_slots',
start_date=DEFAULT_DATE, )
t1 = DummyOperator(task_id='dummy1', dag=dag)
t2 = DummyOperator(task_id='dummy2', dag=dag)
ti1 = TI(task=t1, execution_date=DEFAULT_DATE)
ti2 = TI(task=t2, execution_date=DEFAULT_DATE)
ti1.state = State.RUNNING
ti2.state = State.QUEUED
ti1.pool = Pool.default_pool_name
ti2.pool = Pool.default_pool_name

session = settings.Session
session.add(ti1)
session.add(ti2)
session.commit()
session.close()

self.assertEqual(3, Pool.default_pool_open_slots())
Loading

0 comments on commit d63b6c9

Please sign in to comment.