Skip to content

Commit

Permalink
[AIRFLOW-1168] Add closing() to all connections and cursors
Browse files Browse the repository at this point in the history
This will prevent any left-open connections
whenever an exception occurs

Closes apache#2269 from NielsZeilemaker/AIRFLOW-1168
  • Loading branch information
NielsZeilemaker authored and bolkedebruin committed May 12, 2017
1 parent 443e6b2 commit 8aeebd4
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 65 deletions.
125 changes: 60 additions & 65 deletions airflow/hooks/dbapi_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from builtins import str
from past.builtins import basestring
from datetime import datetime
from contextlib import closing
import numpy
import logging
import sys
Expand Down Expand Up @@ -87,10 +88,9 @@ def get_pandas_df(self, sql, parameters=None):
if sys.version_info[0] < 3:
sql = sql.encode('utf-8')
import pandas.io.sql as psql
conn = self.get_conn()
df = psql.read_sql(sql, con=conn, params=parameters)
conn.close()
return df

with closing(self.get_conn()) as conn:
return psql.read_sql(sql, con=conn, params=parameters)

def get_records(self, sql, parameters=None):
"""
Expand All @@ -104,16 +104,14 @@ def get_records(self, sql, parameters=None):
"""
if sys.version_info[0] < 3:
sql = sql.encode('utf-8')
conn = self.get_conn()
cur = self.get_cursor()
if parameters is not None:
cur.execute(sql, parameters)
else:
cur.execute(sql)
rows = cur.fetchall()
cur.close()
conn.close()
return rows

with closing(self.get_conn()) as conn:
with closing(conn.cursor()) as cur:
if parameters is not None:
cur.execute(sql, parameters)
else:
cur.execute(sql)
return cur.fetchall()

def get_first(self, sql, parameters=None):
"""
Expand All @@ -127,16 +125,14 @@ def get_first(self, sql, parameters=None):
"""
if sys.version_info[0] < 3:
sql = sql.encode('utf-8')
conn = self.get_conn()
cur = conn.cursor()
if parameters is not None:
cur.execute(sql, parameters)
else:
cur.execute(sql)
rows = cur.fetchone()
cur.close()
conn.close()
return rows

with closing(self.get_conn()) as conn:
with closing(conn.cursor()) as cur:
if parameters is not None:
cur.execute(sql, parameters)
else:
cur.execute(sql)
return cur.fetchone()

def run(self, sql, autocommit=False, parameters=None):
"""
Expand All @@ -153,25 +149,24 @@ def run(self, sql, autocommit=False, parameters=None):
:param parameters: The parameters to render the SQL query with.
:type parameters: mapping or iterable
"""
conn = self.get_conn()
if isinstance(sql, basestring):
sql = [sql]

if self.supports_autocommit:
self.set_autocommit(conn, autocommit)

cur = conn.cursor()
for s in sql:
if sys.version_info[0] < 3:
s = s.encode('utf-8')
logging.info(s)
if parameters is not None:
cur.execute(s, parameters)
else:
cur.execute(s)
cur.close()
conn.commit()
conn.close()
with closing(self.get_conn()) as conn:
if self.supports_autocommit:
self.set_autocommit(conn, autocommit)
with closing(conn.cursor()) as cur:
for s in sql:
if sys.version_info[0] < 3:
s = s.encode('utf-8')
logging.info(s)
if parameters is not None:
cur.execute(s, parameters)
else:
cur.execute(s)
conn.commit()

def set_autocommit(self, conn, autocommit):
conn.autocommit = autocommit
Expand Down Expand Up @@ -202,30 +197,30 @@ def insert_rows(self, table, rows, target_fields=None, commit_every=1000):
target_fields = "({})".format(target_fields)
else:
target_fields = ''
conn = self.get_conn()
if self.supports_autocommit:
self.set_autocommit(conn, False)
conn.commit()
cur = conn.cursor()
i = 0
for row in rows:
i += 1
l = []
for cell in row:
l.append(self._serialize_cell(cell, conn))
values = tuple(l)
sql = "INSERT INTO {0} {1} VALUES ({2});".format(
table,
target_fields,
",".join(values))
cur.execute(sql)
if commit_every and i % commit_every == 0:
conn.commit()
logging.info(
"Loaded {i} into {table} rows so far".format(**locals()))
conn.commit()
cur.close()
conn.close()
with closing(self.get_conn()) as conn:
if self.supports_autocommit:
self.set_autocommit(conn, False)
conn.commit()
with closing(conn.cursor()) as cur:
for i, row in enumerate(rows, 1):
l = []
for cell in row:
l.append(self._serialize_cell(cell, conn))
values = tuple(l)
sql = "INSERT INTO {0} {1} VALUES ({2});".format(
table,
target_fields,
",".join(values))
cur.execute(sql)
if commit_every and i % commit_every == 0:
conn.commit()
logging.info(
"Loaded {i} into {table} rows so far".format(**locals()))
conn.commit()
logging.info(
"Done loading. Loaded a total of {i} rows".format(**locals()))

Expand Down
76 changes: 76 additions & 0 deletions tests/hooks/test_dbapi_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import mock
import unittest

from airflow.hooks.dbapi_hook import DbApiHook


class TestDbApiHook(unittest.TestCase):

def setUp(self):
super(TestDbApiHook, self).setUp()

self.cur = mock.MagicMock()
self.conn = conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur

class TestDBApiHook(DbApiHook):
conn_name_attr = 'test_conn_id'

def get_conn(self):
return conn

self.db_hook = TestDBApiHook()

def test_get_records(self):
statement = "SQL"
rows = [("hello",),
("world",)]

self.cur.fetchall.return_value = rows

self.assertEqual(rows, self.db_hook.get_records(statement))

self.conn.close.assert_called_once()
self.cur.close.assert_called_once()
self.cur.execute.assert_called_once_with(statement)

def test_get_records_parameters(self):
statement = "SQL"
parameters = ["X", "Y", "Z"]
rows = [("hello",),
("world",)]

self.cur.fetchall.return_value = rows


self.assertEqual(rows, self.db_hook.get_records(statement, parameters))

self.conn.close.assert_called_once()
self.cur.close.assert_called_once()
self.cur.execute.assert_called_once_with(statement, parameters)

def test_get_records_exception(self):
statement = "SQL"
self.cur.fetchall.side_effect = RuntimeError('Great Problems')

with self.assertRaises(RuntimeError):
self.db_hook.get_records(statement)

self.conn.close.assert_called_once()
self.cur.close.assert_called_once()
self.cur.execute.assert_called_once_with(statement)

0 comments on commit 8aeebd4

Please sign in to comment.