Skip to content

Commit

Permalink
update local rank value
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Sep 16, 2023
1 parent 63be6d2 commit aa87312
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion text2vec/bertmatching_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def train(
if data_parallel:
self.bert = nn.DataParallel(self.bert)
num_devices = torch.cuda.device_count()
local_rank = int(os.environ["LOCAL_RANK"])
local_rank = int(os.environ.get("LOCAL_RANK", 0))
sampler = DistributedSampler(train_dataset, num_replicas=num_devices, rank=local_rank)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, shuffle=False)
else:
Expand Down
2 changes: 1 addition & 1 deletion text2vec/bge_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def train(
if data_parallel:
self.bert = nn.DataParallel(self.bert)
num_devices = torch.cuda.device_count()
local_rank = int(os.environ["LOCAL_RANK"])
local_rank = int(os.environ.get("LOCAL_RANK", 0))
sampler = DistributedSampler(train_dataset, num_replicas=num_devices, rank=local_rank)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, shuffle=False)
else:
Expand Down
2 changes: 1 addition & 1 deletion text2vec/cosent_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def train(
if data_parallel:
self.bert = nn.DataParallel(self.bert)
num_devices = torch.cuda.device_count()
local_rank = int(os.environ["LOCAL_RANK"])
local_rank = int(os.environ.get("LOCAL_RANK", 0))
sampler = DistributedSampler(train_dataset, num_replicas=num_devices, rank=local_rank)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, shuffle=False)
else:
Expand Down
2 changes: 1 addition & 1 deletion text2vec/sentencebert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def train(
if data_parallel:
self.bert = nn.DataParallel(self.bert)
num_devices = torch.cuda.device_count()
local_rank = int(os.environ["LOCAL_RANK"])
local_rank = int(os.environ.get("LOCAL_RANK", 0))
sampler = DistributedSampler(train_dataset, num_replicas=num_devices, rank=local_rank)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, shuffle=False)
else:
Expand Down

0 comments on commit aa87312

Please sign in to comment.