forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[static runtime] Add _out variants and reuse memory (pytorch#44128)
Summary: Pull Request resolved: pytorch#44128 Test Plan: Imported from OSS Reviewed By: hlu1 Differential Revision: D23604304 Pulled By: bwasti fbshipit-source-id: 06a23cb75700a0fc733069071843b7b498e7b9e9
- Loading branch information
1 parent
d1d9017
commit d1a1161
Showing
8 changed files
with
192 additions
and
65 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
list(APPEND STATIC_RUNTIME_BENCHMARK_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/deep_wide_pt_bench.cc) | ||
list(APPEND STATIC_RUNTIME_BENCHMARK_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/deep_wide_pt.cc) | ||
list(APPEND STATIC_RUNTIME_BENCHMARK_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/deep_wide_pt_bench.cc) | ||
set(STATIC_RUNTIME_BENCHMARK_SRCS ${STATIC_RUNTIME_BENCHMARK_SRCS} PARENT_SCOPE) | ||
|
||
list(APPEND STATIC_RUNTIME_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/deep_wide_pt.cc) | ||
list(APPEND STATIC_RUNTIME_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/test_static_runtime.cc) | ||
set(STATIC_RUNTIME_TEST_SRCS ${STATIC_RUNTIME_TEST_SRCS} PARENT_SCOPE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
#include <torch/csrc/jit/runtime/static/ops.h> | ||
#include <torch/csrc/jit/ir/ir.h> | ||
|
||
namespace torch { | ||
namespace jit { | ||
|
||
bool canRunOutOfPlace(Node* n) { | ||
auto str = std::string(n->kind().toQualString()); | ||
if ((str == "aten::add") || (str == "aten::mul") || (str == "aten::addmm") || | ||
(str == "aten::bmm") || (str == "aten::sigmoid") || | ||
(str == "aten::cat")) { | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
||
std::function<void(StaticRuntime::ConstantMap&)> getOutOfPlaceOperation( | ||
Node* n) { | ||
auto create_empty_from = [](const at::Tensor& t) { | ||
return at::empty({0}, t.options()); | ||
}; | ||
|
||
if (n->kind() == c10::Symbol::fromQualString("aten::add")) { | ||
auto out = n->outputs().at(0); | ||
auto in0 = n->inputs().at(0); | ||
auto in1 = n->inputs().at(1); | ||
auto in2 = n->inputs().at(2); | ||
return [=](StaticRuntime::ConstantMap& ws) { | ||
auto in0_t = ws.at(in0).toTensor(); | ||
auto in1_t = ws.at(in1).toTensor(); | ||
auto in2_s = ws.at(in2).toScalar(); | ||
if (!ws.count(out)) { | ||
ws.emplace(out, create_empty_from(in0_t)); | ||
} | ||
auto out_t = ws.at(out).toTensor(); | ||
at::native::add_out(out_t, in0_t, in1_t, in2_s); | ||
}; | ||
} else if (n->kind() == c10::Symbol::fromQualString("aten::mul")) { | ||
auto out = n->outputs().at(0); | ||
auto in0 = n->inputs().at(0); | ||
auto in1 = n->inputs().at(1); | ||
return [=](StaticRuntime::ConstantMap& ws) { | ||
auto in0_t = ws.at(in0).toTensor(); | ||
auto in1_t = ws.at(in1).toTensor(); | ||
if (!ws.count(out)) { | ||
ws.emplace(out, create_empty_from(in0_t)); | ||
} | ||
auto out_t = ws.at(out).toTensor(); | ||
at::native::mul_out(out_t, in0_t, in1_t); | ||
}; | ||
} else if (n->kind() == c10::Symbol::fromQualString("aten::addmm")) { | ||
auto out = n->outputs().at(0); | ||
auto in0 = n->inputs().at(0); | ||
auto in1 = n->inputs().at(1); | ||
auto in2 = n->inputs().at(2); | ||
auto in3 = n->inputs().at(3); | ||
auto in4 = n->inputs().at(4); | ||
return [=](StaticRuntime::ConstantMap& ws) { | ||
auto in0_t = ws.at(in0).toTensor(); | ||
auto in1_t = ws.at(in1).toTensor(); | ||
auto in2_t = ws.at(in2).toTensor(); | ||
auto in3_s = ws.at(in3).toScalar(); | ||
auto in4_s = ws.at(in3).toScalar(); | ||
if (!ws.count(out)) { | ||
ws.emplace(out, create_empty_from(in0_t)); | ||
} | ||
auto out_t = ws.at(out).toTensor(); | ||
at::native::addmm_cpu_out(out_t, in0_t, in1_t, in2_t, in3_s, in4_s); | ||
}; | ||
} else if (n->kind() == c10::Symbol::fromQualString("aten::clamp")) { | ||
auto out = n->outputs().at(0); | ||
auto in0 = n->inputs().at(0); | ||
auto in1 = n->inputs().at(1); | ||
auto in2 = n->inputs().at(2); | ||
return [=](StaticRuntime::ConstantMap& ws) { | ||
auto in0_t = ws.at(in0).toTensor(); | ||
auto in1_s = ws.at(in1).toScalar(); | ||
auto in2_s = ws.at(in2).toScalar(); | ||
if (!ws.count(out)) { | ||
ws.emplace(out, create_empty_from(in0_t)); | ||
} | ||
auto out_t = ws.at(out).toTensor(); | ||
at::native::clamp_out(out_t, in0_t, in1_s, in2_s); | ||
}; | ||
} else if (n->kind() == c10::Symbol::fromQualString("aten::bmm")) { | ||
auto out = n->outputs().at(0); | ||
auto in0 = n->inputs().at(0); | ||
auto in1 = n->inputs().at(1); | ||
return [=](StaticRuntime::ConstantMap& ws) { | ||
auto in0_t = ws.at(in0).toTensor(); | ||
auto in1_t = ws.at(in1).toTensor(); | ||
if (!ws.count(out)) { | ||
ws.emplace(out, create_empty_from(in0_t)); | ||
} | ||
auto out_t = ws.at(out).toTensor(); | ||
at::native::bmm_out_cpu(out_t, in0_t, in1_t); | ||
}; | ||
} else if (n->kind() == c10::Symbol::fromQualString("aten::cat")) { | ||
auto out = n->outputs().at(0); | ||
auto in0 = n->inputs().at(0); | ||
auto in1 = n->inputs().at(1); | ||
return [=](StaticRuntime::ConstantMap& ws) { | ||
auto in0_tl = ws.at(in0).toTensorVector(); | ||
auto in1_i = ws.at(in1).toInt(); | ||
if (!ws.count(out)) { | ||
ws.emplace(out, create_empty_from(in0_tl[0])); | ||
} | ||
auto out_t = ws.at(out).toTensor(); | ||
at::native::cat_out(out_t, in0_tl, in1_i); | ||
}; | ||
} else if (n->kind() == c10::Symbol::fromQualString("aten::sigmoid")) { | ||
auto out = n->outputs().at(0); | ||
auto in0 = n->inputs().at(0); | ||
return [=](StaticRuntime::ConstantMap& ws) { | ||
auto in0_t = ws.at(in0).toTensor(); | ||
if (!ws.count(out)) { | ||
ws.emplace(out, create_empty_from(in0_t)); | ||
} | ||
auto out_t = ws.at(out).toTensor(); | ||
at::native::sigmoid_out(out_t, in0_t); | ||
}; | ||
} | ||
|
||
return [](StaticRuntime::ConstantMap&) { TORCH_CHECK(0); }; | ||
} | ||
|
||
} // namespace jit | ||
} // namespace torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
#pragma once | ||
|
||
#include <torch/csrc/jit/ir/ir.h> | ||
#include <torch/csrc/jit/runtime/static/impl.h> | ||
|
||
namespace torch { | ||
namespace jit { | ||
|
||
bool canRunOutOfPlace(Node* n); | ||
std::function<void(StaticRuntime::ConstantMap&)> getOutOfPlaceOperation( | ||
Node* n); | ||
|
||
#define SUPPORTED_OPS(F) \ | ||
F(aten::__getitem__) \ | ||
F(aten::add) \ | ||
F(aten::addmm) \ | ||
F(aten::bmm) \ | ||
F(aten::cat) \ | ||
F(aten::clamp) \ | ||
F(aten::contiguous) \ | ||
F(aten::div) \ | ||
F(aten::flatten) \ | ||
F(aten::index_put_) \ | ||
F(aten::isnan) \ | ||
F(aten::matmul) \ | ||
F(aten::mul) \ | ||
F(aten::permute) \ | ||
F(aten::relu) \ | ||
F(aten::sigmoid) \ | ||
F(aten::size) \ | ||
F(aten::softmax) \ | ||
F(aten::t) \ | ||
F(aten::to) \ | ||
F(aten::transpose) \ | ||
F(aten::view) \ | ||
F(prim::Constant) \ | ||
F(prim::ListConstruct) \ | ||
F(prim::TupleConstruct) | ||
|
||
} // namespace jit | ||
} // namespace torch |