Skip to content

Commit

Permalink
Hoist QuantizationProfile instructions right after the last node that…
Browse files Browse the repository at this point in the history
… updates their inputs to shorten lifetimes of buffers (pytorch#2698)

Summary:
The implementation of this optimization is very similar to hoistDealloc. It runs right before hoistDealloc and helps hoistDealloc to do a better job later. As a result, the peak memory consumption in the quantization profile collection mode is greatly reduced.

Fixes pytorch#2697
Pull Request resolved: pytorch#2698

Differential Revision: D15297832

Pulled By: opti-mix

fbshipit-source-id: 8874a72e7fb25e8be1f895677d3293000e00c990
  • Loading branch information
opti-mix authored and facebook-github-bot committed May 10, 2019
1 parent 36e709c commit 3abdf8d
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 11 deletions.
50 changes: 39 additions & 11 deletions lib/IR/ChildMemSizeBasedScheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,24 +98,27 @@ void ChildMemSizeBasedScheduler::orderChildNodesAndSchedule(Node *N) {
orderedChildren.push_back(N->getPredicate());
}

// SaveNode hack:
// We don't model memory dependencies, but we still need to honor them.
// Make sure the SaveNode happens after the last use of the output
// placeholder.
if (auto *save = dyn_cast<SaveNode>(N)) {
auto *destination = save->getOutput().getNode();
// Make sure the a node mutating any of its inputs happens after the last
// non-mutating use of the operand being mutated. Some examples of such nodes
// would be SaveNode and QuantizationProfileNode.
for (unsigned idx = 0, e = N->getNumInputs(); idx < e; ++idx) {
// We don't care about inputs that are not mutated by the node.
if (!N->isOverwrittenNthInput(idx)) {
continue;
}
auto mutatedInput = N->getNthInput(idx);
auto *destination = mutatedInput.getNode();
for (NodeUse &use : destination->getUsers()) {
Node *user = use.getUser();
if (user == save) {
if (user == N) {
continue;
}
// Storage nodes may have users scattered across different functions.
// Nodes may have users scattered across different functions.
// Only accounts for the ones in that function.
if (&G_ != user->getParent()) {
continue;
}
assert(!isa<SaveNode>(user) &&
"Placeholder must be saved at most once in each function");
orderedChildren.push_back(user);
}
}
Expand Down Expand Up @@ -148,9 +151,34 @@ void ChildMemSizeBasedScheduler::orderChildNodesAndSchedule(Node *N) {
orderChildNodesAndSchedule(child);
}

// Schedule the node after all its children are scheduled.
DEBUG_GLOW(llvm::dbgs() << "Scheduled node: " << N->getName() << "\n");
// Schedule the node after all its children are scheduled. We need to perform
// an extra isScheduled check here, because the code below may have scheduled
// the current node while scheduling its children.
if (isScheduled(N)) {
return;
}
scheduled_.push_back(N);
// If this node has a user which does not have any users and which does not
// require any additional memory, schedule it here, because we don't want to
// extend the lifetime of this value for no reason. We want to execute and get
// rid of this node as soon as possible to reduce the memory pressure.
for (NodeUse &use : N->getUsers()) {
Node *user = use.getUser();
// Users may be scattered across different functions.
// Only accounts for the ones in that function.
if (&G_ != user->getParent()) {
continue;
}
// Bail if a nodes has users, because nodes that have users can't be
// scheduled safely without violating dependencies.
if (user->getNumUsers()) {
continue;
}
// Schedule a node if it does not require any additional memory.
if (resultMemSize_[user] == 0) {
orderChildNodesAndSchedule(user);
}
}
}

void ChildMemSizeBasedScheduler::scheduleNodes() {
Expand Down
48 changes: 48 additions & 0 deletions tests/unittests/GraphSchedulerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,51 @@ TEST(GraphScheduler, testMaxSizeLessThanResultSize) {
std::distance(schedule.begin(), concatSmallIt));
}
}

TEST(GraphScheduler, ScheduleQuantizationProfileRightAfterNodeBeingProfiled) {
Module MD;
PlaceholderBindings bindings;
auto *input1 =
MD.createPlaceholder(ElemKind::FloatTy, {1, 4, 4}, "input1", false);
bindings.allocate(input1);
auto *input2 =
MD.createPlaceholder(ElemKind::FloatTy, {1, 4, 4}, "input2", false);
bindings.allocate(input2);
Function *F = MD.createFunction("F");
Node *add = F->createAdd("add", input1, input2);
Node *sub = F->createSub("sub", input1, input2);
Node *mul = F->createMul("mul", add, sub);
Node *save = F->createSave("save", mul);
Node *quantizationProfileAdd =
F->createQuantizationProfile(bindings, "qpAdd", add);
Node *quantizationProfileSub =
F->createQuantizationProfile(bindings, "qpSub", sub);

// Since all of the tensors are Variables, they don't need
// memory for storing their outputs. Consequently, sliceBig
// should be scheduled before concatSmall in this example
// because the former frees up some memory while the latter
// uses up more memory after execution.
NodesPtrList schedule;
ChildMemSizeBasedScheduler scheduler(*F, schedule);
scheduler.schedule();

// Find the positions of add and quantizationProfileAdd in the schedule.
auto addIt = std::find(schedule.begin(), schedule.end(), add);
auto qpAddIt =
std::find(schedule.begin(), schedule.end(), quantizationProfileAdd);
// Expect the quantization profiling node to be scheduled right after the node
// being profiled.
EXPECT_EQ(++addIt, qpAddIt);

// Find the positions of sub and quantizationProfileSub in the schedule.
auto subIt = std::find(schedule.begin(), schedule.end(), sub);
auto qpSubIt =
std::find(schedule.begin(), schedule.end(), quantizationProfileSub);
// Expect the quantization profiling node to be scheduled right after the node
// being profiled.
EXPECT_EQ(++subIt, qpSubIt);

// Expect the save node to be the last in the schedule.
EXPECT_EQ(save, schedule.back());
}

0 comments on commit 3abdf8d

Please sign in to comment.