Skip to content

Commit

Permalink
Delete DDP hooks in Reducer destructor (#21591)
Browse files Browse the repository at this point in the history
Summary:
Closes pytorch/pytorch#21344

DDP assigns the original module to the first module replica instead of creating a new one. Then, it creates a new Reducer to add post hooks to sync gradients. However, because every reconstructed DDP instance wraps the same original module, all their reducers will add hooks to the same set of variables. This PR deletes DDP hooks from variables when destructing Reducer, trying to make DDP failure recoverable.

pietern kuttas and I discussed the following solutions:

#### Solution 1

Keep `add_post_hook` API intact, and do a `dynamic_cast` in `del_post_hook` to check hook type. If the type matches Reducer's hook, delete it. As pietern mentioned, this will not work if we create multiple DDP instances from the same original model.

#### Solution 2

Use a counter to generate a unique key for every hook in `Function`, and keep them in a map. return the key to the caller of `add_post_hook`, and ask the caller to provide key if it needs to delete the hook.

Con: this would add extra overhead to `add_post_hook` and every `Function` object.

#### Solution 3 [Current implementation]

kuttas suggests that, instead of generating a unique key, directly using the address of the pointer would be better. In order to avoid messing up dereferencing, let `add_post_hook` to return a `uintptr_t`.
Pull Request resolved: pytorch/pytorch#21591

Differential Revision: D15745706

Pulled By: mrshenli

fbshipit-source-id: e56d2d48de0c65f6667790ab16337eac7f7d8b76
  • Loading branch information
mrshenli authored and facebook-github-bot committed Jun 12, 2019
1 parent 1e4af2b commit cbcb2b5
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 6 deletions.
65 changes: 65 additions & 0 deletions test/test_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2490,6 +2490,71 @@ def forward(self, x):
loss = criterion(output, target)
loss.backward()

@skip_if_not_nccl
@skip_if_not_multigpu
def test_failure_recovery(self):
store = c10d.FileStore(self.file.name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)

# need to create a separate file for the recovered FileStore, because
# the original one will be deleted when destructing the first FileStore.
recovery_filename = self.file.name + "_recovery"

if self.rank == 0:
# the file will be deleted by the recovered FileStore
open(recovery_filename, "w").close()

# not necessary to run barrier here, as DDP will synchronize

class TestModel(nn.Module):
def __init__(self):
super(TestModel, self).__init__()
self.fc1 = nn.Linear(2, 10, bias=False)
self.fc2 = nn.Linear(10, 4, bias=False)
self.relu = nn.ReLU()

def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
return F.softmax(x, dim=1)

device_id = gpus_for_rank(self.world_size)[self.rank][0]
model = TestModel().float().to(device_id)
ddp = DistributedDataParallel(
model,
device_ids=[device_id],
process_group=process_group,
)

batch_size = 4
criterion = nn.CrossEntropyLoss()
input = torch.rand([batch_size, 2], dtype=torch.float)
target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(device_id)

for _ in range(6):
output = ddp(input)
loss = criterion(output, target)
loss.backward()

del ddp
del process_group
del store # this will delete self.file

store = c10d.FileStore(recovery_filename, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
ddp = DistributedDataParallel(
model,
device_ids=[device_id],
process_group=process_group,
)

input = torch.rand([batch_size, 2], dtype=torch.float)
target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(device_id)
for _ in range(6):
output = ddp(input)
loss = criterion(output, target)
loss.backward()


class ReducerModule(nn.Module):
def __init__(self):
Expand Down
16 changes: 15 additions & 1 deletion torch/csrc/autograd/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,15 +247,29 @@ struct TORCH_API Function : std::enable_shared_from_this<Function> {
// Hook API
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

void add_post_hook(std::unique_ptr<FunctionPostHook>&& post_hook) {
uintptr_t add_post_hook(std::unique_ptr<FunctionPostHook>&& post_hook) {
post_hooks_.push_back(std::move(post_hook));
// Use the raw pointer as the unique key to identify this hook. This key
// can then be used in del_post_hook(key) to remove this hook.
return reinterpret_cast<std::uintptr_t>(post_hooks_.back().get());
}

const std::vector<std::unique_ptr<FunctionPostHook>>& post_hooks() const
noexcept {
return post_hooks_;
}

// delete a post hook matching the key
bool del_post_hook(const uintptr_t& key) {
for (auto it = post_hooks_.begin(); it != post_hooks_.end();) {
if (key == reinterpret_cast<std::uintptr_t>(it->get())) {
post_hooks_.erase(it);
return true;
}
}
return false;
}

std::vector<std::unique_ptr<FunctionPostHook>>& post_hooks() noexcept {
return post_hooks_;
}
Expand Down
26 changes: 21 additions & 5 deletions torch/csrc/distributed/c10d/reducer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,14 @@ Reducer::Reducer(
auto grad_accumulator = variable.grad_accumulator();

// Hook to execute after the gradient accumulator has executed.
grad_accumulator->add_post_hook(torch::make_unique<LambdaPostHook>([=] {
std::lock_guard<std::mutex> lock(this->mutex_);
this->mark_variable_ready(
replica_index, variable_index, /* called_from_autograd= */ true);
}));
hooks_[grad_accumulator->add_post_hook(
torch::make_unique<LambdaPostHook>([=] {
std::lock_guard<std::mutex> lock(this->mutex_);
this->mark_variable_ready(
replica_index,
variable_index,
/* called_from_autograd= */ true);
}))] = grad_accumulator;

// Map raw function pointer to replica index and parameter index.
// This is used later on when the autograd graph is traversed
Expand Down Expand Up @@ -138,6 +141,19 @@ Reducer::Reducer(
}
}

Reducer::~Reducer() noexcept(false) {
// Remove all hooks on variables registered by this Reducer. This is necessary
// to make DDP failure recoverable. Otherwise, multiple Reducer instances
// (from recoveries) will add their hooks to the original model, and those
// hooks will try to invoke methods on a deleted Reducer objects.
for (auto& hook : hooks_) {
auto& key = hook.first;
auto& grad_accumulator = hook.second;
AT_ASSERTM(grad_accumulator->del_post_hook(key),
"Reducer attempts to delete a non-existing hook.");
}
}

// Called when the gradient for the specified variable is ready.
// It can be called from two places:
// - By an autograd thread after executing a gradient accumulator function.
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/distributed/c10d/reducer.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class Reducer {
std::vector<std::vector<size_t>> bucket_indices,
std::shared_ptr<c10d::ProcessGroup> process_group);

~Reducer() noexcept(false);

// To (re-)initialize bucket assignment, pass a list of buckets, each
// of which is specified by a list of indices in the variables list.
// This function performs validation that the variables within a bucket
Expand Down Expand Up @@ -52,6 +54,8 @@ class Reducer {
std::vector<std::vector<std::shared_ptr<torch::autograd::Function>>>
grad_accumulators_;
std::unordered_map<torch::autograd::Function*, std::tuple<int, int>> func_;
std::unordered_map<uintptr_t, std::shared_ptr<torch::autograd::Function>>
hooks_;

bool expect_autograd_hooks_;
bool require_finalize_;
Expand Down

0 comments on commit cbcb2b5

Please sign in to comment.