Skip to content

Commit

Permalink
Update SSD example to report global speed and use proper number of sh…
Browse files Browse the repository at this point in the history
…ards (NVIDIA#810)

Signed-off-by: Janusz Lisiecki <[email protected]>
  • Loading branch information
JanuszL authored Apr 26, 2019
1 parent dbfc920 commit 5cf0ec9
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
13 changes: 10 additions & 3 deletions docs/examples/pytorch/single_stage_detector/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@


class Logger:
def __init__(self, batch_size, local_rank, print_freq=20):
def __init__(self, batch_size, local_rank, n_gpu, print_freq=20):
self.batch_size = batch_size
self.local_rank = local_rank
self.n_gpu = n_gpu
self.print_freq = print_freq

self.processed_samples = 0
Expand Down Expand Up @@ -59,7 +60,7 @@ def end_epoch(self):

if self.local_rank == 0:
print('Epoch {:2d} finished. Time: {:4f} s, Speed: {:4f} img/sec, Average speed: {:4f}'
.format(len(self.epochs_times)-1, epoch_time, epoch_speed, self.average_speed()))
.format(len(self.epochs_times)-1, epoch_time, epoch_speed * self.n_gpu, self.average_speed() * self.n_gpu))

def average_speed(self):
return sum(self.epochs_speeds) / len(self.epochs_speeds)
Expand Down Expand Up @@ -171,7 +172,7 @@ def train(args):
train_loader = get_train_loader(args, dboxes)

acc = 0
logger = Logger(args.batch_size, args.local_rank)
logger = Logger(args.batch_size, args.local_rank, args.N_gpu)

for epoch in range(0, args.epochs):
logger.start_epoch()
Expand Down Expand Up @@ -211,6 +212,12 @@ def train(args):

start_time = time.time()
acc, avg_speed = train(args)
# avg_speed is reported per node, adjust for the global speed
try:
num_shards = torch.distributed.get_world_size()
except RuntimeError:
num_shards = 1
avg_speed = num_shards * avg_speed
training_time = time.time() - start_time

if args.local_rank == 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,17 @@ def __init__(self, default_boxes, args, seed):

try:
shard_id = torch.distributed.get_rank()
num_shards = torch.distributed.get_world_size()
except RuntimeError:
shard_id = 0
num_shards = 1

self.input = ops.COCOReader(
file_root=args.train_coco_root,
annotations_file=args.train_annotate,
skip_empty=True,
shard_id=shard_id,
num_shards=args.N_gpu,
num_shards=num_shards,
ratio=True,
ltrb=True,
random_shuffle=True,
Expand Down

0 comments on commit 5cf0ec9

Please sign in to comment.