Skip to content

Commit

Permalink
update ddp steps.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Sep 4, 2023
1 parent 273c4b8 commit 8c67305
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions text2vec/cosent_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,16 +182,16 @@ def train(
logger.debug("Use device: {}".format(self.device))
self.bert.to(self.device)
set_seed(seed)
num_devices = 1

if data_parallel:
self.bert = nn.DataParallel(self.bert)
world_size = torch.cuda.device_count()
local_rank = int(os.environ["LOCAL_RANK"])
sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=local_rank)
num_devices = torch.cuda.device_count()
sampler = DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, shuffle=False)
else:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False) # not shuffle
total_steps = len(train_dataloader) * num_epochs
total_steps = len(train_dataloader) * num_epochs // num_devices
param_optimizer = list(self.bert.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
Expand Down Expand Up @@ -276,7 +276,8 @@ def train(
current_loss = loss.item()
if verbose:
batch_iterator.set_description(
f"Epoch: {epoch_number + 1}/{num_epochs}, Batch:{step}/{len(train_dataloader)}, Loss: {current_loss:9.4f}")
f"Epoch: {epoch_number + 1}/{num_epochs}, "
f"Batch:{step}/{len(train_dataloader)//num_devices}, Loss: {current_loss:9.4f}")

if gradient_accumulation_steps > 1:
loss = loss / gradient_accumulation_steps
Expand Down

0 comments on commit 8c67305

Please sign in to comment.