Skip to content

Commit

Permalink
updated to use DB configuration from starlite
Browse files Browse the repository at this point in the history
  • Loading branch information
cofin committed Oct 6, 2022
1 parent 9332da1 commit 634375c
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 289 deletions.
72 changes: 15 additions & 57 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ python = ">=3.10,<4.0"
redis = "*"
rich = "*"
sqlalchemy = {git = "https://github.com/sqlalchemy/sqlalchemy.git", branch = "main", extras = ["asyncio"]}
starlite = {version = "^1.24.0", extras = ["brotli"]}
starlite = {version = "^1.25.0", extras = ["brotli"]}
starlite-jwt = "^1.4.0"
uvicorn = {extras = ["standard"], version = "*"}

Expand Down
4 changes: 2 additions & 2 deletions src/server/app/cli/commands/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from app.asgi import app
from app.cli.console import console
from app.config import settings
from app.core.db import db_session
from app.core.db import db_engine, db_session
from app.core.db.models import BaseModel, meta

logger = logging.getLogger()
Expand Down Expand Up @@ -258,7 +258,7 @@ def show_database_revision() -> None:
async def drop_tables() -> None:
logger.info("Connecting to database backend.")

async with engine.begin() as db:
async with db_engine().begin() as db:
logger.info("[bold red] Dropping the db")
await db.run_sync(BaseModel.metadata.drop_all)
logger.info("[bold red] Dropping the version table")
Expand Down
4 changes: 2 additions & 2 deletions src/server/app/config/alembic.ini
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# A generic, single database configuration.

[alembic]
prepend_sys_path = src:.
prepend_sys_path = src/server:.
# path to migration scripts
script_location = src/app/core/db/migrations
script_location = src/server/app/core/db/migrations

# template used to generate migration files
file_template = %%(year)d-%%(month).2d-%%(day).2d_%%(rev)s
Expand Down
2 changes: 1 addition & 1 deletion src/server/app/config/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class Config:
CONNECT_ARGS: dict[str, Any] = {}
URL: str
MIGRATION_CONFIG: str = f"{BASE_DIR}/config/alembic.ini"
MIGRATION_PATH: str = f"{BASE_DIR}/db/migrations"
MIGRATION_PATH: str = f"{BASE_DIR}/core/db/migrations"
MIGRATION_DDL_VERSION_TABLE: str = "ddl_version"


Expand Down
28 changes: 14 additions & 14 deletions src/server/app/core/db/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from contextlib import contextmanager
from typing import TYPE_CHECKING, Union, cast
from typing import TYPE_CHECKING, cast

from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from sqlalchemy.pool import NullPool
from starlite.plugins.sql_alchemy import SQLAlchemyConfig, SQLAlchemyEngineConfig, SQLAlchemySessionConfig

Expand Down Expand Up @@ -37,34 +36,35 @@


@contextmanager
def db_session() -> "Iterator[Union[Session, AsyncSession]]":
def db_session() -> "Iterator[AsyncSession]":
"""Use this to get a database session where you can't in starlite
Returns:
config.session_class: _description_
"""
create_engine_callable = (
config.create_async_engine_callable if config.use_async_engine else config.create_engine_callable
)
session_maker_kwargs = session_config.dict(
exclude_none=True, exclude={"future"} if config.use_async_engine else set()
)
session_class = config.session_class or (AsyncSession if config.use_async_engine else Session)
engine = create_engine_callable(config.connection_string, **config.engine_config_dict)
session_maker = config.session_maker(engine, class_=session_class, **session_maker_kwargs) # type: ignore[arg-type]
session = cast("Union[Session, AsyncSession]", session_maker())
try:
session = config.session_maker()
yield session
finally:
session.close()


def db_engine() -> "AsyncEngine":
"""Fetch the db engine
Returns:
config.session_class: _description_
"""
return cast("AsyncEngine", config.engine)


__all__ = [
"config",
"session_config",
"engine_config",
"db_types",
"db_session",
"db_engine",
"models",
"repositories",
]
Loading

0 comments on commit 634375c

Please sign in to comment.