diff --git a/src/db/torndb.py b/src/db/torndb.py index a8ecfd0..dc72390 100755 --- a/src/db/torndb.py +++ b/src/db/torndb.py @@ -1,11 +1,13 @@ import logging import time +from typing import List, Dict import pymysql +from pymysql.cursors import SSCursor class Connection(object): - def __init__(self, host, database, user=None, password=None, **kwargs): + def __init__(self, host: str, database: str, user: str, password: str, **kwargs): self.host = host self.max_idle_time = float(7 * 3600) @@ -24,10 +26,7 @@ def __init__(self, host, database, user=None, password=None, **kwargs): self._db = None self._last_use_time = time.time() - try: - self.reconnect() - except Exception: - logging.error("Cannot connect to MySQL on %s", self.host, exc_info=True) + self.reconnect() def __del__(self): self.close() @@ -44,10 +43,10 @@ def reconnect(self): self._db = pymysql.connect(**self._db_args) self._db.autocommit(True) - def iter(self, query, *args, **kwargs): + def iter(self, query: str, *args, **kwargs): """Returns an iterator for the given query and args.""" self._ensure_connected() - cursor = pymysql.cursors.SSCursor(self._db) + cursor = SSCursor(self._db) try: self._execute(cursor, query, args, kwargs) column_names = [d[0] for d in cursor.description] @@ -56,7 +55,7 @@ def iter(self, query, *args, **kwargs): finally: cursor.close() - def query(self, query, *args, **kwargs): + def query(self, query: str, *args, **kwargs): """Returns a row list for the given query and args.""" cursor = self._cursor() try: @@ -66,7 +65,7 @@ def query(self, query, *args, **kwargs): finally: cursor.close() - def get(self, query, *args, **kwargs): + def get(self, query: str, *args: List[str], **kwargs: Dict): """Returns the (singular) row returned by the given query. If the query has no results, returns None. If it has @@ -80,7 +79,7 @@ def get(self, query, *args, **kwargs): else: return rows[0] - def insert(self, query, *args, **kwargs): + def insert(self, query: str, *args, **kwargs): """Executes the given query, returning the last rowid from the query.""" cursor = self._cursor() try: @@ -89,7 +88,7 @@ def insert(self, query, *args, **kwargs): finally: cursor.close() - def update(self, query, *args, **kwargs): + def update(self, query: str, *args, **kwargs): """Executes the given query, returning the rowcount from the query.""" cursor = self._cursor() try: @@ -98,13 +97,13 @@ def update(self, query, *args, **kwargs): finally: cursor.close() - def execute_many(self, query, args): + def execute_many(self, query: str, args): """Executes the given query against all the given param sequences. We return the last rowid from the query. """ return self.insert_many(query, args) - def insert_many(self, query, args): + def insert_many(self, query: str, args): """Executes the given query against all the given param sequences. We return the last rowid from the query. """ @@ -115,7 +114,7 @@ def insert_many(self, query, args): finally: cursor.close() - def update_many(self, query, args): + def update_many(self, query: str, args): """Executes the given query against all the given param sequences. We return the rowcount from the query. """ @@ -140,7 +139,7 @@ def _cursor(self): self._ensure_connected() return self._db.cursor() - def _execute(self, cursor, query, args, kwargs): + def _execute(self, cursor: SSCursor, query: str, args, kwargs): try: return cursor.execute(query, kwargs or args) except pymysql.OperationalError: