Skip to content

Commit

Permalink
hack for 4 gpu per node cross-node training.
Browse files Browse the repository at this point in the history
  • Loading branch information
Jinyang Li committed Apr 30, 2024
1 parent f4f5b63 commit 1e14e40
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions utils/general_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def init_distributed(args):

DEFAULT_GROUP = dist.group.WORLD

num_gpu_per_node = torch.cuda.device_count()
num_gpu_per_node = one_node_device_count()
n_of_nodes = WORLD_SIZE // num_gpu_per_node
all_in_node_group = []
for rank in range(n_of_nodes):
Expand All @@ -196,10 +196,14 @@ def init_distributed(args):
DEFAULT_GROUP = SingleGPUGroup()
IN_NODE_GROUP = SingleGPUGroup()

def one_node_device_count():
return 4 # HACK: because in perl cluster, it is always 4 GPUs per node. TODO: change it back.
return torch.cuda.device_count()

def get_first_rank_on_cur_node():
global GLOBAL_RANK
NODE_ID = GLOBAL_RANK // torch.cuda.device_count()
first_rank_in_node = NODE_ID * torch.cuda.device_count()
NODE_ID = GLOBAL_RANK // one_node_device_count()
first_rank_in_node = NODE_ID * one_node_device_count()
return first_rank_in_node


Expand Down

0 comments on commit 1e14e40

Please sign in to comment.