Skip to content

Commit

Permalink
added splitting checkpointed activations
Browse files Browse the repository at this point in the history
  • Loading branch information
shoeybi committed Sep 2, 2020
1 parent 42d2be0 commit d6c4248
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 1 deletion.
4 changes: 4 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ def _add_training_args(parser):
group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.')
group.add_argument('--distribute-checkpointed-activations',
action='store_true',
help='If set, distribute checkpointed activations '
'across model parallel group.')
group.add_argument('--checkpoint-num-layers', type=int, default=1,
help='chunk size (number of layers) for checkpointing.')
group.add_argument('--train-iters', type=int, default=None,
Expand Down
21 changes: 21 additions & 0 deletions megatron/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ def finish_mpu_init():
else:
# Megatron's MPU is the master. Complete initialization right away.
finish_mpu_init()

# Initialize memory buffers.
_initialize_mem_buffs()

# Autoresume.
_init_autoresume()
Expand Down Expand Up @@ -151,3 +154,21 @@ def _write_args_to_tensorboard():
if writer:
for arg in vars(args):
writer.add_text(arg, str(getattr(args, arg)))


def _initialize_mem_buffs():
"""Initialize manually allocated static memory."""
args = get_args()

# Initialize memory for checkpointed activations.
if args.distribute_checkpointed_activations:
per_layer = args.batch_size * args.max_position_embeddings * \
args.hidden_size // args.model_parallel_size
assert args.num_layers % args.checkpoint_num_layers == 0, \
'number of layers is not divisible by checkpoint-num-layers'
num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers
numel = per_layer * num_checkpointer_layers
dtype = torch.half
if not args.fp16:
dtype = torch.float
mpu.init_checkpointed_activations_memory_buffer(numel, dtype)
145 changes: 145 additions & 0 deletions megatron/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch


# A dictionary of all the memory buffers allocated.
_MEM_BUFFS = dict()


def allocate_mem_buff(name, numel, dtype, track_usage):
"""Allocate a memory buffer."""
assert name not in _MEM_BUFFS, \
'memory buffer {} already allocated.'.format(name)
_MEM_BUFFS[name] = MemoryBuffer(name, numel, dtype, track_usage)
return _MEM_BUFFS[name]


def get_mem_buff(name):
"""Get the memory buffer."""
return _MEM_BUFFS[name]


class MemoryBuffer:
"""Contiguous memory buffer.
Allocate a contiguous memory of type `dtype` and size `numel`. It is
used to reduce memory fragmentation.
Usage: After the allocation, the `_start` index is set tot the first
index of the memory. A memory chunk starting from `_start` index
can be `allocated` for an input tensor, with the elements of the
tensor being coppied. The buffer can be reused by resetting the
`_start` index.
"""
def __init__(self, name, numel, dtype, track_usage):
if torch.distributed.get_rank() == 0:
element_size = torch.tensor([], dtype=dtype).element_size()
print('> building the {} memory buffer with {} num elements '
'and {} dtype ({:.1f} MB)...'.format(
name, numel, dtype, numel*element_size/1024/1024),
flush=True)
self.name = name
self.numel = numel
self.dtype = dtype
self.data = torch.empty(self.numel,
dtype=self.dtype,
device=torch.cuda.current_device(),
requires_grad=False)

# Index tracking the start of the free memory.
self._start = 0

# Values used for tracking usage.
self.track_usage = track_usage
if self.track_usage:
self.in_use_value = 0.0
self.total_value = 0.0


def reset(self):
"""Reset the buffer start index to the beginning of the buffer."""
self._start = 0


def is_in_use(self):
"""Whether the current buffer hold on to any memory."""
return self._start > 0


def numel_in_use(self):
"""Return number of elements in use."""
return self._start


def add(self, tensor):
"""Allocate a chunk of memory from the buffer to tensor and copy
the values."""
assert tensor.dtype == self.dtype, \
'Input tensor type {} different from buffer type {}'.format(
tensor.dtype, self.dtype)
# Number of elements of the input tensor.
tensor_numel = torch.numel(tensor)
new_start = self._start + tensor_numel
assert new_start <= self.numel, \
'Not enough memory left in the buffer ({} > {})'.format(
tensor_numel, self.numel - self._start)
# New tensor is a view into the memory.
new_tensor = self.data[self._start:new_start]
self._start = new_start
new_tensor = new_tensor.view(tensor.shape)
new_tensor.copy_(tensor)
# Return a pointer to the new tensor.
return new_tensor


def get_data(self):
"""Return the data currently in use."""
if self.track_usage:
self.in_use_value += float(self._start)
self.total_value += float(self.numel)
return self.data[:self._start]


def print_average_usage(self):
"""Print memory usage average over time. We would like this value
to be as high as possible."""
assert self.track_usage, 'You need to enable track usage.'
if torch.distributed.get_rank() == 0:
print(' > usage of {} memory buffer: {:.2f} %'.format(
self.name, self.in_use_value * 100.0 / self.total_value),
flush=True)



class RingMemBuffer:
"""A ring of memory buffers."""

def __init__(self, name, num_buffers, numel, dtype, track_usage):
self.num_buffers = num_buffers
self.buffers = [
allocate_mem_buff(name+' {}'.format(i), numel, dtype, track_usage)
for i in range(num_buffers)]
self._index = -1


def get_next_buffer(self):
self._index += 1
self._index = self._index % self.num_buffers
buff = self.buffers[self._index]
assert not buff.is_in_use(), 'buffer is already in use.'
return buff
2 changes: 2 additions & 0 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,8 @@ def custom_forward(*inputs):
return x_
return custom_forward

# Make sure memory is freed.
mpu.reset_checkpointed_activations_memory_buffer()
l = 0
while l < self.num_layers:
hidden_states = mpu.checkpoint(
Expand Down
2 changes: 2 additions & 0 deletions megatron/mpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@

from .random import checkpoint
from .random import get_cuda_rng_tracker
from .random import init_checkpointed_activations_memory_buffer
from .random import model_parallel_cuda_manual_seed
from .random import reset_checkpointed_activations_memory_buffer

from .utils import divide
from .utils import split_tensor_along_last_dim
61 changes: 60 additions & 1 deletion megatron/mpu/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,35 @@
from torch.cuda import _lazy_call, device as device_ctx_manager
from torch.utils.checkpoint import detach_variable

from megatron.memory import allocate_mem_buff

from .initialize import get_data_parallel_rank
from .initialize import get_model_parallel_group
from .initialize import get_model_parallel_rank
from .initialize import get_model_parallel_world_size


# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'


# Whether apply model parallelsim to checkpointed hidden states.
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None


def init_checkpointed_activations_memory_buffer(numel, dtype):
"""Initializ the memory buffer for the checkpointed activations."""
global _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = allocate_mem_buff(
'checkpointed activations', numel, dtype, track_usage=False)


def reset_checkpointed_activations_memory_buffer():
"""Reset the memory used for checkpointing."""
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.reset()


def _set_cuda_rng_state(new_state, device=-1):
"""Sets the random number generator state of the current GPU.
Expand Down Expand Up @@ -65,6 +86,29 @@ def cb():
_lazy_call(cb)


def split_tensor_into_1d_equal_chunks(tensor):
"""Break a tensor into equal 1D chunks."""
data = tensor.view(-1)
partition_size = torch.numel(data) // get_model_parallel_world_size()
start_index = partition_size * get_model_parallel_rank()
end_index = start_index + partition_size
return data[start_index:end_index]


def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks."""
world_size = get_model_parallel_world_size()
numel = torch.numel(tensor)
numel_gathered = world_size * numel
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
chunks = [gathered[i*numel:(i+1)*numel] for i in range(world_size)]
torch.distributed.all_gather(chunks, tensor,
group=get_model_parallel_group())
return gathered


class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states.
Expand Down Expand Up @@ -199,9 +243,21 @@ def forward(ctx, run_function, *args):
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()

ctx.save_for_backward(*args)
with torch.no_grad():
outputs = run_function(*args)

# Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank.
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
ctx.input_0_shape = args[0].data.shape
args[0].data = split_tensor_into_1d_equal_chunks(args[0].data)
args[0].data = _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.add(
args[0].data)

# Store everything.
ctx.save_for_backward(*args)


return outputs

@staticmethod
Expand All @@ -210,6 +266,9 @@ def backward(ctx, *args):
raise RuntimeError("Checkpointing is not compatible with .grad(), "
"please use .backward() if possible")
inputs = ctx.saved_tensors
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
inputs[0].data = gather_split_1d_tensor(inputs[0].data)
inputs[0].data = inputs[0].data.view(ctx.input_0_shape)

# Store the current states.
bwd_cpu_rng_state = torch.get_rng_state()
Expand Down

0 comments on commit d6c4248

Please sign in to comment.