Skip to content

Commit

Permalink
Optimize peer memory halo exchange kernel
Browse files Browse the repository at this point in the history
Use an alternating double-buffer scheme to allow push-only communication. Remove cooperative group syncs and memory fences.
  • Loading branch information
timmoon10 committed Nov 18, 2022
1 parent 3229126 commit 0bd620e
Show file tree
Hide file tree
Showing 4 changed files with 401 additions and 408 deletions.
23 changes: 16 additions & 7 deletions apex/contrib/bottleneck/halo_exchangers.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,28 +88,37 @@ def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_inp
inc.left_right_halo_exchange_inplace(self.handle, self.left_rank, self.right_rank, left_output_halo, right_output_halo, left_input_halo, right_input_halo)

class HaloExchangerPeer(HaloExchanger):
def __init__(self, ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=1):
def __init__(self, ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=0):
super(HaloExchangerPeer, self).__init__(ranks, rank_in_group)
self.diagnostics = False
self.explicit_nhwc = explicit_nhwc
self.numSM = numSM
self.peer_pool = peer_pool
self.signals = peer_pool.allocate_peer_tensors([2,4], torch.int32, False, False)
self.signals[self.rank_in_group].zero_()

def _allocate_peer_tensor(self, halo):

# Compute size in bytes
# Note: Pad buffer so each CUDA block gets required buffer size
size = 4 * halo.numel() * halo.element_size()
size_per_block = 128 * 2 * 16 # 128 threads each require two 128b buffers
size = (size + size_per_block - 1) // size_per_block * size_per_block

# Construct dtype peer buffer with desired size
shape = [1, 1, 1, size // halo.element_size()]
return self.peer_pool.allocate_peer_tensors(shape, halo.dtype, False, True)

def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
inplace = False if left_input_halo is None and right_input_halo is None else True
if not inplace:
left_input_halo = torch.empty_like(right_output_halo)
right_input_halo = torch.empty_like(left_output_halo)
channels_last = left_output_halo.is_contiguous(memory_format=torch.channels_last) and not self.explicit_nhwc
left_tx = self.peer_pool.allocate_peer_tensors(list(left_output_halo.shape), left_output_halo.dtype, channels_last, True)
right_tx = self.peer_pool.allocate_peer_tensors(list(right_output_halo.shape), right_output_halo.dtype, channels_last, True)
left_tx = self._allocate_peer_tensor(left_input_halo)
right_tx = self._allocate_peer_tensor(right_input_halo)
pm.push_pull_halos_1d(
self.diagnostics, self.explicit_nhwc, self.numSM,
self.diagnostics, self.explicit_nhwc, self.numSM, self.rank_in_group,
self.left_zero, left_output_halo, left_tx[self.rank_in_group], right_tx[self.wrap_around_left_rank_in_group], left_input_halo,
self.right_zero, right_output_halo, right_tx[self.rank_in_group], left_tx[self.wrap_around_right_rank_in_group], right_input_halo,
self.signals[self.wrap_around_left_rank_in_group], self.signals[self.wrap_around_right_rank_in_group], self.signals[self.rank_in_group]
)
if not inplace:
return left_input_halo, right_input_halo
Expand Down
Loading

0 comments on commit 0bd620e

Please sign in to comment.