Skip to content

Commit

Permalink
Merge pull request NVIDIA#391 from NVIDIA/persistent_sync_bn_group8_fix
Browse files Browse the repository at this point in the history
Fixing rank mapping for bn_group size == 8
  • Loading branch information
jjsjann123 authored Jul 13, 2019
2 parents 1483f22 + 89ae9e5 commit a00952b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions apex/contrib/groupbn/batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ def __init__(self, num_features, fuse_relu=False, bn_group=1, max_cta_per_sm=2,
self.pair_data2 = bnp.get_remote_data_ptr(self.pair_handle2, pair_offset2)

if bn_group>4:
self.pair_handle3 = handles_l[local_rank ^ 3].cpu().contiguous()
pair_offset3 = offsets_l[local_rank ^ 3].cpu()
self.pair_handle3 = handles_l[local_rank ^ 4].cpu().contiguous()
pair_offset3 = offsets_l[local_rank ^ 4].cpu()
self.pair_data3 = bnp.get_remote_data_ptr(self.pair_handle3, pair_offset3)

#FIXME: get magic value into C code and eliminate from here
Expand Down

0 comments on commit a00952b

Please sign in to comment.