Skip to content

Commit

Permalink
Merge pull request apache#1376 from bolkedebruin/multiprocessing_sche…
Browse files Browse the repository at this point in the history
…duler

Use multiprocessing for the scheduler
  • Loading branch information
bolkedebruin committed Apr 17, 2016
2 parents 8b63891 + a36861a commit 7da6a94
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 26 deletions.
7 changes: 7 additions & 0 deletions airflow/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def run_command(command):
'job_heartbeat_sec': 5,
'scheduler_heartbeat_sec': 60,
'authenticate': False,
'max_threads': 2,
},
'celery': {
'default_queue': 'default',
Expand Down Expand Up @@ -329,6 +330,11 @@ def run_command(command):
# statsd_port = 8125
# statsd_prefix = airflow
# The scheduler can run multiple threads in parallel to schedule dags.
# This defines how many threads will run. However airflow will never
# use more threads than the amount of cpu cores available.
max_threads = 2
[mesos]
# Mesos master address which MesosExecutor will connect to.
master = localhost:5050
Expand Down Expand Up @@ -414,6 +420,7 @@ def run_command(command):
job_heartbeat_sec = 1
scheduler_heartbeat_sec = 5
authenticate = true
max_threads = 2
"""


Expand Down
106 changes: 80 additions & 26 deletions airflow/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import logging
import socket
import subprocess
import multiprocessing
import math
from time import sleep

from sqlalchemy import Column, Integer, String, DateTime, func, Index, or_
Expand Down Expand Up @@ -213,14 +215,20 @@ class SchedulerJob(BaseJob):
def __init__(
self,
dag_id=None,
dag_ids=None,
subdir=None,
test_mode=False,
refresh_dags_every=10,
num_runs=None,
do_pickle=False,
*args, **kwargs):

# for BaseJob compatibility
self.dag_id = dag_id
self.dag_ids = [dag_id] if dag_id else []
if dag_ids:
self.dag_ids.extend(dag_ids)

self.subdir = subdir

if test_mode:
Expand All @@ -234,6 +242,11 @@ def __init__(
super(SchedulerJob, self).__init__(*args, **kwargs)

self.heartrate = conf.getint('scheduler', 'SCHEDULER_HEARTBEAT_SEC')
self.max_threads = min(conf.getint('scheduler', 'max_threads'), multiprocessing.cpu_count())
if 'sqlite' in conf.get('core', 'sql_alchemy_conn'):
if self.max_threads > 1:
self.logger.error("Cannot use more than 1 thread when using sqlite. Setting max_threads to 1")
self.max_threads = 1

@provide_session
def manage_slas(self, dag, session=None):
Expand Down Expand Up @@ -422,7 +435,7 @@ def schedule_dag(self, dag):

# don't ever schedule prior to the dag's start_date
if dag.start_date:
next_run_date = max(next_run_date, dag.start_date)
next_run_date = dag.start_date if not next_run_date else max(next_run_date, dag.start_date)

# this structure is necessary to avoid a TypeError from concatenating
# NoneType
Expand All @@ -447,7 +460,7 @@ def schedule_dag(self, dag):
session.commit()
return next_run

def process_dag(self, dag, executor):
def process_dag(self, dag, queue):
"""
This method schedules a single DAG by looking at the latest
run for each task and attempting to schedule the following run.
Expand Down Expand Up @@ -500,6 +513,7 @@ def process_dag(self, dag, executor):
could_not_run = set()
self.logger.info('Checking dependencies on {} tasks instances, minus {} '
'skippable ones'.format(len(descartes), len(skip_tis)))

for task, dttm in descartes:
if task.adhoc or (task.task_id, dttm) in skip_tis:
continue
Expand All @@ -510,14 +524,14 @@ def process_dag(self, dag, executor):
State.RUNNING, State.QUEUED, State.SUCCESS, State.FAILED):
continue
elif ti.is_runnable(flag_upstream_failed=True):
self.logger.debug('Firing task: {}'.format(ti))
executor.queue_task_instance(ti, pickle_id=pickle_id)
self.logger.debug('Queuing task: {}'.format(ti))
queue.put((ti.key, pickle_id))
else:
could_not_run.add(ti)

# this type of deadlock happens when dagruns can't even start and so
# the TI's haven't been persisted to the database.
if len(could_not_run) == len(descartes):
# the TI's haven't been persisted to the database.
if len(could_not_run) == len(descartes) and len(could_not_run) > 0:
self.logger.error(
'Dag runs are deadlocked for DAG: {}'.format(dag.dag_id))
(session
Expand Down Expand Up @@ -657,8 +671,32 @@ def prioritize_queued(self, session, executor, dagbag):

session.commit()

def _split_dags(self, dags, size):
"""
This function splits a list of dags into chunks of int size.
_split_dags([1,2,3,4,5,6], 3) becomes [[1,2,3],[4,5,6]]
"""
size = max(1, size)
return [dags[i:i + size] for i in range(0, len(dags), size)]

def _do_dags(self, dagbag, dags, tis_out):
"""
Iterates over the dags and schedules and processes them
"""
for dag in dags:
self.logger.debug("Scheduling {}".format(dag.dag_id))
dag = dagbag.get_dag(dag.dag_id)
if not dag:
continue
try:
self.schedule_dag(dag)
self.process_dag(dag, tis_out)
self.manage_slas(dag)
except Exception as e:
self.logger.exception(e)

def _execute(self):
dag_id = self.dag_id
TI = models.TaskInstance

pessimistic_connection_handling()

Expand All @@ -668,8 +706,8 @@ def _execute(self):
dagbag = models.DagBag(self.subdir, sync_to_db=True)
executor = self.executor = dagbag.executor
executor.start()
i = 0
while not self.num_runs or self.num_runs > i:
self.runs = 0
while not self.num_runs or self.num_runs > self.runs:
try:
loop_start_dttm = datetime.now()
try:
Expand All @@ -678,35 +716,51 @@ def _execute(self):
except Exception as e:
self.logger.exception(e)

i += 1
self.runs += 1
try:
if i % self.refresh_dags_every == 0:
if self.num_runs % self.refresh_dags_every == 0:
dagbag = models.DagBag(self.subdir, sync_to_db=True)
else:
dagbag.collect_dags(only_if_updated=True)
except:
self.logger.error("Failed at reloading the dagbag")
except Exception as e:
self.logger.error("Failed at reloading the dagbag. {}".format(e))
Stats.incr('dag_refresh_error', 1, 1)
sleep(5)

if dag_id:
dags = [dagbag.dags[dag_id]]
if len(self.dag_ids) > 0:
dags = [dag for dag in dagbag.dags.values() if dag.dag_id in self.dag_ids]
else:
dags = [
dag for dag in dagbag.dags.values()
if not dag.parent_dag]

paused_dag_ids = dagbag.paused_dags()
for dag in dags:
self.logger.debug("Scheduling {}".format(dag.dag_id))
dag = dagbag.get_dag(dag.dag_id)
if not dag or (dag.dag_id in paused_dag_ids):
continue
try:
self.schedule_dag(dag)
self.process_dag(dag, executor)
self.manage_slas(dag)
except Exception as e:
self.logger.exception(e)
dags = [x for x in dags if x.dag_id not in paused_dag_ids]
# dags = filter(lambda x: x.dag_id not in paused_dag_ids, dags)

self.logger.debug("Total Cores: {} Max Threads: {} DAGs:{}".
format(multiprocessing.cpu_count(),
self.max_threads,
len(dags)))
dags = self._split_dags(dags, math.ceil(len(dags) / self.max_threads))
tis_q = multiprocessing.Queue()
jobs = [multiprocessing.Process(target=self._do_dags,
args=(dagbag, dags[i], tis_q))
for i in range(len(dags))]

self.logger.info("Starting {} scheduler jobs".format(len(jobs)))
for j in jobs:
j.start()
for j in jobs:
j.join()

while not tis_q.empty():
ti_key, pickle_id = tis_q.get()
dag = dagbag.dags[ti_key[0]]
task = dag.get_task(ti_key[1])
ti = TI(task, ti_key[2])
self.executor.queue_task_instance(ti, pickle_id=pickle_id)

self.logger.info("Done queuing tasks, calling the executor's "
"heartbeat")
duration_sec = (datetime.now() - loop_start_dttm).total_seconds()
Expand Down
16 changes: 16 additions & 0 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,22 @@ def merge_conn(conn, session=None):
session.commit()


@event.listens_for(settings.engine, "connect")
def connect(dbapi_connection, connection_record):
connection_record.info['pid'] = os.getpid()


@event.listens_for(settings.engine, "checkout")
def checkout(dbapi_connection, connection_record, connection_proxy):
pid = os.getpid()
if connection_record.info['pid'] != pid:
connection_record.connection = connection_proxy.connection = None
raise exc.DisconnectionError(
"Connection record belongs to pid {}, "
"attempting to check out in pid {}".format(
connection_record.info['pid'], pid))


def initdb():
session = settings.Session()

Expand Down
2 changes: 2 additions & 0 deletions scripts/ci/airflow_travis.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,5 @@ default_queue = default
job_heartbeat_sec = 1
scheduler_heartbeat_sec = 5
authenticate = true
max_threads = 2

22 changes: 22 additions & 0 deletions tests/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
from airflow.utils.timeout import timeout
from airflow.utils.db import provide_session

from airflow import configuration
configuration.test_mode()

DEV_NULL = '/dev/null'
DEFAULT_DATE = datetime.datetime(2016, 1, 1)

Expand Down Expand Up @@ -284,6 +287,7 @@ def test_scheduler_pooled_tasks(self):
scheduler.run()

task_1 = dag.tasks[0]
logging.info("Trying to find task {}".format(task_1))
ti = TI(task_1, dag.start_date)
ti.refresh_from_db()
self.assertEqual(ti.state, State.FAILED)
Expand Down Expand Up @@ -364,3 +368,21 @@ def test_scheduler_start_date(self):
session = settings.Session()
self.assertEqual(
len(session.query(TI).filter(TI.dag_id == dag_id).all()), 1)

def test_scheduler_multiprocessing(self):
"""
Test that the scheduler can successfully queue multiple dags in parallel
"""
dag_ids = ['test_start_date_scheduling', 'test_dagrun_states_success']
for dag_id in dag_ids:
dag = self.dagbag.get_dag(dag_id)
dag.clear()

scheduler = SchedulerJob(dag_ids=dag_ids, num_runs=2)
scheduler.run()

# zero tasks ran
dag_id = 'test_start_date_scheduling'
session = settings.Session()
self.assertEqual(
len(session.query(TI).filter(TI.dag_id == dag_id).all()), 0)

0 comments on commit 7da6a94

Please sign in to comment.