Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
heiwang1997 authored Jun 9, 2023
1 parent faa5860 commit 1051273
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def determine_usable_gpus():
os.environ['CUDA_VISIBLE_DEVICES'] = selection_str

if program_args.gpus > 1 and program_args.accelerator is None:
program_args.accelerator = 'ddp'
program_args.accelerator = 'gpu'
program_args.strategy = 'ddp'


def is_rank_zero():
Expand Down Expand Up @@ -139,7 +140,8 @@ def readable_name_from_exec(exec_list: List[str]):
determine_usable_gpus()
else:
# Align parameters.
program_args.accelerator = 'ddp'
program_args.accelerator = 'gpu'
program_args.strategy = 'ddp'

# Profiling and debugging options
torch.autograd.set_detect_anomaly(program_args.debug)
Expand Down

0 comments on commit 1051273

Please sign in to comment.