forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added splitting checkpointed activations
- Loading branch information
Showing
6 changed files
with
234 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
# coding=utf-8 | ||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import torch | ||
|
||
|
||
# A dictionary of all the memory buffers allocated. | ||
_MEM_BUFFS = dict() | ||
|
||
|
||
def allocate_mem_buff(name, numel, dtype, track_usage): | ||
"""Allocate a memory buffer.""" | ||
assert name not in _MEM_BUFFS, \ | ||
'memory buffer {} already allocated.'.format(name) | ||
_MEM_BUFFS[name] = MemoryBuffer(name, numel, dtype, track_usage) | ||
return _MEM_BUFFS[name] | ||
|
||
|
||
def get_mem_buff(name): | ||
"""Get the memory buffer.""" | ||
return _MEM_BUFFS[name] | ||
|
||
|
||
class MemoryBuffer: | ||
"""Contiguous memory buffer. | ||
Allocate a contiguous memory of type `dtype` and size `numel`. It is | ||
used to reduce memory fragmentation. | ||
Usage: After the allocation, the `_start` index is set tot the first | ||
index of the memory. A memory chunk starting from `_start` index | ||
can be `allocated` for an input tensor, with the elements of the | ||
tensor being coppied. The buffer can be reused by resetting the | ||
`_start` index. | ||
""" | ||
def __init__(self, name, numel, dtype, track_usage): | ||
if torch.distributed.get_rank() == 0: | ||
element_size = torch.tensor([], dtype=dtype).element_size() | ||
print('> building the {} memory buffer with {} num elements ' | ||
'and {} dtype ({:.1f} MB)...'.format( | ||
name, numel, dtype, numel*element_size/1024/1024), | ||
flush=True) | ||
self.name = name | ||
self.numel = numel | ||
self.dtype = dtype | ||
self.data = torch.empty(self.numel, | ||
dtype=self.dtype, | ||
device=torch.cuda.current_device(), | ||
requires_grad=False) | ||
|
||
# Index tracking the start of the free memory. | ||
self._start = 0 | ||
|
||
# Values used for tracking usage. | ||
self.track_usage = track_usage | ||
if self.track_usage: | ||
self.in_use_value = 0.0 | ||
self.total_value = 0.0 | ||
|
||
|
||
def reset(self): | ||
"""Reset the buffer start index to the beginning of the buffer.""" | ||
self._start = 0 | ||
|
||
|
||
def is_in_use(self): | ||
"""Whether the current buffer hold on to any memory.""" | ||
return self._start > 0 | ||
|
||
|
||
def numel_in_use(self): | ||
"""Return number of elements in use.""" | ||
return self._start | ||
|
||
|
||
def add(self, tensor): | ||
"""Allocate a chunk of memory from the buffer to tensor and copy | ||
the values.""" | ||
assert tensor.dtype == self.dtype, \ | ||
'Input tensor type {} different from buffer type {}'.format( | ||
tensor.dtype, self.dtype) | ||
# Number of elements of the input tensor. | ||
tensor_numel = torch.numel(tensor) | ||
new_start = self._start + tensor_numel | ||
assert new_start <= self.numel, \ | ||
'Not enough memory left in the buffer ({} > {})'.format( | ||
tensor_numel, self.numel - self._start) | ||
# New tensor is a view into the memory. | ||
new_tensor = self.data[self._start:new_start] | ||
self._start = new_start | ||
new_tensor = new_tensor.view(tensor.shape) | ||
new_tensor.copy_(tensor) | ||
# Return a pointer to the new tensor. | ||
return new_tensor | ||
|
||
|
||
def get_data(self): | ||
"""Return the data currently in use.""" | ||
if self.track_usage: | ||
self.in_use_value += float(self._start) | ||
self.total_value += float(self.numel) | ||
return self.data[:self._start] | ||
|
||
|
||
def print_average_usage(self): | ||
"""Print memory usage average over time. We would like this value | ||
to be as high as possible.""" | ||
assert self.track_usage, 'You need to enable track usage.' | ||
if torch.distributed.get_rank() == 0: | ||
print(' > usage of {} memory buffer: {:.2f} %'.format( | ||
self.name, self.in_use_value * 100.0 / self.total_value), | ||
flush=True) | ||
|
||
|
||
|
||
class RingMemBuffer: | ||
"""A ring of memory buffers.""" | ||
|
||
def __init__(self, name, num_buffers, numel, dtype, track_usage): | ||
self.num_buffers = num_buffers | ||
self.buffers = [ | ||
allocate_mem_buff(name+' {}'.format(i), numel, dtype, track_usage) | ||
for i in range(num_buffers)] | ||
self._index = -1 | ||
|
||
|
||
def get_next_buffer(self): | ||
self._index += 1 | ||
self._index = self._index % self.num_buffers | ||
buff = self.buffers[self._index] | ||
assert not buff.is_in_use(), 'buffer is already in use.' | ||
return buff |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters