Skip to content

Commit

Permalink
Rename compress to rerank
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiheng-huang committed Aug 20, 2024
1 parent 00559c8 commit f128103
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 12 deletions.
9 changes: 4 additions & 5 deletions denser_retriever/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(self, top_k: int = 50, weight: float = 0.5):
self.weight = weight

@abstractmethod
def compress_documents(
def rerank(
self,
documents: Sequence[Document],
query: str,
Expand All @@ -27,7 +27,7 @@ def __init__(self, model_name: str, **kwargs):
super().__init__()
self.model = HuggingFaceCrossEncoder(model_name=model_name)

def compress_documents(
def rerank(
self,
documents: Sequence[Document],
query: str,
Expand All @@ -36,9 +36,8 @@ def compress_documents(
Rerank documents using CrossEncoder.
Args:
documents: A sequence of documents to compress.
query: The query to use for compressing the documents.
callbacks: Callbacks to run during the compression process.
documents: A sequence of documents to rerank.
query: The query to use for ranking the documents.
Returns:
A list of tuples containing the document and its score.
Expand Down
6 changes: 3 additions & 3 deletions denser_retriever/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,11 @@ def _retrieve_by_linear_or_rank(
if self.reranker:
start_time = time.time()
docs = [doc for doc, _ in passages[:self.reranker.top_k]]
compressed_docs = self.reranker.compress_documents(docs, query)
reranked_docs = self.reranker.rerank(docs, query)

passages = merge_results(
passages,
compressed_docs,
reranked_docs,
1.0,
self.reranker.weight,
self.combine_mode,
Expand Down Expand Up @@ -195,7 +195,7 @@ def _retrieve_with_features(

reranked_docs = []
if self.reranker:
reranked_docs = self.reranker.compress_documents(combined_docs, query)
reranked_docs = self.reranker.rerank(combined_docs, query)

_, ks_score_dict, ks_rank_dict = docs_to_dict(ks_docs)
_, vs_score_dict, vs_rank_dict = docs_to_dict(vs_docs)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def test_rerank() -> None:
"bbb3",
]
docs = list(map(lambda text: Document(page_content=text), texts))
compressor = HFReranker(model_name="cross-encoder/ms-marco-MiniLM-L-6-v2")
actual_docs = compressor.compress_documents(docs, "bbb2")
reranker = HFReranker(model_name="cross-encoder/ms-marco-MiniLM-L-6-v2")
actual_docs = reranker.rerank(docs, "bbb2")
actual = list(map(lambda doc: doc[0].page_content, actual_docs))[0:3]
expected_returned = ["bbb2", "bbb1", "bbb3"]
expected_not_returned = ["aaa1", "aaa2", "aaa3"]
Expand All @@ -29,6 +29,6 @@ def test_rerank() -> None:

def test_rerank_empty() -> None:
docs: List[Document] = []
compressor = HFReranker(model_name="cross-encoder/ms-marco-MiniLM-L-6-v2")
actual_docs = compressor.compress_documents(docs, "query")
reranker = HFReranker(model_name="cross-encoder/ms-marco-MiniLM-L-6-v2")
actual_docs = reranker.rerank(docs, "query")
assert len(actual_docs) == 0

0 comments on commit f128103

Please sign in to comment.