Skip to content

Commit

Permalink
update bge train
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Sep 4, 2023
1 parent 8860e5e commit 0255177
Showing 1 changed file with 23 additions and 12 deletions.
35 changes: 23 additions & 12 deletions text2vec/bge_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,21 +269,32 @@ def train(
query, passage = batch
query = self.flat_list(query)
passage = self.flat_list(passage)
query = self.tokenizer(
query,
max_length=self.query_max_len,
truncation=True,
padding=True,
return_tensors='pt'
)
passage = self.tokenizer(
passage,
max_length=self.passage_max_len,
truncation=True,
padding=True,
return_tensors='pt'
)
query = query.to(self.device)
passage = passage.to(self.device)

# get sentence embeddings
with torch.autocast(str(self.device), dtype=torch_type):
q_embeddings = self.encode(
query,
normalize_embeddings=True,
convert_to_tensor=True,
max_seq_length=self.query_max_len
)
p_embeddings = self.encode(
passage,
normalize_embeddings=True,
convert_to_tensor=True,
max_seq_length=self.passage_max_len
)
q_embeddings = self.get_sentence_embeddings(**query)
q_embeddings = torch.nn.functional.normalize(q_embeddings, dim=-1)
q_embeddings = q_embeddings.contiguous()

p_embeddings = self.get_sentence_embeddings(**passage)
p_embeddings = torch.nn.functional.normalize(p_embeddings, dim=-1)
p_embeddings = p_embeddings.contiguous()
scores = self.calc_similarity(q_embeddings, p_embeddings)
scores = scores / temperature
scores = scores.view(q_embeddings.size(0), -1)
Expand Down

0 comments on commit 0255177

Please sign in to comment.