Skip to content

Commit

Permalink
Kill Index Attempts for previous model (onyx-dot-app#1088)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuhongsun96 authored Feb 17, 2024
1 parent 269431c commit 514e7f6
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 55 deletions.
4 changes: 4 additions & 0 deletions backend/danswer/background/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from danswer.db.embedding_model import update_embedding_model_status
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import count_unique_cc_pairs_with_index_attempts
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_index_attempt
Expand Down Expand Up @@ -381,6 +382,9 @@ def check_index_swap(db_session: Session) -> None:
db_session=db_session,
)

# Expire jobs for the now past index/embedding model
cancel_indexing_attempts_past_model(db_session)

# Recount aggregates
for cc_pair in all_cc_pairs:
resync_cc_pair(cc_pair, db_session=db_session)
Expand Down
37 changes: 27 additions & 10 deletions backend/danswer/db/index_attempt.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_index_attempt(
def create_index_attempt(
connector_id: int,
credential_id: int,
embedding_model_id: int | None,
embedding_model_id: int,
db_session: Session,
from_beginning: bool = False,
) -> int:
Expand Down Expand Up @@ -248,24 +248,41 @@ def cancel_indexing_attempts_for_connector(
EmbeddingModel.status != IndexModelStatus.FUTURE
)

stmt = delete(IndexAttempt).where(
IndexAttempt.connector_id == connector_id,
IndexAttempt.status == IndexingStatus.NOT_STARTED,
stmt = (
update(IndexAttempt)
.where(
IndexAttempt.connector_id == connector_id,
IndexAttempt.status == IndexingStatus.NOT_STARTED,
)
.values(status=IndexingStatus.FAILED)
)

if not include_secondary_index:
stmt = stmt.where(
or_(
IndexAttempt.embedding_model_id.is_(None),
IndexAttempt.embedding_model_id.in_(subquery),
)
)
stmt = stmt.where(IndexAttempt.embedding_model_id.in_(subquery))

db_session.execute(stmt)

db_session.commit()


def cancel_indexing_attempts_past_model(
db_session: Session,
) -> None:
db_session.execute(
update(IndexAttempt)
.where(
IndexAttempt.status.in_(
[IndexingStatus.IN_PROGRESS, IndexingStatus.NOT_STARTED]
),
IndexAttempt.embedding_model_id == EmbeddingModel.id,
EmbeddingModel.status == IndexModelStatus.PAST,
)
.values(status=IndexingStatus.FAILED)
)

db_session.commit()


def count_unique_cc_pairs_with_index_attempts(
embedding_model_id: int | None,
db_session: Session,
Expand Down
96 changes: 51 additions & 45 deletions backend/danswer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_secondary_db_embedding_model
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.document_index.factory import get_default_document_index
from danswer.llm.factory import get_default_llm
from danswer.search.search_nlp_models import warm_up_models
Expand Down Expand Up @@ -209,6 +210,8 @@ def get_application() -> FastAPI:

@application.on_event("startup")
def startup_event() -> None:
engine = get_sqlalchemy_engine()

verify_auth = fetch_versioned_implementation(
"danswer.auth.users", "verify_auth_setting"
)
Expand Down Expand Up @@ -242,66 +245,69 @@ def startup_event() -> None:
f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}"
)

with Session(get_sqlalchemy_engine()) as db_session:
with Session(engine) as db_session:
db_embedding_model = get_current_db_embedding_model(db_session)
secondary_db_embedding_model = get_secondary_db_embedding_model(db_session)

if ENABLE_RERANKING_REAL_TIME_FLOW:
logger.info("Reranking step of search flow is enabled.")
cancel_indexing_attempts_past_model(db_session)

logger.info(f'Using Embedding model: "{db_embedding_model.model_name}"')
if db_embedding_model.query_prefix or db_embedding_model.passage_prefix:
logger.info(f'Query embedding prefix: "{db_embedding_model.query_prefix}"')
logger.info(
f'Passage embedding prefix: "{db_embedding_model.passage_prefix}"'
)
logger.info(f'Using Embedding model: "{db_embedding_model.model_name}"')
if db_embedding_model.query_prefix or db_embedding_model.passage_prefix:
logger.info(
f'Query embedding prefix: "{db_embedding_model.query_prefix}"'
)
logger.info(
f'Passage embedding prefix: "{db_embedding_model.passage_prefix}"'
)

if MODEL_SERVER_HOST:
logger.info(
f"Using Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}"
)
else:
logger.info("Warming up local NLP models.")
warm_up_models(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
skip_cross_encoders=not ENABLE_RERANKING_REAL_TIME_FLOW,
)
if ENABLE_RERANKING_REAL_TIME_FLOW:
logger.info("Reranking step of search flow is enabled.")

if torch.cuda.is_available():
logger.info("GPU is available")
if MODEL_SERVER_HOST:
logger.info(
f"Using Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}"
)
else:
logger.info("GPU is not available")
logger.info(f"Torch Threads: {torch.get_num_threads()}")
logger.info("Warming up local NLP models.")
warm_up_models(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
skip_cross_encoders=not ENABLE_RERANKING_REAL_TIME_FLOW,
)

if torch.cuda.is_available():
logger.info("GPU is available")
else:
logger.info("GPU is not available")
logger.info(f"Torch Threads: {torch.get_num_threads()}")

logger.info("Verifying query preprocessing (NLTK) data is downloaded")
nltk.download("stopwords", quiet=True)
nltk.download("wordnet", quiet=True)
nltk.download("punkt", quiet=True)
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
nltk.download("stopwords", quiet=True)
nltk.download("wordnet", quiet=True)
nltk.download("punkt", quiet=True)

logger.info("Verifying default connector/credential exist.")
with Session(get_sqlalchemy_engine()) as db_session:
logger.info("Verifying default connector/credential exist.")
create_initial_public_credential(db_session)
create_initial_default_connector(db_session)
associate_default_cc_pair(db_session)

logger.info("Loading default Prompts and Personas")
load_chat_yamls()
logger.info("Loading default Prompts and Personas")
load_chat_yamls()

logger.info("Verifying Document Index(s) is/are available.")
logger.info("Verifying Document Index(s) is/are available.")

document_index = get_default_document_index(
primary_index_name=db_embedding_model.index_name,
secondary_index_name=secondary_db_embedding_model.index_name
if secondary_db_embedding_model
else None,
)
document_index.ensure_indices_exist(
index_embedding_dim=db_embedding_model.model_dim,
secondary_index_embedding_dim=secondary_db_embedding_model.model_dim
if secondary_db_embedding_model
else None,
)
document_index = get_default_document_index(
primary_index_name=db_embedding_model.index_name,
secondary_index_name=secondary_db_embedding_model.index_name
if secondary_db_embedding_model
else None,
)
document_index.ensure_indices_exist(
index_embedding_dim=db_embedding_model.model_dim,
secondary_index_embedding_dim=secondary_db_embedding_model.model_dim
if secondary_db_embedding_model
else None,
)

optional_telemetry(
record_type=RecordType.VERSION, data={"version": __version__}
Expand Down
4 changes: 4 additions & 0 deletions backend/danswer/server/documents/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session
from danswer.db.index_attempt import cancel_indexing_attempts_for_connector
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_index_attempts_for_cc_pair
from danswer.db.index_attempt import get_latest_index_attempts
Expand Down Expand Up @@ -456,6 +457,9 @@ def update_connector_from_model(
if updated_connector.disabled:
cancel_indexing_attempts_for_connector(connector_id, db_session)

# Just for good measure
cancel_indexing_attempts_past_model(db_session)

return ConnectorSnapshot(
id=updated_connector.id,
name=updated_connector.name,
Expand Down

0 comments on commit 514e7f6

Please sign in to comment.