Skip to content

Commit

Permalink
[static runtime] Add _out variants and reuse memory (pytorch#44128)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#44128

Test Plan: Imported from OSS

Reviewed By: hlu1

Differential Revision: D23604304

Pulled By: bwasti

fbshipit-source-id: 06a23cb75700a0fc733069071843b7b498e7b9e9
  • Loading branch information
bwasti authored and facebook-github-bot committed Sep 25, 2020
1 parent d1d9017 commit d1a1161
Show file tree
Hide file tree
Showing 8 changed files with 192 additions and 65 deletions.
6 changes: 5 additions & 1 deletion benchmarks/static_runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
list(APPEND STATIC_RUNTIME_BENCHMARK_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/deep_wide_pt_bench.cc)
list(APPEND STATIC_RUNTIME_BENCHMARK_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/deep_wide_pt.cc)
list(APPEND STATIC_RUNTIME_BENCHMARK_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/deep_wide_pt_bench.cc)
set(STATIC_RUNTIME_BENCHMARK_SRCS ${STATIC_RUNTIME_BENCHMARK_SRCS} PARENT_SCOPE)

list(APPEND STATIC_RUNTIME_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/deep_wide_pt.cc)
list(APPEND STATIC_RUNTIME_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/test_static_runtime.cc)
set(STATIC_RUNTIME_TEST_SRCS ${STATIC_RUNTIME_TEST_SRCS} PARENT_SCOPE)
2 changes: 2 additions & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1247,7 +1247,9 @@ endif()
if(BUILD_STATIC_RUNTIME_BENCHMARK)
add_subdirectory(${TORCH_ROOT}/benchmarks/static_runtime ${PROJECT_BINARY_DIR}/bin)
add_executable(static_runtime_bench "${STATIC_RUNTIME_BENCHMARK_SRCS}")
add_executable(static_runtime_test "${STATIC_RUNTIME_TEST_SRCS}")
target_link_libraries(static_runtime_bench torch_library benchmark)
target_link_libraries(static_runtime_test torch_library gtest_main)
endif()

if(BUILD_MOBILE_BENCHMARK)
Expand Down
13 changes: 8 additions & 5 deletions test/test_static_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def test_multihead_attention_layer(self):
DROPOUT = 0.1
device = torch.device("cpu")
attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
with torch.no_grad():
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)

attention.eval()
Expand All @@ -129,17 +130,19 @@ def test_mlp(self):
bot_l_acc = StaticRuntime(bot_l)
top_l = create_mlp(ln_top, sigmoid_top)
top_l_acc = StaticRuntime(top_l)
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
with torch.no_grad():
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
ref_bot = bot_l(bot_inp)
acc_bot = bot_l_acc(bot_inp)[0]
torch.testing.assert_allclose(acc_bot, ref_bot)
ref_top = top_l(top_inp)
acc_top = top_l_acc(top_inp)[0]
torch.testing.assert_allclose(acc_top, ref_top)
for _ in range(5):
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
with torch.no_grad():
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
ref_bot = bot_l(bot_inp)
acc_bot = bot_l_acc(bot_inp)[0]
torch.testing.assert_allclose(acc_bot, ref_bot)
Expand Down
1 change: 1 addition & 0 deletions tools/build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ core_sources_full = [
"torch/csrc/jit/runtime/profiling_record.cpp",
"torch/csrc/jit/runtime/symbolic_script.cpp",
"torch/csrc/jit/runtime/static/impl.cpp",
"torch/csrc/jit/runtime/static/ops.cpp",
"torch/csrc/jit/serialization/import.cpp",
"torch/csrc/jit/serialization/import_export_helpers.cpp",
"torch/csrc/jit/serialization/import_source.cpp",
Expand Down
63 changes: 6 additions & 57 deletions torch/csrc/jit/runtime/static/impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/passes/remove_mutation.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <torch/csrc/jit/runtime/static/ops.h>
#include <torch/csrc/jit/runtime/vararg_functions.h>

namespace torch {
Expand All @@ -12,48 +13,6 @@ namespace jit {
using c10::DispatchKey;
using c10::RegisterOperators;

static auto reg =
RegisterOperators()
.op("static::add(Tensor a, Tensor b) -> Tensor",
RegisterOperators::options().kernel(
DispatchKey::CPU,
[](at::Tensor a, at::Tensor b) -> at::Tensor { return a + b; }))
.op("static::mul.a(Tensor a, Tensor b) -> Tensor",
RegisterOperators::options().kernel(
DispatchKey::CPU,
[](at::Tensor a, at::Tensor b) -> at::Tensor { return a * b; }))
.op("static::mul.b(Tensor a, int b) -> Tensor",
RegisterOperators::options().kernel(
DispatchKey::CPU,
[](at::Tensor a, int64_t b) -> at::Tensor { return a * b; }));

#define SUPPORTED_OPS(F) \
F(aten::__getitem__) \
F(aten::add) \
F(aten::addmm) \
F(aten::bmm) \
F(aten::cat) \
F(aten::clamp) \
F(aten::contiguous) \
F(aten::div) \
F(aten::flatten) \
F(aten::index_put_) \
F(aten::isnan) \
F(aten::matmul) \
F(aten::mul) \
F(aten::permute) \
F(aten::relu) \
F(aten::sigmoid) \
F(aten::size) \
F(aten::softmax) \
F(aten::t) \
F(aten::to) \
F(aten::transpose) \
F(aten::view) \
F(prim::Constant) \
F(prim::ListConstruct) \
F(prim::TupleConstruct)

StaticRuntime::StaticRuntime(const torch::jit::Module& m)
: module_(m.copy()), graph_(nullptr) {
module_.eval();
Expand Down Expand Up @@ -84,19 +43,6 @@ StaticRuntime::StaticRuntime(const torch::jit::Module& m)
}
}

SubgraphRewriter sr;
sr.RegisterRewritePattern(
R"IR(
graph(%x, %w, %s):
%r = aten::add(%x, %w, %s)
return (%r))IR",
R"IR(
graph(%x, %w, %s):
%y = static::add(%x, %w)
%r = static::mul(%y, %s)
return (%r))IR");
sr.runOnGraph(graph_);

// remove unused input 0 from graph
if (graph_->inputs().at(0)->type()->is_module()) {
if (!graph_->inputs().at(0)->hasUses()) {
Expand Down Expand Up @@ -157,10 +103,13 @@ ProcessedNode::ProcessedNode(Node* node) : node_(node) {
CHECK(op.hasOperation());
op_ = op.getOperation(node);
}
if (canRunOutOfPlace(node)) {
fn_ = getOutOfPlaceOperation(node);
}
}

void ProcessedNode::run(StaticRuntime::ConstantMap& workspace) const {
if (use_stack_) {
if (!fn_) {
std::vector<IValue> stack;
const size_t size = node_->inputs().size();
stack.reserve(size);
Expand Down Expand Up @@ -201,7 +150,7 @@ void ProcessedNode::run(StaticRuntime::ConstantMap& workspace) const {
workspace[node_->outputs()[i]] = stack[i];
}
} else {
TORCH_CHECK(0, "Non-stack execution not yet implemented");
(*fn_)(workspace);
}
}

Expand Down
3 changes: 1 addition & 2 deletions torch/csrc/jit/runtime/static/impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ class ProcessedNode {
private:
Node* node_;
c10::optional<Operation> op_;
// if false, we have an optimized version
bool use_stack_ = true;
c10::optional<std::function<void(StaticRuntime::ConstantMap&)>> fn_;
};

} // namespace jit
Expand Down
128 changes: 128 additions & 0 deletions torch/csrc/jit/runtime/static/ops.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#include <torch/csrc/jit/runtime/static/ops.h>
#include <torch/csrc/jit/ir/ir.h>

namespace torch {
namespace jit {

bool canRunOutOfPlace(Node* n) {
auto str = std::string(n->kind().toQualString());
if ((str == "aten::add") || (str == "aten::mul") || (str == "aten::addmm") ||
(str == "aten::bmm") || (str == "aten::sigmoid") ||
(str == "aten::cat")) {
return true;
}
return false;
}

std::function<void(StaticRuntime::ConstantMap&)> getOutOfPlaceOperation(
Node* n) {
auto create_empty_from = [](const at::Tensor& t) {
return at::empty({0}, t.options());
};

if (n->kind() == c10::Symbol::fromQualString("aten::add")) {
auto out = n->outputs().at(0);
auto in0 = n->inputs().at(0);
auto in1 = n->inputs().at(1);
auto in2 = n->inputs().at(2);
return [=](StaticRuntime::ConstantMap& ws) {
auto in0_t = ws.at(in0).toTensor();
auto in1_t = ws.at(in1).toTensor();
auto in2_s = ws.at(in2).toScalar();
if (!ws.count(out)) {
ws.emplace(out, create_empty_from(in0_t));
}
auto out_t = ws.at(out).toTensor();
at::native::add_out(out_t, in0_t, in1_t, in2_s);
};
} else if (n->kind() == c10::Symbol::fromQualString("aten::mul")) {
auto out = n->outputs().at(0);
auto in0 = n->inputs().at(0);
auto in1 = n->inputs().at(1);
return [=](StaticRuntime::ConstantMap& ws) {
auto in0_t = ws.at(in0).toTensor();
auto in1_t = ws.at(in1).toTensor();
if (!ws.count(out)) {
ws.emplace(out, create_empty_from(in0_t));
}
auto out_t = ws.at(out).toTensor();
at::native::mul_out(out_t, in0_t, in1_t);
};
} else if (n->kind() == c10::Symbol::fromQualString("aten::addmm")) {
auto out = n->outputs().at(0);
auto in0 = n->inputs().at(0);
auto in1 = n->inputs().at(1);
auto in2 = n->inputs().at(2);
auto in3 = n->inputs().at(3);
auto in4 = n->inputs().at(4);
return [=](StaticRuntime::ConstantMap& ws) {
auto in0_t = ws.at(in0).toTensor();
auto in1_t = ws.at(in1).toTensor();
auto in2_t = ws.at(in2).toTensor();
auto in3_s = ws.at(in3).toScalar();
auto in4_s = ws.at(in3).toScalar();
if (!ws.count(out)) {
ws.emplace(out, create_empty_from(in0_t));
}
auto out_t = ws.at(out).toTensor();
at::native::addmm_cpu_out(out_t, in0_t, in1_t, in2_t, in3_s, in4_s);
};
} else if (n->kind() == c10::Symbol::fromQualString("aten::clamp")) {
auto out = n->outputs().at(0);
auto in0 = n->inputs().at(0);
auto in1 = n->inputs().at(1);
auto in2 = n->inputs().at(2);
return [=](StaticRuntime::ConstantMap& ws) {
auto in0_t = ws.at(in0).toTensor();
auto in1_s = ws.at(in1).toScalar();
auto in2_s = ws.at(in2).toScalar();
if (!ws.count(out)) {
ws.emplace(out, create_empty_from(in0_t));
}
auto out_t = ws.at(out).toTensor();
at::native::clamp_out(out_t, in0_t, in1_s, in2_s);
};
} else if (n->kind() == c10::Symbol::fromQualString("aten::bmm")) {
auto out = n->outputs().at(0);
auto in0 = n->inputs().at(0);
auto in1 = n->inputs().at(1);
return [=](StaticRuntime::ConstantMap& ws) {
auto in0_t = ws.at(in0).toTensor();
auto in1_t = ws.at(in1).toTensor();
if (!ws.count(out)) {
ws.emplace(out, create_empty_from(in0_t));
}
auto out_t = ws.at(out).toTensor();
at::native::bmm_out_cpu(out_t, in0_t, in1_t);
};
} else if (n->kind() == c10::Symbol::fromQualString("aten::cat")) {
auto out = n->outputs().at(0);
auto in0 = n->inputs().at(0);
auto in1 = n->inputs().at(1);
return [=](StaticRuntime::ConstantMap& ws) {
auto in0_tl = ws.at(in0).toTensorVector();
auto in1_i = ws.at(in1).toInt();
if (!ws.count(out)) {
ws.emplace(out, create_empty_from(in0_tl[0]));
}
auto out_t = ws.at(out).toTensor();
at::native::cat_out(out_t, in0_tl, in1_i);
};
} else if (n->kind() == c10::Symbol::fromQualString("aten::sigmoid")) {
auto out = n->outputs().at(0);
auto in0 = n->inputs().at(0);
return [=](StaticRuntime::ConstantMap& ws) {
auto in0_t = ws.at(in0).toTensor();
if (!ws.count(out)) {
ws.emplace(out, create_empty_from(in0_t));
}
auto out_t = ws.at(out).toTensor();
at::native::sigmoid_out(out_t, in0_t);
};
}

return [](StaticRuntime::ConstantMap&) { TORCH_CHECK(0); };
}

} // namespace jit
} // namespace torch
41 changes: 41 additions & 0 deletions torch/csrc/jit/runtime/static/ops.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#pragma once

#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/runtime/static/impl.h>

namespace torch {
namespace jit {

bool canRunOutOfPlace(Node* n);
std::function<void(StaticRuntime::ConstantMap&)> getOutOfPlaceOperation(
Node* n);

#define SUPPORTED_OPS(F) \
F(aten::__getitem__) \
F(aten::add) \
F(aten::addmm) \
F(aten::bmm) \
F(aten::cat) \
F(aten::clamp) \
F(aten::contiguous) \
F(aten::div) \
F(aten::flatten) \
F(aten::index_put_) \
F(aten::isnan) \
F(aten::matmul) \
F(aten::mul) \
F(aten::permute) \
F(aten::relu) \
F(aten::sigmoid) \
F(aten::size) \
F(aten::softmax) \
F(aten::t) \
F(aten::to) \
F(aten::transpose) \
F(aten::view) \
F(prim::Constant) \
F(prim::ListConstruct) \
F(prim::TupleConstruct)

} // namespace jit
} // namespace torch

0 comments on commit d1a1161

Please sign in to comment.