Skip to content

Commit

Permalink
update bge model eval.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Sep 3, 2023
1 parent 4012244 commit f305255
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 35 deletions.
1 change: 0 additions & 1 deletion examples/training_bge_model_mydata.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def main():
model = SentenceModel(
model_name_or_path=args.output_dir,
encoder_type=args.encoder_type,
max_seq_length=args.max_seq_length
)
test_data = load_text_matching_test_data(args.test_file)

Expand Down
11 changes: 3 additions & 8 deletions text2vec/bge_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
def __len__(self):
return len(self.dataset)

def text_2_id(self, text: str, max_len: int):
def text_2_id(self, text, max_len: int):
return self.tokenizer(
text,
max_length=max_len,
Expand All @@ -82,11 +82,6 @@ def __getitem__(self, index: int):
else:
negs = random.sample(self.dataset[index]['neg'], self.train_group_size - 1)
passage.extend(negs)
return query, passage


if isinstance(query, list):
query = sum(query, [])
if isinstance(passage, list):
passage = sum(passage, [])
query_tokens = self.text_2_id(query, self.query_max_len)
passage_tokens = self.text_2_id(passage, self.passage_max_len)
return query_tokens, passage_tokens
126 changes: 100 additions & 26 deletions text2vec/bge_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from text2vec.bge_dataset import BgeTrainDataset
from text2vec.sentence_model import SentenceModel
from text2vec.utils.stats_util import compute_spearmanr, compute_pearsonr
from text2vec.utils.stats_util import set_seed


Expand Down Expand Up @@ -151,13 +152,9 @@ def calc_loss(self, y_true, y_pred):
loss = nn.CrossEntropyLoss(reduction='mean')(y_pred, y_true)
return loss

def calc_similarity(self, q_embs, p_embs):
"""
Calc similarity with two sentence embeddings, Cosine similarity
"""
if len(p_embs.size()) == 2:
return torch.matmul(q_embs, p_embs.transpose(0, 1))
return torch.matmul(q_embs, p_embs.transpose(-2, -1))
@staticmethod
def flat_list(l):
return [item for sublist in l for item in sublist]

def train(
self,
Expand Down Expand Up @@ -246,6 +243,7 @@ def train(
"eval_spearman": [],
"eval_pearson": [],
}

for current_epoch in trange(int(num_epochs), desc="Epoch", disable=False, mininterval=0):
self.bert.train()
current_loss = 0
Expand All @@ -261,30 +259,36 @@ def train(
steps_trained_in_current_epoch -= 1
continue
query, passage = batch
# query [batch, 1, seq_len] -> [batch, seq_len]
q_input_ids = query.get('input_ids').squeeze(1).to(self.device)
q_attention_mask = query.get('attention_mask').squeeze(1).to(self.device)
q_token_type_ids = query.get('token_type_ids', None)
if q_token_type_ids is not None:
q_token_type_ids = q_token_type_ids.squeeze(1).to(self.device)
# passage [batch, 1, seq_len] -> [batch, seq_len]
p_input_ids = passage.get('input_ids').squeeze(1).to(self.device)
p_attention_mask = passage.get('attention_mask').squeeze(1).to(self.device)
p_token_type_ids = passage.get('token_type_ids', None)
if p_token_type_ids is not None:
p_token_type_ids = p_token_type_ids.squeeze(1).to(self.device)

# get sentence embeddings of BERT encoder
with torch.autocast(self.device, dtype=torch_type):
q_embeddings = self.get_sentence_embeddings(q_input_ids, q_attention_mask, q_token_type_ids)
p_embeddings = self.get_sentence_embeddings(p_input_ids, p_attention_mask, p_token_type_ids)
scores = self.calc_similarity(q_embeddings, p_embeddings)
query = self.flat_list(query)
passage = self.flat_list(passage)
query = self.tokenizer(
query,
max_length=self.query_max_len,
truncation=True,
padding=True,
return_tensors='pt'
)
passage = self.tokenizer(
passage,
max_length=self.passage_max_len,
truncation=True,
padding=True,
return_tensors='pt'
)
query = query.to(self.device)
passage = passage.to(self.device)

# get sentence embeddings
with torch.autocast(str(self.device), dtype=torch_type):
q_embeddings = self.get_sentence_embeddings(**query)
p_embeddings = self.get_sentence_embeddings(**passage)
scores = torch.cosine_similarity(q_embeddings, p_embeddings)
scores = scores / temperature
scores = scores.view(q_embeddings.size(0), -1)

target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
target = target * (p_embeddings.size(0) // q_embeddings.size(0))
loss = self.calc_loss(scores, target)
loss = self.calc_loss(target, scores)
current_loss = loss.item()
if verbose:
batch_iterator.set_description(
Expand Down Expand Up @@ -322,3 +326,73 @@ def train(
return global_step, training_progress_scores

return global_step, training_progress_scores

def eval_model(self, eval_dataset: Dataset, output_dir: str = None, verbose: bool = True, batch_size: int = 16):
"""
Evaluates the model on eval_df. Saves results to args.output_dir
result: Dictionary containing evaluation results.
"""
result = self.evaluate(eval_dataset, output_dir, batch_size=batch_size)
self.results.update(result)

if verbose:
logger.info(self.results)

return result

def evaluate(self, eval_dataset, output_dir: str = None, batch_size: int = 16):
"""
Evaluates the model on eval_dataset.
Utility function to be used by the eval_model() method. Not intended to be used directly.
"""
results = {}

eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size)
self.bert.to(self.device)
self.bert.eval()

batch_labels = []
batch_preds = []
for batch in tqdm(eval_dataloader, disable=False, desc="Running Evaluation"):
query, passage = batch
query = self.flat_list(query)
passage = self.flat_list(passage)
query = self.tokenizer(
query,
max_length=self.query_max_len,
truncation=True,
padding=True,
return_tensors='pt'
)
passage = self.tokenizer(
passage,
max_length=self.passage_max_len,
truncation=True,
padding=True,
return_tensors='pt'
)
query = query.to(self.device)
passage = passage.to(self.device)

with torch.no_grad():
q_embeddings = self.get_sentence_embeddings(**query)
p_embeddings = self.get_sentence_embeddings(**passage)
preds = torch.cosine_similarity(q_embeddings, p_embeddings)
batch_preds.extend(preds.cpu().numpy())

spearman = compute_spearmanr(batch_labels, batch_preds)
pearson = compute_pearsonr(batch_labels, batch_preds)
logger.debug(f"labels: {batch_labels[:10]}")
logger.debug(f"preds: {batch_preds[:10]}")
logger.debug(f"pearson: {pearson}, spearman: {spearman}")

results["eval_spearman"] = spearman
results["eval_pearson"] = pearson
if output_dir:
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, "eval_results.txt"), "w") as writer:
for key in sorted(results.keys()):
writer.write("{} = {}\n".format(key, str(results[key])))

return results

0 comments on commit f305255

Please sign in to comment.