Skip to content

Commit

Permalink
Fix eval function for distributed training (PaddlePaddle#1228)
Browse files Browse the repository at this point in the history
  • Loading branch information
wanghaoshuang authored Jul 1, 2022
1 parent fc90903 commit 81fa182
Showing 1 changed file with 6 additions and 26 deletions.
32 changes: 6 additions & 26 deletions demo/auto_compression/semantic_segmentation/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ def parse_args():
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):

nranks = paddle.distributed.ParallelEnv().local_rank

batch_sampler = paddle.io.DistributedBatchSampler(
if nranks > 1 and paddle.distributed.get_rank() != 0:
return
batch_sampler = paddle.io.BatchSampler(
eval_dataset, batch_size=1, shuffle=False, drop_last=False)
loader = paddle.io.DataLoader(
eval_dataset,
Expand Down Expand Up @@ -116,30 +117,9 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
paddle.to_tensor(label),
eval_dataset.num_classes,
ignore_index=eval_dataset.ignore_index)

if nranks > 1:
intersect_area_list = []
pred_area_list = []
label_area_list = []
paddle.distributed.all_gather(intersect_area_list, intersect_area)
paddle.distributed.all_gather(pred_area_list, pred_area)
paddle.distributed.all_gather(label_area_list, label_area)

# Some image has been evaluated and should be eliminated in last iter
if (iter + 1) * nranks > len(eval_dataset):
valid = len(eval_dataset) - iter * nranks
intersect_area_list = intersect_area_list[:valid]
pred_area_list = pred_area_list[:valid]
label_area_list = label_area_list[:valid]

for i in range(len(intersect_area_list)):
intersect_area_all = intersect_area_all + intersect_area_list[i]
pred_area_all = pred_area_all + pred_area_list[i]
label_area_all = label_area_all + label_area_list[i]
else:
intersect_area_all = intersect_area_all + intersect_area
pred_area_all = pred_area_all + pred_area
label_area_all = label_area_all + label_area
intersect_area_all = intersect_area_all + intersect_area
pred_area_all = pred_area_all + pred_area
label_area_all = label_area_all + label_area

class_iou, miou = metrics.mean_iou(intersect_area_all, pred_area_all,
label_area_all)
Expand Down

0 comments on commit 81fa182

Please sign in to comment.