Skip to content

Commit

Permalink
Update bertmatching_model.py
Browse files Browse the repository at this point in the history
add bf16 and data parallel
  • Loading branch information
wptoux authored Jul 14, 2023
1 parent f12b7e5 commit 400c97e
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion text2vec/bertmatching_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def train_model(
use_hf_dataset: bool = False,
hf_dataset_name: str = "STS-B",
save_model_every_epoch: bool = True,
bf16: bool = False,
data_parallel: bool = False,
):
"""
Trains the model on 'train_file'
Expand All @@ -128,6 +130,8 @@ def train_model(
use_hf_dataset (optional): Whether to use the HuggingFace datasets for training.
hf_dataset_name (optional): Name of the dataset to use for the HuggingFace datasets.
save_model_every_epoch (optional): Save model checkpoint every epoch.
bf16 (optional): Use bfloat16 amp training.
data_parallel: Use multi-gpu data parallel training.
Returns:
global_step: Number of global steps trained
training_details: Full training progress scores
Expand Down Expand Up @@ -163,6 +167,8 @@ def train_model(
max_grad_norm=max_grad_norm,
max_steps=max_steps,
save_model_every_epoch=save_model_every_epoch,
bf16=bf16,
data_parallel=data_parallel,
)
logger.info(f" Training model done. Saved to {output_dir}.")

Expand All @@ -185,6 +191,8 @@ def train(
max_grad_norm: float = 1.0,
max_steps: int = -1,
save_model_every_epoch: bool = True,
bf16: bool = False,
data_parallel: bool = False,
):
"""
Trains the model on train_dataset.
Expand All @@ -196,6 +204,9 @@ def train(
self.model.bert.to(device)
set_seed(seed)

if data_parallel:
self.bert = nn.DataParallel(self.bert)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size) # keep the order of the data, not shuffle
total_steps = len(train_dataloader) * num_epochs
param_optimizer = list(self.model.bert.named_parameters())
Expand Down Expand Up @@ -270,7 +281,13 @@ def train(
token_type_ids = inputs.get('token_type_ids', None)
if token_type_ids is not None:
token_type_ids = token_type_ids.squeeze(1).to(self.device)
loss, logits, probs = self.model(input_ids, attention_mask, token_type_ids, labels)

if bf16:
with torch.autocast('cuda', dtype=torch.bfloat16):
loss, logits, probs = self.model(input_ids, attention_mask, token_type_ids, labels)
else:
loss, logits, probs = self.model(input_ids, attention_mask, token_type_ids, labels)

current_loss = loss.item()

if verbose:
Expand Down

0 comments on commit 400c97e

Please sign in to comment.