Skip to content

Commit

Permalink
refactor: add class interface
Browse files Browse the repository at this point in the history
  • Loading branch information
aplavin committed Dec 22, 2015
1 parent 7efea50 commit 14fe4bc
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 18 deletions.
2 changes: 1 addition & 1 deletion pandasql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .sqldf import sqldf
from .sqldf import sqldf, PandaSQL
import os
import pandas as pd

Expand Down
38 changes: 22 additions & 16 deletions pandasql/sqldf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,26 @@ class PandaSQLException(Exception):
pass


class PandaSQL:
def __init__(self, db_uri):
self.engine = create_engine(db_uri)
if self.engine.name not in ('sqlite', 'postgresql'):
raise PandaSQLException('Currently only sqlite and postgresql are supported.')

def __call__(self, query, **env):
tables = _extract_table_names(query)
for table in tables:
if table not in env:
raise PandaSQLException("%s not found" % table)
df = env[table]
df = _ensure_data_frame(df, table)
_write_table(table, df, self.engine)

result = read_sql(query, self.engine)

return result


def _ensure_data_frame(obj, name):
"""
obj a python object to be converted to a DataFrame
Expand All @@ -36,6 +56,7 @@ def _ensure_data_frame(obj, name):
if not isinstance(df, pd.DataFrame):
raise PandaSQLException("%s is not of a supported data type" % name)

# XXX: what is this for? Tests pass without it.
for col in df:
if df[col].dtype == np.int64:
df[col] = df[col].astype(np.float)
Expand Down Expand Up @@ -98,19 +119,4 @@ def sqldf(q, env, db_uri='sqlite:///:memory:'):
>>> sqldf("select * from df;", locals())
>>> sqldf("select avg(x) from df;", locals())
"""

engine = create_engine(db_uri)
if engine.name not in ('sqlite', 'postgresql'):
raise PandaSQLException('Currently only sqlite and postgresql are supported.')

tables = _extract_table_names(q)
for table in tables:
if table not in env:
raise PandaSQLException("%s not found" % table)
df = env[table]
df = _ensure_data_frame(df, table)
_write_table(table, df, engine)

result = read_sql(q, engine)

return result
return PandaSQL(db_uri)(q, **env)
18 changes: 17 additions & 1 deletion pandasql/tests/tests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pandas as pd
from pandasql import sqldf, load_meat
from pandasql import PandaSQL, sqldf, load_meat
import string
import pytest
import pandas.util.testing as pdtest
Expand All @@ -10,6 +10,11 @@ def db_uri(request):
return request.param


@pytest.fixture()
def pandasql(db_uri):
return PandaSQL(db_uri)


def test_select(db_uri):
df = pd.DataFrame({
"letter_pos": [i for i in range(len(string.ascii_letters))],
Expand All @@ -21,6 +26,17 @@ def test_select(db_uri):
pdtest.assert_frame_equal(result, df.head(10))


def test_select_using_class(pandasql):
df = pd.DataFrame({
"letter_pos": [i for i in range(len(string.ascii_letters))],
"l2": list(string.ascii_letters)
})
result = pandasql("SELECT * FROM df LIMIT 10;", **locals())

assert len(result) == 10
pdtest.assert_frame_equal(result, df.head(10))


def test_join(db_uri):
df = pd.DataFrame({
"letter_pos": [i for i in range(len(string.ascii_letters))],
Expand Down

0 comments on commit 14fe4bc

Please sign in to comment.