Skip to content

Commit

Permalink
torndb typing hint
Browse files Browse the repository at this point in the history
  • Loading branch information
liuxiaobai authored and liuxiaobai committed May 15, 2018
1 parent 6d7779e commit f630169
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions src/db/torndb.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import logging
import time
from typing import List, Dict

import pymysql
from pymysql.cursors import SSCursor


class Connection(object):
def __init__(self, host, database, user=None, password=None, **kwargs):
def __init__(self, host: str, database: str, user: str, password: str, **kwargs):
self.host = host
self.max_idle_time = float(7 * 3600)

Expand All @@ -24,10 +26,7 @@ def __init__(self, host, database, user=None, password=None, **kwargs):

self._db = None
self._last_use_time = time.time()
try:
self.reconnect()
except Exception:
logging.error("Cannot connect to MySQL on %s", self.host, exc_info=True)
self.reconnect()

def __del__(self):
self.close()
Expand All @@ -44,10 +43,10 @@ def reconnect(self):
self._db = pymysql.connect(**self._db_args)
self._db.autocommit(True)

def iter(self, query, *args, **kwargs):
def iter(self, query: str, *args, **kwargs):
"""Returns an iterator for the given query and args."""
self._ensure_connected()
cursor = pymysql.cursors.SSCursor(self._db)
cursor = SSCursor(self._db)
try:
self._execute(cursor, query, args, kwargs)
column_names = [d[0] for d in cursor.description]
Expand All @@ -56,7 +55,7 @@ def iter(self, query, *args, **kwargs):
finally:
cursor.close()

def query(self, query, *args, **kwargs):
def query(self, query: str, *args, **kwargs):
"""Returns a row list for the given query and args."""
cursor = self._cursor()
try:
Expand All @@ -66,7 +65,7 @@ def query(self, query, *args, **kwargs):
finally:
cursor.close()

def get(self, query, *args, **kwargs):
def get(self, query: str, *args: List[str], **kwargs: Dict):
"""Returns the (singular) row returned by the given query.
If the query has no results, returns None. If it has
Expand All @@ -80,7 +79,7 @@ def get(self, query, *args, **kwargs):
else:
return rows[0]

def insert(self, query, *args, **kwargs):
def insert(self, query: str, *args, **kwargs):
"""Executes the given query, returning the last rowid from the query."""
cursor = self._cursor()
try:
Expand All @@ -89,7 +88,7 @@ def insert(self, query, *args, **kwargs):
finally:
cursor.close()

def update(self, query, *args, **kwargs):
def update(self, query: str, *args, **kwargs):
"""Executes the given query, returning the rowcount from the query."""
cursor = self._cursor()
try:
Expand All @@ -98,13 +97,13 @@ def update(self, query, *args, **kwargs):
finally:
cursor.close()

def execute_many(self, query, args):
def execute_many(self, query: str, args):
"""Executes the given query against all the given param sequences.
We return the last rowid from the query.
"""
return self.insert_many(query, args)

def insert_many(self, query, args):
def insert_many(self, query: str, args):
"""Executes the given query against all the given param sequences.
We return the last rowid from the query.
"""
Expand All @@ -115,7 +114,7 @@ def insert_many(self, query, args):
finally:
cursor.close()

def update_many(self, query, args):
def update_many(self, query: str, args):
"""Executes the given query against all the given param sequences.
We return the rowcount from the query.
"""
Expand All @@ -140,7 +139,7 @@ def _cursor(self):
self._ensure_connected()
return self._db.cursor()

def _execute(self, cursor, query, args, kwargs):
def _execute(self, cursor: SSCursor, query: str, args, kwargs):
try:
return cursor.execute(query, kwargs or args)
except pymysql.OperationalError:
Expand Down

0 comments on commit f630169

Please sign in to comment.