Skip to content

Commit

Permalink
refactor: fix experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
jotyy committed Aug 20, 2024
1 parent f128103 commit c94a9c1
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 42 deletions.
8 changes: 1 addition & 7 deletions denser_retriever/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_huggingface import HuggingFaceEmbeddings

from denser_retriever.gradient_boost import DenserGradientBoost
from denser_retriever.keyword import DenserKeywordSearch
Expand All @@ -20,10 +19,6 @@
)
from denser_retriever.vectordb.base import DenserVectorDB

DEFAULT_EMBEDDINGS = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)

config_to_features = {
"es+vs": ["1,2,3,4,5,6", None],
"es+rr": ["1,2,3,7,8,9", None],
Expand Down Expand Up @@ -136,7 +131,7 @@ def _retrieve_by_linear_or_rank(

if self.reranker:
start_time = time.time()
docs = [doc for doc, _ in passages[:self.reranker.top_k]]
docs = [doc for doc, _ in passages[: self.reranker.top_k]]
reranked_docs = self.reranker.rerank(docs, query)

passages = merge_results(
Expand All @@ -149,7 +144,6 @@ def _retrieve_by_linear_or_rank(
rerank_time_sec = time.time() - start_time
logger.info(f"Rerank time: {rerank_time_sec:.3f} sec.")


return passages[:k]

def _retrieve_by_model(
Expand Down
5 changes: 1 addition & 4 deletions experiments/index_and_query_local_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,12 @@
vector_db=MilvusDenserVectorDB(
auto_id=True,
connection_args={"uri": "http://localhost:19530"},
embedding_function=embeddings,
),
keyword_search=ElasticKeywordSearch(
es_connection=create_elasticsearch_client(url="http://localhost:9200"),
),
reranker=reranker,
gradient_boost=XGradientBoost(
"experiments/models/msmarco_xgb_es+vs+rr_n.json"
),
gradient_boost=XGradientBoost("experiments/models/msmarco_xgb_es+vs+rr_n.json"),
embeddings=embeddings,
combine_mode="model",
xgb_model_features="es+vs+rr_n",
Expand Down
28 changes: 12 additions & 16 deletions experiments/index_and_query_website.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,41 @@
import textwrap

from langchain_community.document_loaders import WebBaseLoader
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter

from denser_retriever.gradient_boost import DenserGradientBoost
from denser_retriever.gradient_boost import XGradientBoost
from denser_retriever.keyword import (
DenserKeywordSearch,
ElasticKeywordSearch,
create_elasticsearch_client,
)
from denser_retriever.reranker import DenserReranker
from denser_retriever.retriever import DEFAULT_EMBEDDINGS, DenserRetriever
from denser_retriever.reranker import HFReranker
from denser_retriever.retriever import DenserRetriever
from denser_retriever.vectordb.milvus import MilvusDenserVectorDB

from langchain_community.document_loaders import WebBaseLoader

web_site = "https://denser.ai"
loader = WebBaseLoader(web_site)

docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
texts = text_splitter.split_documents(docs)

embeddings = DEFAULT_EMBEDDINGS
retriever = DenserRetriever(
index_name="agent_webpage",
vector_db=MilvusDenserVectorDB(
collection_name="agent_webpage",
auto_id=True,
connection_args={"uri": "http://localhost:19530"},
embedding_function=embeddings,
),
keyword_search=DenserKeywordSearch(
keyword_search=ElasticKeywordSearch(
index_name="agent_webpage",
field_types={"title": {"type": "keyword"}},
es_connection=create_elasticsearch_client(url="http://localhost:9200"),
),
embeddings=embeddings,
reranker=DenserReranker(model_name="cross-encoder/ms-marco-MiniLM-L-6-v2"),
gradient_boost=DenserGradientBoost(
"experiments/models/scifact_xgb_es+vs+rr_n.json"
embeddings=HuggingFaceEmbeddings(
model_name="sentence-transformers/msmarco-MiniLM-L-6-v2"
),
reranker=HFReranker(model_name="cross-encoder/ms-marco-MiniLM-L-6-v2"),
gradient_boost=XGradientBoost("experiments/models/scifact_xgb_es+vs+rr_n.json"),
combine_mode="model",
xgb_model_features="es+vs+rr_n",
)
Expand All @@ -54,5 +52,3 @@
print(f"\nMetadata: {r[0].metadata}")
print(f"Score: {r[1]}")
print(f"{'='*40}\n")

retriever.clear()
29 changes: 14 additions & 15 deletions experiments/utils.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
from denser_retriever.keyword import DenserKeywordSearch, create_elasticsearch_client
from denser_retriever.reranker import DenserReranker
from denser_retriever.retriever import DEFAULT_EMBEDDINGS
from langchain_huggingface import HuggingFaceEmbeddings
from denser_retriever.keyword import ElasticKeywordSearch, create_elasticsearch_client
from denser_retriever.reranker import HFReranker
from denser_retriever.vectordb.milvus import MilvusDenserVectorDB

index_name="unit_test_retriever"
index_name = "unit_test_retriever"

embeddings = DEFAULT_EMBEDDINGS
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/msmarco-MiniLM-L-6-v2"
)

milvus = MilvusDenserVectorDB(
collection_name=index_name,
embedding_function=embeddings,
connection_args={"uri": "http://localhost:19530"},
auto_id=True,
)
collection_name=index_name,
embedding_function=embeddings,
connection_args={"uri": "http://localhost:19530"},
auto_id=True,
)

elasticsearch = DenserKeywordSearch(
elasticsearch = ElasticKeywordSearch(
index_name=index_name,
field_types={
"title": {"type": "keyword"},
},
es_connection=create_elasticsearch_client(url="http://localhost:9200"),
)
reranker = DenserReranker(model_name="cross-encoder/ms-marco-MiniLM-L-6-v2")
reranker = HFReranker(model_name="cross-encoder/ms-marco-MiniLM-L-6-v2")

0 comments on commit c94a9c1

Please sign in to comment.