diff --git a/simcse/tool.py b/simcse/tool.py index 12b9a098..773dbeea 100644 --- a/simcse/tool.py +++ b/simcse/tool.py @@ -176,6 +176,31 @@ def build_index(self, sentences_or_file_path: Union[str, List[str]], self.is_faiss_index = False self.index["index"] = index logger.info("Finished") + + def add_to_index(self, sentences_or_file_path: Union[str, List[str]], + device: str = None, + batch_size: int = 64): + + # if the input sentence is a string, we assume it's the path of file that stores various sentences + if isinstance(sentences_or_file_path, str): + sentences = [] + with open(sentences_or_file_path, "r") as f: + logging.info("Loading sentences from %s ..." % (sentences_or_file_path)) + for line in tqdm(f): + sentences.append(line.rstrip()) + sentences_or_file_path = sentences + + logger.info("Encoding embeddings for sentences...") + embeddings = self.encode(sentences_or_file_path, device=device, batch_size=batch_size, normalize_to_unit=True, return_numpy=True) + + if self.is_faiss_index: + self.index["index"].add(embeddings.astype(np.float32)) + else: + self.index["index"] = np.concatenate((self.index["index"], embeddings)) + self.index["sentences"] += sentences_or_file_path + logger.info("Finished") + + def search(self, queries: Union[str, List[str]], device: str = None,