Skip to content

Commit

Permalink
database: rename Adapter to Client (iterative#10151)
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored Dec 9, 2023
1 parent 1d47b03 commit 77b451d
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 25 deletions.
4 changes: 2 additions & 2 deletions dvc/commands/imp_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

class CmdTestDb(CmdBase):
def run(self):
from dvc.database import get_adapter
from dvc.database import get_client
from dvc.database.dbt_utils import DBT_PROJECT_FILE, is_dbt_project
from dvc.dependency.db import _get_dbt_config
from dvc.exceptions import DvcException
Expand Down Expand Up @@ -47,7 +47,7 @@ def run(self):
"provide arguments or set a configuration"
)

adapter = get_adapter(conn_config, project_dir=project_dir, **dbt_config)
adapter = get_client(conn_config, project_dir=project_dir, **dbt_config)
with adapter as db:
ui.write(f"Testing with {db}", styled=True)

Expand Down
12 changes: 6 additions & 6 deletions dvc/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@
from .dbt_models import get_model
from .serializer import export

Adapter = Union[sqla.SQLAlchemyAdapter, dbt_query.DbtAdapter]
Client = Union[sqla.SQLAlchemyClient, dbt_query.DbtClient]


def get_adapter(
def get_client(
config,
project_dir: Optional[str] = None,
profile: Optional[str] = None,
target: Optional[str] = None,
**kwargs: Any,
) -> "ContextManager[Adapter]":
) -> "ContextManager[Client]":
if config:
return sqla.adapter(config, **kwargs)
return dbt_query.adapter(project_dir=project_dir, profile=profile, target=target)
return sqla.client(config, **kwargs)
return dbt_query.client(project_dir=project_dir, profile=profile, target=target)


__all__ = ["export", "get_adapter", "get_model", "Adapter"]
__all__ = ["export", "get_client", "get_model", "Client"]
8 changes: 4 additions & 4 deletions dvc/database/dbt_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


@define
class DbtAdapter:
class DbtClient:
adapter: "BaseAdapter" = field(repr=lambda o: type(o).__qualname__)
creds: Dict[str, Any] = field(repr=False)

Expand Down Expand Up @@ -59,12 +59,12 @@ def test_connection(self, onerror: Optional[Callable[[], Any]] = None) -> None:

@contextmanager
@check_dbt("query")
def adapter(
def client(
project_dir: Optional[str] = None,
profiles_dir: Optional[str] = None,
profile: Optional[str] = None,
target: Optional[str] = None,
) -> Iterator["DbtAdapter"]:
) -> Iterator["DbtClient"]:
from dbt.adapters import factory as adapters_factory
from dbt.adapters.base.impl import BaseAdapter

Expand All @@ -85,4 +85,4 @@ def adapter(
except: # noqa: E722
creds = {}

yield DbtAdapter(adapter, creds)
yield DbtClient(adapter, creds)
8 changes: 4 additions & 4 deletions dvc/database/sqla.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def url_from_config(config: Union[str, URL, Dict[str, str]]) -> URL:


@dataclass
class SQLAlchemyAdapter:
class SQLAlchemyClient:
engine: Engine

@contextmanager
Expand Down Expand Up @@ -61,13 +61,13 @@ def handle_error(url: URL):


@contextmanager
def adapter(
def client(
url_or_config: Union[str, URL, Dict[str, str]], **engine_kwargs: Any
) -> Iterator[SQLAlchemyAdapter]:
) -> Iterator[SQLAlchemyClient]:
url = url_from_config(url_or_config)
with handle_error(url):
engine = create_engine(url, **engine_kwargs)
try:
yield SQLAlchemyAdapter(engine)
yield SQLAlchemyClient(engine)
finally:
engine.dispose()
4 changes: 2 additions & 2 deletions dvc/dependency/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def download(
file_format: Optional[str] = None,
**kwargs: Any,
) -> None:
from dvc.database import export, get_adapter
from dvc.database import export, get_client

db_info = self.info.get(PARAM_DB, {})
query = db_info.get(self.PARAM_QUERY)
Expand All @@ -140,7 +140,7 @@ def download(
raise DvcException(f"connection {connection} not found in config")

project_dir = self.repo.root_dir
with get_adapter(
with get_client(
config, project_dir=project_dir, profile=profile, target=target
) as db:
logger.debug("using adapter: %s", db)
Expand Down
14 changes: 7 additions & 7 deletions tests/func/test_import_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@


@pytest.fixture
def adapter(mocker):
m = mocker.patch("dvc.database.get_adapter")
adapter = mocker.MagicMock()
adapter.query.return_value.__enter__.side_effect = serializers
m.return_value.__enter__.return_value = adapter
def client(mocker):
m = mocker.patch("dvc.database.get_client")
client = mocker.MagicMock()
client.query.return_value.__enter__.side_effect = serializers
m.return_value.__enter__.return_value = client
return m


def test_sql(adapter, tmp_dir, dvc):
def test_sql(client, tmp_dir, dvc):
stage = dvc.imp_db(
sql="select * from model", profile="profile", output_format="json"
)
Expand Down Expand Up @@ -56,7 +56,7 @@ def test_sql(adapter, tmp_dir, dvc):
}


def test_sql_conn_string(adapter, tmp_dir, dvc):
def test_sql_conn_string(client, tmp_dir, dvc):
with dvc.config.edit(level="local") as conf:
conf["db"] = {"conn": {"url": "conn"}}

Expand Down

0 comments on commit 77b451d

Please sign in to comment.