Skip to content

Commit

Permalink
resolved review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
kvareddy committed May 24, 2022
1 parent 9dc3c42 commit 739cb43
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 15 deletions.
8 changes: 7 additions & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,13 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.recompute_method is None, \
'recompute method is not yet supported for ' \
'selective recomputing granularity'


# disable sequence parallelism when tp=1
# to avoid change in numerics when
# sequence_parallelism is enabled.
if args.tensor_model_parallel_size == 1:
args.sequence_parallel = False

# disable async_tensor_model_parallel_allreduce when
# model parallel memory optimization is enabled
if args.sequence_parallel:
Expand Down
23 changes: 13 additions & 10 deletions megatron/global_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,18 +292,21 @@ def log(self, names, normalizer=1.0, reset=True):


class GlobalMemoryBuffer:
"Global buffer to avoid dynamic memory allocations"
"""Global buffer to avoid dynamic memory allocations.
Caller should ensure that buffers of the same name
are not used concurrently."""

def __init__(self):
self.buffer = {}

def allocate_tensor(self, tensor_shape, dtype):
def get_tensor(self, tensor_shape, dtype, name):
required_len = reduce(operator.mul, tensor_shape, 1)
if self.buffer.get(dtype, None) is None or self.buffer[dtype].numel() < required_len:
self.buffer[dtype] = torch.empty(required_len,
dtype=dtype,
device=torch.cuda.current_device(),
requires_grad=False)

return self.buffer[dtype][0:required_len].view(*tensor_shape)

if self.buffer.get((name, dtype), None) is None or \
self.buffer[(name, dtype)].numel() < required_len:
self.buffer[(name, dtype)] = \
torch.empty(required_len,
dtype=dtype,
device=torch.cuda.current_device(),
requires_grad=False)

return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape)
4 changes: 2 additions & 2 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,9 @@ def forward(self, query_layer, key_layer,
output_size[0] * output_size[1], -1)

# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer = get_global_memory_buffer().allocate_tensor(
matmul_input_buffer = get_global_memory_buffer().get_tensor(
(output_size[0]*output_size[1], output_size[2], output_size[3]),
dtype=query_layer.dtype)
query_layer.dtype, "mpu")

# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
Expand Down
4 changes: 2 additions & 2 deletions megatron/mpu/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
dim_size[0] = dim_size[0] * world_size

all_gather_buffer = \
get_global_memory_buffer().allocate_tensor(dim_size, dtype=input.dtype)
get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
torch.distributed._all_gather_base(
all_gather_buffer,
input,
Expand All @@ -246,7 +246,7 @@ def backward(ctx, grad_output):
dim_size[0] = dim_size[0] * world_size

all_gather_buffer = \
get_global_memory_buffer().allocate_tensor(dim_size, dtype=input.dtype)
get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
handle = torch.distributed._all_gather_base(
all_gather_buffer,
input,
Expand Down

0 comments on commit 739cb43

Please sign in to comment.