From 02551773b203ccfbef963128ad2e9a5ddc2f732a Mon Sep 17 00:00:00 2001 From: shibing624 Date: Mon, 4 Sep 2023 11:31:16 +0800 Subject: [PATCH] update bge train --- text2vec/bge_model.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/text2vec/bge_model.py b/text2vec/bge_model.py index 226c8a9..9b758ea 100644 --- a/text2vec/bge_model.py +++ b/text2vec/bge_model.py @@ -269,21 +269,32 @@ def train( query, passage = batch query = self.flat_list(query) passage = self.flat_list(passage) + query = self.tokenizer( + query, + max_length=self.query_max_len, + truncation=True, + padding=True, + return_tensors='pt' + ) + passage = self.tokenizer( + passage, + max_length=self.passage_max_len, + truncation=True, + padding=True, + return_tensors='pt' + ) + query = query.to(self.device) + passage = passage.to(self.device) # get sentence embeddings with torch.autocast(str(self.device), dtype=torch_type): - q_embeddings = self.encode( - query, - normalize_embeddings=True, - convert_to_tensor=True, - max_seq_length=self.query_max_len - ) - p_embeddings = self.encode( - passage, - normalize_embeddings=True, - convert_to_tensor=True, - max_seq_length=self.passage_max_len - ) + q_embeddings = self.get_sentence_embeddings(**query) + q_embeddings = torch.nn.functional.normalize(q_embeddings, dim=-1) + q_embeddings = q_embeddings.contiguous() + + p_embeddings = self.get_sentence_embeddings(**passage) + p_embeddings = torch.nn.functional.normalize(p_embeddings, dim=-1) + p_embeddings = p_embeddings.contiguous() scores = self.calc_similarity(q_embeddings, p_embeddings) scores = scores / temperature scores = scores.view(q_embeddings.size(0), -1)